In [2]:
!pip install torch

Collecting torch
  Downloading torch-2.9.1-cp311-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting filelock (from torch)
  Downloading filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)
Collecting sympy>=1.13.3 (from torch)
  Using cached sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx>=2.5.1 (from torch)
  Downloading networkx-3.6-py3-none-any.whl.metadata (6.8 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec>=0.8.5 (from torch)
  Downloading fsspec-2025.10.0-py3-none-any.whl.metadata (10 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Using cached markupsafe-3.0.3-cp311-cp311-macosx_11_0_arm64.whl.metadata (2.7 kB)
Downloading torch-2.9.1-cp311-none-macosx_11_0_arm64.whl (74.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.5/74.5 MB[0m [31m82.5 k

In [3]:
import sys
sys.path.append('../src')
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from vit_model import VisionTransformer
from preprocessing import prepare_dl_data


In [4]:
data_path = '../data_dl/train'
X, y = prepare_dl_data(data_path)
print(f'Data shape: {X.shape}, Labels: {y.shape}')
print(f'Class distribution: {np.bincount(y)}')


Data shape: (25000, 224, 224, 3), Labels: (25000,)
Class distribution: [12500 12500]


In [8]:
from typing import Any


class ImageDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X).permute(0, 3, 1, 2)
        self.y = torch.LongTensor(y)
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

indices = np.random.permutation(len(X))
split = int(0.8 * len(X))
train_idx, val_idx = indices[:split], indices[split:]

train_dataset = ImageDataset(X[train_idx], y[train_idx])
val_dataset = ImageDataset(X[val_idx], y[val_idx])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)


In [9]:
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
model = VisionTransformer(img_size=224, patch_size=16, num_classes=2, 
                          embed_dim=384, depth=6, num_heads=6)
model = model.to(device)
print(f'Model on {device}')
print(f'Total parameters: {sum(p.numel() for p in model.parameters()):,}')


Model on mps
Total parameters: 11,019,650


In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)

epochs = 5
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    batch_count = 0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += y_batch.size(0)
        train_correct += predicted.eq(y_batch).sum().item()
        batch_count += 1
        if batch_count % 100 == 0:
            print(f'  Batch {batch_count}/{len(train_loader)}')
    
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += y_batch.size(0)
            val_correct += predicted.eq(y_batch).sum().item()
    
    train_acc = 100. * train_correct / train_total
    val_acc = 100. * val_correct / val_total
    
    print(f'Epoch {epoch+1}/{epochs} - Train Loss: {train_loss/len(train_loader):.4f}, '
          f'Train Acc: {train_acc:.2f}%, Val Loss: {val_loss/len(val_loader):.4f}, '
          f'Val Acc: {val_acc:.2f}%')


  Batch 100/625
  Batch 200/625
  Batch 300/625
  Batch 400/625
  Batch 500/625
  Batch 600/625
Epoch 1/5 - Train Loss: 0.6963, Train Acc: 54.43%, Val Loss: 0.6600, Val Acc: 59.56%
  Batch 100/625
  Batch 200/625
  Batch 300/625
  Batch 400/625
  Batch 500/625
  Batch 600/625
Epoch 2/5 - Train Loss: 0.6609, Train Acc: 59.97%, Val Loss: 0.6797, Val Acc: 57.36%
  Batch 100/625
  Batch 200/625
  Batch 300/625
  Batch 400/625
  Batch 500/625
  Batch 600/625
Epoch 3/5 - Train Loss: 0.6393, Train Acc: 62.98%, Val Loss: 0.6359, Val Acc: 63.24%
  Batch 100/625
  Batch 200/625
  Batch 300/625
  Batch 400/625
  Batch 500/625
  Batch 600/625
Epoch 4/5 - Train Loss: 0.6313, Train Acc: 63.62%, Val Loss: 0.6436, Val Acc: 63.66%
  Batch 100/625
  Batch 200/625
  Batch 300/625
  Batch 400/625
  Batch 500/625
  Batch 600/625
Epoch 5/5 - Train Loss: 0.6118, Train Acc: 66.05%, Val Loss: 0.6099, Val Acc: 66.04%


In [11]:
model.eval()
with torch.no_grad():
    X_test, y_test = next(iter(val_loader))
    X_test, y_test = X_test.to(device), y_test.to(device)
    outputs = model(X_test)
    _, predicted = outputs.max(1)
    probs = torch.softmax(outputs, dim=1)
    
    print('Sample predictions:')
    for i in range(len(y_test)):
        true_label = 'cat' if y_test[i] == 0 else 'dog'
        pred_label = 'cat' if predicted[i] == 0 else 'dog'
        confidence = probs[i][predicted[i]].item() * 100
        print(f'True: {true_label}, Pred: {pred_label}, Confidence: {confidence:.2f}%')


Sample predictions:
True: cat, Pred: cat, Confidence: 65.62%
True: dog, Pred: cat, Confidence: 63.68%
True: cat, Pred: cat, Confidence: 52.38%
True: cat, Pred: cat, Confidence: 78.83%
True: dog, Pred: dog, Confidence: 54.71%
True: dog, Pred: dog, Confidence: 52.71%
True: dog, Pred: dog, Confidence: 57.68%
True: dog, Pred: cat, Confidence: 50.52%
True: dog, Pred: cat, Confidence: 59.05%
True: cat, Pred: dog, Confidence: 50.57%
True: dog, Pred: dog, Confidence: 70.01%
True: dog, Pred: dog, Confidence: 50.31%
True: cat, Pred: dog, Confidence: 50.22%
True: dog, Pred: cat, Confidence: 51.28%
True: dog, Pred: dog, Confidence: 53.89%
True: cat, Pred: cat, Confidence: 54.29%
True: dog, Pred: cat, Confidence: 60.17%
True: dog, Pred: dog, Confidence: 88.49%
True: dog, Pred: cat, Confidence: 67.38%
True: dog, Pred: dog, Confidence: 54.10%
True: cat, Pred: cat, Confidence: 68.28%
True: dog, Pred: cat, Confidence: 72.56%
True: dog, Pred: cat, Confidence: 61.57%
True: dog, Pred: dog, Confidence: 56.

In [12]:
%pip install timm

Collecting timm
[0m  Downloading timm-1.0.22-py3-none-any.whl.metadata (63 kB)
Collecting torchvision (from timm)
  Downloading torchvision-0.24.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (5.9 kB)
Collecting pyyaml (from timm)
  Downloading pyyaml-6.0.3-cp311-cp311-macosx_11_0_arm64.whl.metadata (2.4 kB)
Collecting huggingface_hub (from timm)
  Downloading huggingface_hub-1.1.7-py3-none-any.whl.metadata (13 kB)
Collecting safetensors (from timm)
  Downloading safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl.metadata (4.1 kB)
Collecting hf-xet<2.0.0,>=1.2.0 (from huggingface_hub->timm)
  Downloading hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl.metadata (4.9 kB)
Collecting httpx<1,>=0.23.0 (from huggingface_hub->timm)
  Using cached httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)
Collecting shellingham (from huggingface_hub->timm)
  Using cached shellingham-1.5.4-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting tqdm>=4.42.1 (from huggingface_hub->timm)
  Using cached tqdm-4.67.1-py3

In [13]:
import timm
pretrained_model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=2)
pretrained_model = pretrained_model.to(device)
print(f'Pretrained model on {device}')
print(f'Total parameters: {sum(p.numel() for p in pretrained_model.parameters()):,}')


  from .autonotebook import tqdm as notebook_tqdm


Pretrained model on mps
Total parameters: 85,800,194


In [14]:
criterion_pretrained = nn.CrossEntropyLoss()
optimizer_pretrained = torch.optim.AdamW(pretrained_model.parameters(), lr=1e-4, weight_decay=0.01)

epochs_pretrained = 3
for epoch in range(epochs_pretrained):
    pretrained_model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    batch_count = 0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer_pretrained.zero_grad()
        outputs = pretrained_model(X_batch)
        loss = criterion_pretrained(outputs, y_batch)
        loss.backward()
        optimizer_pretrained.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += y_batch.size(0)
        train_correct += predicted.eq(y_batch).sum().item()
        batch_count += 1
        if batch_count % 100 == 0:
            print(f'  Batch {batch_count}/{len(train_loader)}')
    
    pretrained_model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = pretrained_model(X_batch)
            loss = criterion_pretrained(outputs, y_batch)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += y_batch.size(0)
            val_correct += predicted.eq(y_batch).sum().item()
    
    train_acc = 100. * train_correct / train_total
    val_acc = 100. * val_correct / val_total
    
    print(f'PRETRAINED Epoch {epoch+1}/{epochs_pretrained} - Train Loss: {train_loss/len(train_loader):.4f}, '
          f'Train Acc: {train_acc:.2f}%, Val Loss: {val_loss/len(val_loader):.4f}, '
          f'Val Acc: {val_acc:.2f}%')


  Batch 100/625
  Batch 200/625
  Batch 300/625
  Batch 400/625
  Batch 500/625
  Batch 600/625
PRETRAINED Epoch 1/3 - Train Loss: 0.0938, Train Acc: 96.47%, Val Loss: 0.0775, Val Acc: 97.02%


KeyboardInterrupt: 

In [None]:
train_acc