# Training a ViT based classifier for 20 classes of ImageNet dataset

## Setup: Imports and Device Configuration

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import os
import shutil
from tqdm import tqdm
import time
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Configuration and Hyperparameters

In [None]:
FULL_IMAGENET_PATH = '/path/to/your/imagenet' 
SUBSET_PATH = './ImageNet20'
NUM_CLASSES = 20
IMAGE_SIZE = 224
PATCH_SIZE = 16
NUM_CHANNELS = 3
D_MODEL = 384
NUM_HEADS = 6
NUM_LAYERS = 6
MLP_RATIO = 4
BATCH_SIZE = 64
EPOCHS = 50
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.05

## Dataset Preparation: Creating the ImageNet Subset

In [None]:
import os
from datasets import load_dataset, DatasetDict

IMAGENET_20_SYNSETS = [
    'n02113186','n02099601','n02123045','n02124075','n02871525','n03085013',
    'n03126707','n03417042','n03445777','n03770679','n03888257','n03930630',
    'n04141975','n04209133','n04254680','n01855672','n01514859','n02410509',
    'n02422699','n02480495'
]

SYNSET_TO_HUMAN_LABEL = {
    'n01514859': 'cock',
    'n01855672': 'goose',
    'n02099601': 'Eskimo dog, husky',
    'n02113186': 'Cardigan, Cardigan Welsh corgi',
    'n02123045': 'tabby, tabby cat',
    'n02124075': 'Egyptian cat',
    'n02410509': 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
    'n02422699': 'impala, Aepyceros melampus',
    'n02480495': 'gorilla, Gorilla gorilla',
    'n02871525': 'bookshop, bookstore, bookstall',
    'n03085013': 'computer keyboard, keypad',
    'n03126707': 'crane',
    'n03417042': 'garbage truck, dustcart',
    'n03445777': 'golf ball',
    'n03770679': 'minibus',
    'n03888257': 'parachute, chute',
    'n03930630': 'pizza, pizza pie',
    'n04141975': 'safe',
    'n04209133': 'snowplow, snowplough',
    'n04254680': 'sports car, sport car'
}

OUT = "./ImageNet20_hf"

if not os.path.exists(OUT):
    print("Preparing dataset for the first time...")
    ds_id = "benjamin-paine/imagenet-1k-256x256"
    train_full = load_dataset(ds_id, split="train")
    val_full   = load_dataset(ds_id, split="validation")
    all_class_names = train_full.features["label"].names
    name_to_id = {name: i for i, name in enumerate(all_class_names)}
    TARGET_CLASS_NAMES = [SYNSET_TO_HUMAN_LABEL[s] for s in IMAGENET_20_SYNSETS]
    missing = [name for name in TARGET_CLASS_NAMES if name not in name_to_id]
    if missing:
        raise RuntimeError(f"Could not find the following class names in the dataset: {missing}")

    tgt_ids = {name_to_id[name] for name in TARGET_CLASS_NAMES}
    
    print("Filtering for 20 classes...")
    train_20 = train_full.filter(lambda ex: ex["label"] in tgt_ids, num_proc=4)
    val_20   = val_full.filter(lambda ex: ex["label"] in tgt_ids, num_proc=4)

    print("Remapping labels to 0-19 range...")
    sorted_target_names = [SYNSET_TO_HUMAN_LABEL[s] for s in sorted(IMAGENET_20_SYNSETS)]
    remap = {name_to_id[name]: i for i, name in enumerate(sorted_target_names)}
    
    train_final = train_20.map(lambda ex: {"label": remap[ex["label"]]}, num_proc=4)
    val_20_remapped = val_20.map(lambda ex: {"label": remap[ex["label"]]}, num_proc=4)

    print("Splitting validation set into validation and test sets...")
    val_test_split = val_20_remapped.train_test_split(test_size=0.5, seed=42, stratify_by_column="label")
    
    final_dataset = DatasetDict({
        "train": train_final, 
        "val": val_test_split['train'], 
        "test": val_test_split['test']
    })

    final_dataset.save_to_disk(OUT)
    print(f"--- Dataset saved to {OUT} ---")
else:
    print(f"Dataset already exists at {OUT}. Skipping preparation.")

Dataset already exists at ./ImageNet20_hf. Skipping preparation.


## Data Loading: Transforms and DataLoaders

In [None]:
from datasets import load_from_disk
from torchvision import transforms
from torch.utils.data import DataLoader
import torch

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

val_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

final_dataset = load_from_disk(OUT)
train_dataset_hf = final_dataset['train']
val_dataset_hf = final_dataset['val']
test_dataset_hf = final_dataset['test']

def apply_train_transforms(examples):
    examples['pixel_values'] = [train_transform(image.convert("RGB")) for image in examples['image']]
    return examples

def apply_val_test_transforms(examples):
    examples['pixel_values'] = [val_test_transform(image.convert("RGB")) for image in examples['image']]
    return examples

train_dataset_hf.set_transform(apply_train_transforms)
val_dataset_hf.set_transform(apply_val_test_transforms)
test_dataset_hf.set_transform(apply_val_test_transforms)

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

train_loader = DataLoader(train_dataset_hf, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset_hf, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset_hf, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate_fn)

print("\n--- DataLoaders Ready ---")
print(f"Training samples:   {len(train_dataset_hf)}")
print(f"Validation samples: {len(val_dataset_hf)}")
print(f"Test samples:       {len(test_dataset_hf)}")


--- DataLoaders Ready ---
Training samples:   25729
Validation samples: 500
Test samples:       500


## Vision Transformer (ViT) Model Implementation

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, d_model):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        return x

class MLP(nn.Module):
    def __init__(self, d_model, mlp_ratio, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, int(d_model * mlp_ratio))
        self.act = nn.GELU()
        self.fc2 = nn.Linear(int(d_model * mlp_ratio), d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads, mlp_ratio, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, mlp_ratio, dropout)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, n_classes, d_model, n_heads, n_layers, mlp_ratio):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, d_model)
        num_patches = (img_size // patch_size) ** 2

        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, d_model))
        
        self.encoder = nn.Sequential(*[
            TransformerEncoder(d_model, n_heads, mlp_ratio) for _ in range(n_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, n_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        
        x = self.encoder(x)
        x = self.norm(x)
        
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)
        
        return x

## Training Setup: Model, Optimizer, Loss, and Scaler

In [8]:
model = VisionTransformer(
    img_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=NUM_CHANNELS,
    n_classes=NUM_CLASSES,
    d_model=D_MODEL,
    n_heads=NUM_HEADS,
    n_layers=NUM_LAYERS,
    mlp_ratio=MLP_RATIO
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params / 1e6:.2f}M")

Total trainable parameters: 11.03M


## Defining the Training and Validation Loops

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, scaler, device):
    model.train()
    running_loss = 0.0
    
    loop = tqdm(loader, desc="Training")
    for batch in loop:
        images = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            outputs = model(images)
            loss = criterion(outputs, labels)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item() * images.size(0)
        loop.set_postfix(loss=loss.item())

    return running_loss / len(loader.dataset)

def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        loop = tqdm(loader, desc="Validating")
        for batch in loop:
            images = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                outputs = model(images)
                loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    accuracy = 100 * correct / total
    return val_loss / len(loader.dataset), accuracy

## Running the Full Training Process

In [17]:
best_val_acc = 0.0
history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

print("Starting training...")
start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    epoch_duration = time.time() - epoch_start_time
    
    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Val Acc: {val_acc:.2f}% | "
          f"Time: {epoch_duration:.2f}s")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'vit_best_model.pth')
        print(f"New best model saved with accuracy: {best_val_acc:.2f}%")

total_training_time = time.time() - start_time
print(f"\nTraining finished in {total_training_time/60:.2f} minutes.")
print(f"Best validation accuracy: {best_val_acc:.2f}%")

Starting training...


Training: 100%|██████████| 805/805 [02:21<00:00,  5.69it/s, loss=2.48]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.09it/s]


Epoch 1/50 | Train Loss: 5.1865 | Val Loss: 4.9396 | Val Acc: 5.40% | Time: 142.76s
New best model saved with accuracy: 5.40%


Training: 100%|██████████| 805/805 [02:23<00:00,  5.61it/s, loss=1.98]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.54it/s]


Epoch 2/50 | Train Loss: 4.6815 | Val Loss: 4.4494 | Val Acc: 5.60% | Time: 144.78s
New best model saved with accuracy: 5.60%


Training: 100%|██████████| 805/805 [02:23<00:00,  5.59it/s, loss=4.3] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.42it/s]


Epoch 3/50 | Train Loss: 4.2037 | Val Loss: 3.9892 | Val Acc: 6.60% | Time: 145.11s
New best model saved with accuracy: 6.60%


Training: 100%|██████████| 805/805 [02:24<00:00,  5.59it/s, loss=4.56]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.21it/s]


Epoch 4/50 | Train Loss: 3.7412 | Val Loss: 3.5558 | Val Acc: 8.00% | Time: 145.33s
New best model saved with accuracy: 8.00%


Training: 100%|██████████| 805/805 [02:24<00:00,  5.58it/s, loss=3.95]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.51it/s]


Epoch 5/50 | Train Loss: 3.3874 | Val Loss: 3.2745 | Val Acc: 8.00% | Time: 145.54s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=2.62]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.13it/s]


Epoch 6/50 | Train Loss: 3.1765 | Val Loss: 3.1048 | Val Acc: 8.60% | Time: 146.78s
New best model saved with accuracy: 8.60%


Training: 100%|██████████| 805/805 [02:24<00:00,  5.55it/s, loss=2.93]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.61it/s]


Epoch 7/50 | Train Loss: 3.0705 | Val Loss: 3.0114 | Val Acc: 10.60% | Time: 146.11s
New best model saved with accuracy: 10.60%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=3.07]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.26it/s]


Epoch 8/50 | Train Loss: 2.9975 | Val Loss: 2.9572 | Val Acc: 11.40% | Time: 146.87s
New best model saved with accuracy: 11.40%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.55it/s, loss=2.8] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.39it/s]


Epoch 9/50 | Train Loss: 2.9270 | Val Loss: 2.8898 | Val Acc: 13.80% | Time: 146.24s
New best model saved with accuracy: 13.80%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.55it/s, loss=2.92]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.76it/s]


Epoch 10/50 | Train Loss: 2.8673 | Val Loss: 2.8140 | Val Acc: 15.20% | Time: 146.33s
New best model saved with accuracy: 15.20%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=3.29]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.29it/s]


Epoch 11/50 | Train Loss: 2.7924 | Val Loss: 2.7187 | Val Acc: 16.80% | Time: 146.56s
New best model saved with accuracy: 16.80%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=1.75]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.19it/s]


Epoch 12/50 | Train Loss: 2.6456 | Val Loss: 2.5333 | Val Acc: 23.60% | Time: 146.65s
New best model saved with accuracy: 23.60%


Training: 100%|██████████| 805/805 [02:24<00:00,  5.56it/s, loss=1.6] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.56it/s]


Epoch 13/50 | Train Loss: 2.4501 | Val Loss: 2.3177 | Val Acc: 27.60% | Time: 146.08s
New best model saved with accuracy: 27.60%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=3.09]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.56it/s]


Epoch 14/50 | Train Loss: 2.3084 | Val Loss: 2.2055 | Val Acc: 32.00% | Time: 146.41s
New best model saved with accuracy: 32.00%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.52it/s, loss=3.32]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.40it/s]


Epoch 15/50 | Train Loss: 2.1944 | Val Loss: 2.0724 | Val Acc: 35.40% | Time: 147.03s
New best model saved with accuracy: 35.40%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=0.196]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.73it/s]


Epoch 16/50 | Train Loss: 2.0830 | Val Loss: 1.9554 | Val Acc: 37.60% | Time: 146.60s
New best model saved with accuracy: 37.60%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=1.55]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.45it/s]


Epoch 17/50 | Train Loss: 1.9782 | Val Loss: 1.9041 | Val Acc: 40.80% | Time: 146.47s
New best model saved with accuracy: 40.80%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=1.69]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.10it/s]


Epoch 18/50 | Train Loss: 1.8830 | Val Loss: 1.8412 | Val Acc: 40.20% | Time: 146.67s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=0.271]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.59it/s]


Epoch 19/50 | Train Loss: 1.7976 | Val Loss: 1.7249 | Val Acc: 44.20% | Time: 146.50s
New best model saved with accuracy: 44.20%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.52it/s, loss=0.635]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.12it/s]


Epoch 20/50 | Train Loss: 1.7276 | Val Loss: 1.7812 | Val Acc: 42.60% | Time: 147.12s


Training: 100%|██████████| 805/805 [02:26<00:00,  5.51it/s, loss=2.08]
Validating: 100%|██████████| 16/16 [00:01<00:00, 12.99it/s]


Epoch 21/50 | Train Loss: 1.6714 | Val Loss: 1.7015 | Val Acc: 47.40% | Time: 147.28s
New best model saved with accuracy: 47.40%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.55it/s, loss=0.265]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.29it/s]


Epoch 22/50 | Train Loss: 1.6203 | Val Loss: 1.6460 | Val Acc: 45.80% | Time: 146.33s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.52it/s, loss=3.42] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.59it/s]


Epoch 23/50 | Train Loss: 1.5781 | Val Loss: 1.5028 | Val Acc: 52.00% | Time: 147.13s
New best model saved with accuracy: 52.00%


Training: 100%|██████████| 805/805 [02:26<00:00,  5.51it/s, loss=1.85] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 14.03it/s]


Epoch 24/50 | Train Loss: 1.5297 | Val Loss: 1.4936 | Val Acc: 53.20% | Time: 147.23s
New best model saved with accuracy: 53.20%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=0.902]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.32it/s]


Epoch 25/50 | Train Loss: 1.4876 | Val Loss: 1.6037 | Val Acc: 51.20% | Time: 146.39s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=2.45] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.27it/s]


Epoch 26/50 | Train Loss: 1.4475 | Val Loss: 1.5906 | Val Acc: 50.20% | Time: 146.45s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=1.61] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.48it/s]


Epoch 27/50 | Train Loss: 1.4337 | Val Loss: 1.3614 | Val Acc: 56.60% | Time: 146.89s
New best model saved with accuracy: 56.60%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=1.18] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.39it/s]


Epoch 28/50 | Train Loss: 1.3829 | Val Loss: 1.3897 | Val Acc: 57.80% | Time: 146.66s
New best model saved with accuracy: 57.80%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=1.24] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.61it/s]


Epoch 29/50 | Train Loss: 1.3493 | Val Loss: 1.4035 | Val Acc: 56.20% | Time: 146.62s


Training: 100%|██████████| 805/805 [02:24<00:00,  5.57it/s, loss=0.981]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.72it/s]


Epoch 30/50 | Train Loss: 1.3227 | Val Loss: 1.3677 | Val Acc: 57.00% | Time: 145.75s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=1.38] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.55it/s]


Epoch 31/50 | Train Loss: 1.2941 | Val Loss: 1.3074 | Val Acc: 57.80% | Time: 146.87s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=1.28] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.56it/s]


Epoch 32/50 | Train Loss: 1.2601 | Val Loss: 1.3687 | Val Acc: 56.00% | Time: 146.81s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=2.63] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.38it/s]


Epoch 33/50 | Train Loss: 1.2596 | Val Loss: 1.2847 | Val Acc: 59.20% | Time: 146.79s
New best model saved with accuracy: 59.20%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=0.015]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.90it/s]


Epoch 34/50 | Train Loss: 1.2199 | Val Loss: 1.3388 | Val Acc: 57.60% | Time: 146.50s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=5.07] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.62it/s]


Epoch 35/50 | Train Loss: 1.1973 | Val Loss: 1.2977 | Val Acc: 58.80% | Time: 146.87s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.52it/s, loss=0.559]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.53it/s]


Epoch 36/50 | Train Loss: 1.1637 | Val Loss: 1.2484 | Val Acc: 57.80% | Time: 146.97s


Training: 100%|██████████| 805/805 [02:26<00:00,  5.51it/s, loss=0.829]
Validating: 100%|██████████| 16/16 [00:01<00:00, 12.65it/s]


Epoch 37/50 | Train Loss: 1.1563 | Val Loss: 1.2263 | Val Acc: 60.20% | Time: 147.34s
New best model saved with accuracy: 60.20%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=0.00495]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.51it/s]


Epoch 38/50 | Train Loss: 1.1358 | Val Loss: 1.1451 | Val Acc: 62.80% | Time: 146.59s
New best model saved with accuracy: 62.80%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.52it/s, loss=0.526]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.35it/s]


Epoch 39/50 | Train Loss: 1.1109 | Val Loss: 1.2543 | Val Acc: 60.80% | Time: 147.17s


Training: 100%|██████████| 805/805 [02:26<00:00,  5.51it/s, loss=2.12] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.34it/s]


Epoch 40/50 | Train Loss: 1.0914 | Val Loss: 1.1918 | Val Acc: 62.00% | Time: 147.35s


Training: 100%|██████████| 805/805 [02:26<00:00,  5.51it/s, loss=4.05] 
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.63it/s]


Epoch 41/50 | Train Loss: 1.0627 | Val Loss: 1.1260 | Val Acc: 65.00% | Time: 147.20s
New best model saved with accuracy: 65.00%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=0.122]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.48it/s]


Epoch 42/50 | Train Loss: 1.0537 | Val Loss: 1.1326 | Val Acc: 63.00% | Time: 146.59s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=0.544]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.45it/s]


Epoch 43/50 | Train Loss: 1.0304 | Val Loss: 1.1499 | Val Acc: 61.20% | Time: 146.63s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.51it/s, loss=0.0123]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.45it/s]


Epoch 44/50 | Train Loss: 1.0148 | Val Loss: 1.1420 | Val Acc: 66.00% | Time: 147.18s
New best model saved with accuracy: 66.00%


Training: 100%|██████████| 805/805 [02:25<00:00,  5.52it/s, loss=0.0431]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.62it/s]


Epoch 45/50 | Train Loss: 0.9960 | Val Loss: 1.1471 | Val Acc: 64.00% | Time: 146.95s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.54it/s, loss=0.0923]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.26it/s]


Epoch 46/50 | Train Loss: 0.9744 | Val Loss: 1.1911 | Val Acc: 61.00% | Time: 146.59s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=0.0099]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.23it/s]


Epoch 47/50 | Train Loss: 0.9588 | Val Loss: 1.1864 | Val Acc: 64.40% | Time: 146.74s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.52it/s, loss=0.0894]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.63it/s]


Epoch 48/50 | Train Loss: 0.9353 | Val Loss: 1.1107 | Val Acc: 64.80% | Time: 146.92s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=0.292]
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.45it/s]


Epoch 49/50 | Train Loss: 0.9329 | Val Loss: 1.1288 | Val Acc: 65.80% | Time: 146.89s


Training: 100%|██████████| 805/805 [02:25<00:00,  5.53it/s, loss=1.1]  
Validating: 100%|██████████| 16/16 [00:01<00:00, 13.24it/s]

Epoch 50/50 | Train Loss: 0.9075 | Val Loss: 1.1404 | Val Acc: 65.20% | Time: 146.66s

Training finished in 122.23 minutes.
Best validation accuracy: 66.00%





## Experiment with Different Number of heads

In [None]:
HEADS_TO_TEST = [4, 8]
experiment_results = {}
for n_heads in HEADS_TO_TEST:
    print(f"\n{'='*50}")
    print(f"  STARTING EXPERIMENT: {n_heads} ATTENTION HEADS")
    print(f"{'='*50}\n")
    if D_MODEL % n_heads != 0:
        print(f"Skipping {n_heads} heads: D_MODEL ({D_MODEL}) is not divisible by {n_heads}.")
        continue

    model = VisionTransformer(
        img_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        in_channels=NUM_CHANNELS,
        n_classes=NUM_CLASSES,
        d_model=D_MODEL,
        n_heads=n_heads,
        n_layers=NUM_LAYERS,
        mlp_ratio=MLP_RATIO
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model with {n_heads} heads has {total_params / 1e6:.2f}M trainable parameters.")

    best_val_acc = 0.0
    model_save_path = f'vit_heads_{n_heads}_best.pth'
    history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

    print(f"Starting training for {n_heads}-head model...")
    start_time = time.time()

    for epoch in range(EPOCHS):
        epoch_start_time = time.time()
        
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        history[f'train_loss_{n_heads}'] = train_loss
        history[f'val_loss_{n_heads}'] = val_loss
        history[f'val_acc_{n_heads}'] = val_acc
        
        epoch_duration = time.time() - epoch_start_time
        
        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Time: {epoch_duration:.2f}s")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), model_save_path)
            print(f"--> New best model saved to {model_save_path} with accuracy: {best_val_acc:.2f}%")

    total_training_time = time.time() - start_time
    print(f"\nTraining for {n_heads}-head model finished in {total_training_time/60:.2f} minutes.")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    print(f"\n--- Evaluating best {n_heads}-head model on the TEST set ---")
    final_model = VisionTransformer(img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, in_channels=NUM_CHANNELS,
                                    n_classes=NUM_CLASSES, d_model=D_MODEL, n_heads=n_heads,
                                    n_layers=NUM_LAYERS, mlp_ratio=MLP_RATIO).to(device)
    final_model.load_state_dict(torch.load(model_save_path))
    
    test_loss, test_acc = validate(final_model, test_loader, criterion, device)
    print(f"Final Test Accuracy for {n_heads} heads: {test_acc:.2f}%")
    experiment_results[n_heads] = {
        'best_val_acc': best_val_acc,
        'test_acc': test_acc,
        'training_time_min': total_training_time / 60
    }
print(f"\n\n{'='*50}")
print(f"  EXPERIMENT SUMMARY: EFFECT OF NUMBER OF HEADS")
print(f"{'='*50}")
print(f"{'Heads':<10} | {'Best Val Acc (%)':<20} | {'Final Test Acc (%)':<20} | {'Train Time (min)':<20}")
print(f"-"*75)
for n_heads, results in experiment_results.items():
    print(f"{n_heads:<10} | {results['best_val_acc']:<20.2f} | {results['test_acc']:<20.2f} | {results['training_time_min']:<20.2f}")


  STARTING EXPERIMENT: 4 ATTENTION HEADS

Model with 4 heads has 11.03M trainable parameters.
Starting training for 4-head model...


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=2.14]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s]


Epoch 1/50 | Train Loss: 2.4379 | Val Loss: 2.0361 | Val Acc: 37.20% | Time: 135.44s
--> New best model saved to vit_heads_4_best.pth with accuracy: 37.20%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.02it/s, loss=0.326]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.71it/s]


Epoch 2/50 | Train Loss: 2.0386 | Val Loss: 1.8754 | Val Acc: 41.00% | Time: 134.55s
--> New best model saved to vit_heads_4_best.pth with accuracy: 41.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.662]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.47it/s]


Epoch 3/50 | Train Loss: 1.8834 | Val Loss: 1.9238 | Val Acc: 38.60% | Time: 136.01s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.97it/s, loss=1.12]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.82it/s]


Epoch 4/50 | Train Loss: 1.7836 | Val Loss: 1.7220 | Val Acc: 47.80% | Time: 136.81s
--> New best model saved to vit_heads_4_best.pth with accuracy: 47.80%


Training: 100%|██████████| 403/403 [02:16<00:00,  2.95it/s, loss=2.35]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.63it/s]


Epoch 5/50 | Train Loss: 1.7077 | Val Loss: 1.6764 | Val Acc: 47.00% | Time: 137.80s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.97it/s, loss=0.271]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.87it/s]


Epoch 6/50 | Train Loss: 1.6508 | Val Loss: 1.6148 | Val Acc: 48.60% | Time: 137.06s
--> New best model saved to vit_heads_4_best.pth with accuracy: 48.60%


Training: 100%|██████████| 403/403 [02:15<00:00,  2.97it/s, loss=2.23]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.50it/s]


Epoch 7/50 | Train Loss: 1.5972 | Val Loss: 1.7363 | Val Acc: 46.40% | Time: 136.96s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.96it/s, loss=0.532]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.41it/s]


Epoch 8/50 | Train Loss: 1.6054 | Val Loss: 1.5868 | Val Acc: 50.60% | Time: 137.47s
--> New best model saved to vit_heads_4_best.pth with accuracy: 50.60%


Training: 100%|██████████| 403/403 [02:16<00:00,  2.95it/s, loss=1.5] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.70it/s]


Epoch 9/50 | Train Loss: 1.5384 | Val Loss: 1.5407 | Val Acc: 51.00% | Time: 137.69s
--> New best model saved to vit_heads_4_best.pth with accuracy: 51.00%


Training: 100%|██████████| 403/403 [02:15<00:00,  2.96it/s, loss=1.54]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.67it/s]


Epoch 10/50 | Train Loss: 1.4949 | Val Loss: 1.6175 | Val Acc: 48.60% | Time: 137.18s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.96it/s, loss=0.281]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.86it/s]


Epoch 11/50 | Train Loss: 1.4828 | Val Loss: 1.5162 | Val Acc: 53.80% | Time: 137.47s
--> New best model saved to vit_heads_4_best.pth with accuracy: 53.80%


Training: 100%|██████████| 403/403 [02:16<00:00,  2.95it/s, loss=1.51]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.80it/s]


Epoch 12/50 | Train Loss: 1.4312 | Val Loss: 1.4961 | Val Acc: 52.80% | Time: 137.74s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.96it/s, loss=0.599]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.92it/s]


Epoch 13/50 | Train Loss: 1.4250 | Val Loss: 1.4060 | Val Acc: 55.80% | Time: 137.12s
--> New best model saved to vit_heads_4_best.pth with accuracy: 55.80%


Training: 100%|██████████| 403/403 [02:15<00:00,  2.97it/s, loss=1.78] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.62it/s]


Epoch 14/50 | Train Loss: 1.3800 | Val Loss: 1.4630 | Val Acc: 53.20% | Time: 137.12s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.95it/s, loss=3.76] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.71it/s]


Epoch 15/50 | Train Loss: 1.3854 | Val Loss: 1.3690 | Val Acc: 58.80% | Time: 137.69s
--> New best model saved to vit_heads_4_best.pth with accuracy: 58.80%


Training: 100%|██████████| 403/403 [02:16<00:00,  2.96it/s, loss=0.0241]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.89it/s]


Epoch 16/50 | Train Loss: 1.3395 | Val Loss: 1.4199 | Val Acc: 56.60% | Time: 137.26s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.95it/s, loss=0.043]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.60it/s]


Epoch 17/50 | Train Loss: 1.3149 | Val Loss: 1.3953 | Val Acc: 56.20% | Time: 137.83s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.95it/s, loss=1.41] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.55it/s]


Epoch 18/50 | Train Loss: 1.2831 | Val Loss: 1.3311 | Val Acc: 57.20% | Time: 137.85s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.96it/s, loss=3.78] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.85it/s]


Epoch 19/50 | Train Loss: 1.2923 | Val Loss: 1.3180 | Val Acc: 57.80% | Time: 137.51s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.95it/s, loss=2.23] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.45it/s]


Epoch 20/50 | Train Loss: 1.3186 | Val Loss: 1.3737 | Val Acc: 56.40% | Time: 137.77s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.95it/s, loss=0.401]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.81it/s]


Epoch 21/50 | Train Loss: 1.2429 | Val Loss: 1.3278 | Val Acc: 58.00% | Time: 137.98s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.96it/s, loss=1.14] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.81it/s]


Epoch 22/50 | Train Loss: 1.2250 | Val Loss: 1.4526 | Val Acc: 53.60% | Time: 137.30s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.98it/s, loss=4.28] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.40it/s]


Epoch 23/50 | Train Loss: 1.2488 | Val Loss: 1.2746 | Val Acc: 60.80% | Time: 136.65s
--> New best model saved to vit_heads_4_best.pth with accuracy: 60.80%


Training: 100%|██████████| 403/403 [02:16<00:00,  2.96it/s, loss=2.89] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.45it/s]


Epoch 24/50 | Train Loss: 1.1788 | Val Loss: 1.2967 | Val Acc: 58.60% | Time: 137.27s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.98it/s, loss=0.0709]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.76it/s]


Epoch 25/50 | Train Loss: 1.1962 | Val Loss: 1.2796 | Val Acc: 59.00% | Time: 136.45s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.97it/s, loss=1.32] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.71it/s]


Epoch 26/50 | Train Loss: 1.1377 | Val Loss: 1.4349 | Val Acc: 55.40% | Time: 136.91s


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=2.9]  
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.32it/s]


Epoch 27/50 | Train Loss: 1.1578 | Val Loss: 1.3976 | Val Acc: 58.60% | Time: 136.05s


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=1.43] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.41it/s]


Epoch 28/50 | Train Loss: 1.1726 | Val Loss: 1.2401 | Val Acc: 59.60% | Time: 135.87s


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.34] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.78it/s]


Epoch 29/50 | Train Loss: 1.1105 | Val Loss: 1.3711 | Val Acc: 57.00% | Time: 135.12s


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.258]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.37it/s]


Epoch 30/50 | Train Loss: 1.1195 | Val Loss: 1.1710 | Val Acc: 63.00% | Time: 135.46s
--> New best model saved to vit_heads_4_best.pth with accuracy: 63.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.16] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s]


Epoch 31/50 | Train Loss: 1.0615 | Val Loss: 1.2908 | Val Acc: 62.00% | Time: 135.45s


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=0.0659]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.83it/s]


Epoch 32/50 | Train Loss: 1.0601 | Val Loss: 1.2210 | Val Acc: 60.80% | Time: 134.90s


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=0.174]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.60it/s]


Epoch 33/50 | Train Loss: 1.0299 | Val Loss: 1.1287 | Val Acc: 66.60% | Time: 134.28s
--> New best model saved to vit_heads_4_best.pth with accuracy: 66.60%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.02it/s, loss=1.84] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s]


Epoch 34/50 | Train Loss: 1.0142 | Val Loss: 1.2018 | Val Acc: 64.40% | Time: 134.67s


Training: 100%|██████████| 403/403 [02:13<00:00,  3.02it/s, loss=1.94] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.81it/s]


Epoch 35/50 | Train Loss: 1.0019 | Val Loss: 1.1808 | Val Acc: 63.40% | Time: 134.80s


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=4.27] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.67it/s]


Epoch 36/50 | Train Loss: 1.0081 | Val Loss: 1.1743 | Val Acc: 62.20% | Time: 133.98s


Training: 100%|██████████| 403/403 [02:13<00:00,  3.02it/s, loss=0.000309]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.55it/s]


Epoch 37/50 | Train Loss: 0.9711 | Val Loss: 1.1573 | Val Acc: 64.20% | Time: 134.84s


Training: 100%|██████████| 403/403 [02:13<00:00,  3.02it/s, loss=0.032]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.78it/s]


Epoch 38/50 | Train Loss: 0.9567 | Val Loss: 1.2032 | Val Acc: 63.60% | Time: 134.51s


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=1.24] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.90it/s]


Epoch 39/50 | Train Loss: 0.9283 | Val Loss: 1.3389 | Val Acc: 60.60% | Time: 134.04s


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=2.81] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.82it/s]


Epoch 40/50 | Train Loss: 0.9483 | Val Loss: 1.1156 | Val Acc: 65.40% | Time: 134.16s


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.852]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.84it/s]


Epoch 41/50 | Train Loss: 0.9249 | Val Loss: 1.2183 | Val Acc: 63.80% | Time: 135.71s


Training: 100%|██████████| 403/403 [02:13<00:00,  3.02it/s, loss=0.0156]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.54it/s]


Epoch 42/50 | Train Loss: 0.8841 | Val Loss: 1.0837 | Val Acc: 65.20% | Time: 134.77s


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=0.0123]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.63it/s]


Epoch 43/50 | Train Loss: 0.8681 | Val Loss: 1.1785 | Val Acc: 64.60% | Time: 135.15s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.98it/s, loss=0.251]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.79it/s]


Epoch 44/50 | Train Loss: 0.8607 | Val Loss: 1.1380 | Val Acc: 64.60% | Time: 136.33s


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.0135]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.87it/s]


Epoch 45/50 | Train Loss: 0.8579 | Val Loss: 1.0853 | Val Acc: 66.80% | Time: 135.79s
--> New best model saved to vit_heads_4_best.pth with accuracy: 66.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.213]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.69it/s]


Epoch 46/50 | Train Loss: 0.8197 | Val Loss: 1.0761 | Val Acc: 66.00% | Time: 136.10s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.97it/s, loss=1.45] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.91it/s]


Epoch 47/50 | Train Loss: 0.8204 | Val Loss: 1.1148 | Val Acc: 65.00% | Time: 136.86s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.98it/s, loss=3.68] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.58it/s]


Epoch 48/50 | Train Loss: 0.8183 | Val Loss: 1.1363 | Val Acc: 65.80% | Time: 136.50s


Training: 100%|██████████| 403/403 [02:15<00:00,  2.98it/s, loss=1.06] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.41it/s]


Epoch 49/50 | Train Loss: 0.8456 | Val Loss: 1.2279 | Val Acc: 61.80% | Time: 136.55s


Training: 100%|██████████| 403/403 [02:16<00:00,  2.96it/s, loss=0.115]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.35it/s]


Epoch 50/50 | Train Loss: 0.8136 | Val Loss: 1.1997 | Val Acc: 64.20% | Time: 137.42s

Training for 4-head model finished in 113.64 minutes.
Best validation accuracy: 66.80%

--- Evaluating best 4-head model on the TEST set ---


Validating: 100%|██████████| 8/8 [00:01<00:00,  6.76it/s]


Final Test Accuracy for 4 heads: 67.60%

  STARTING EXPERIMENT: 8 ATTENTION HEADS

Model with 8 heads has 11.03M trainable parameters.
Starting training for 8-head model...


Training: 100%|██████████| 403/403 [02:27<00:00,  2.74it/s, loss=2.39]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.61it/s]


Epoch 1/50 | Train Loss: 2.4009 | Val Loss: 2.1347 | Val Acc: 35.20% | Time: 148.52s
--> New best model saved to vit_heads_8_best.pth with accuracy: 35.20%


Training: 100%|██████████| 403/403 [02:26<00:00,  2.75it/s, loss=1.54]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.02it/s]


Epoch 2/50 | Train Loss: 2.0117 | Val Loss: 1.9045 | Val Acc: 39.20% | Time: 147.63s
--> New best model saved to vit_heads_8_best.pth with accuracy: 39.20%


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=1.41]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.93it/s]


Epoch 3/50 | Train Loss: 1.8273 | Val Loss: 1.8383 | Val Acc: 41.80% | Time: 147.44s
--> New best model saved to vit_heads_8_best.pth with accuracy: 41.80%


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=1.5] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.27it/s]


Epoch 4/50 | Train Loss: 1.7296 | Val Loss: 1.9328 | Val Acc: 40.00% | Time: 147.49s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.75it/s, loss=2.22]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.82it/s]


Epoch 5/50 | Train Loss: 1.6602 | Val Loss: 1.7251 | Val Acc: 43.20% | Time: 147.71s
--> New best model saved to vit_heads_8_best.pth with accuracy: 43.20%


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=3.03]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.93it/s]


Epoch 6/50 | Train Loss: 1.6100 | Val Loss: 1.6026 | Val Acc: 49.80% | Time: 147.42s
--> New best model saved to vit_heads_8_best.pth with accuracy: 49.80%


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=0.049]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.24it/s]


Epoch 7/50 | Train Loss: 1.5256 | Val Loss: 1.4419 | Val Acc: 57.20% | Time: 147.49s
--> New best model saved to vit_heads_8_best.pth with accuracy: 57.20%


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=1.57]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.02it/s]


Epoch 8/50 | Train Loss: 1.4834 | Val Loss: 1.5851 | Val Acc: 51.60% | Time: 147.60s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.75it/s, loss=3.34]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.22it/s]


Epoch 9/50 | Train Loss: 1.4776 | Val Loss: 1.7273 | Val Acc: 47.00% | Time: 147.62s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=3.08]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.97it/s]


Epoch 10/50 | Train Loss: 1.5119 | Val Loss: 1.3993 | Val Acc: 55.80% | Time: 147.48s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=0.696]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.11it/s]


Epoch 11/50 | Train Loss: 1.4262 | Val Loss: 1.4326 | Val Acc: 53.40% | Time: 147.47s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=0.778]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.17it/s]


Epoch 12/50 | Train Loss: 1.3825 | Val Loss: 1.3783 | Val Acc: 56.80% | Time: 147.50s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=3.03] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.00it/s]


Epoch 13/50 | Train Loss: 1.3436 | Val Loss: 1.4091 | Val Acc: 55.40% | Time: 147.57s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.72it/s, loss=2.62] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.79it/s]


Epoch 14/50 | Train Loss: 1.3430 | Val Loss: 1.3994 | Val Acc: 55.80% | Time: 149.31s


Training: 100%|██████████| 403/403 [02:28<00:00,  2.72it/s, loss=0.0279]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.12it/s]


Epoch 15/50 | Train Loss: 1.3156 | Val Loss: 1.3048 | Val Acc: 57.80% | Time: 149.49s
--> New best model saved to vit_heads_8_best.pth with accuracy: 57.80%


Training: 100%|██████████| 403/403 [02:28<00:00,  2.72it/s, loss=1.66] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.09it/s]


Epoch 16/50 | Train Loss: 1.2717 | Val Loss: 1.2654 | Val Acc: 60.00% | Time: 149.58s
--> New best model saved to vit_heads_8_best.pth with accuracy: 60.00%


Training: 100%|██████████| 403/403 [02:28<00:00,  2.72it/s, loss=2.32] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.09it/s]


Epoch 17/50 | Train Loss: 1.2465 | Val Loss: 1.4103 | Val Acc: 56.20% | Time: 149.50s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.74it/s, loss=0.108]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.13it/s]


Epoch 18/50 | Train Loss: 1.2928 | Val Loss: 1.2742 | Val Acc: 60.20% | Time: 148.65s
--> New best model saved to vit_heads_8_best.pth with accuracy: 60.20%


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=0.31] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.21it/s]


Epoch 19/50 | Train Loss: 1.2028 | Val Loss: 1.3690 | Val Acc: 55.80% | Time: 147.47s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.74it/s, loss=1.27] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.26it/s]


Epoch 20/50 | Train Loss: 1.1949 | Val Loss: 1.2688 | Val Acc: 59.00% | Time: 148.10s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=0.942]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.10it/s]


Epoch 21/50 | Train Loss: 1.2002 | Val Loss: 1.2406 | Val Acc: 60.20% | Time: 147.51s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=0.633]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.07it/s]


Epoch 22/50 | Train Loss: 1.1491 | Val Loss: 1.2173 | Val Acc: 60.00% | Time: 147.34s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.75it/s, loss=0.322]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.22it/s]


Epoch 23/50 | Train Loss: 1.1477 | Val Loss: 1.2578 | Val Acc: 62.80% | Time: 147.77s
--> New best model saved to vit_heads_8_best.pth with accuracy: 62.80%


Training: 100%|██████████| 403/403 [02:25<00:00,  2.76it/s, loss=0.473]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.07it/s]


Epoch 24/50 | Train Loss: 1.1369 | Val Loss: 1.3046 | Val Acc: 59.40% | Time: 147.28s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.75it/s, loss=1.37] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.09it/s]


Epoch 25/50 | Train Loss: 1.0949 | Val Loss: 1.2726 | Val Acc: 60.80% | Time: 147.64s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=1.41] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.08it/s]


Epoch 26/50 | Train Loss: 1.1083 | Val Loss: 1.1622 | Val Acc: 65.20% | Time: 147.55s
--> New best model saved to vit_heads_8_best.pth with accuracy: 65.20%


Training: 100%|██████████| 403/403 [02:26<00:00,  2.75it/s, loss=0.0482]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.05it/s]


Epoch 27/50 | Train Loss: 1.0575 | Val Loss: 1.1715 | Val Acc: 64.60% | Time: 147.73s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.76it/s, loss=0.213]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.20it/s]


Epoch 28/50 | Train Loss: 1.0429 | Val Loss: 1.2230 | Val Acc: 61.20% | Time: 147.37s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.75it/s, loss=1]    
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.07it/s]


Epoch 29/50 | Train Loss: 1.0312 | Val Loss: 1.2516 | Val Acc: 60.40% | Time: 148.14s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.74it/s, loss=0.72] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.08it/s]


Epoch 30/50 | Train Loss: 1.0331 | Val Loss: 1.1028 | Val Acc: 65.60% | Time: 148.14s
--> New best model saved to vit_heads_8_best.pth with accuracy: 65.60%


Training: 100%|██████████| 403/403 [02:27<00:00,  2.72it/s, loss=2.39] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.03it/s]


Epoch 31/50 | Train Loss: 1.0010 | Val Loss: 1.0718 | Val Acc: 68.00% | Time: 149.26s
--> New best model saved to vit_heads_8_best.pth with accuracy: 68.00%


Training: 100%|██████████| 403/403 [02:27<00:00,  2.74it/s, loss=2.45] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.17it/s]


Epoch 32/50 | Train Loss: 0.9815 | Val Loss: 1.1328 | Val Acc: 63.60% | Time: 148.48s


Training: 100%|██████████| 403/403 [02:25<00:00,  2.77it/s, loss=0.0167]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.59it/s]


Epoch 33/50 | Train Loss: 0.9960 | Val Loss: 1.1373 | Val Acc: 64.60% | Time: 146.75s


Training: 100%|██████████| 403/403 [02:24<00:00,  2.78it/s, loss=0.422]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.93it/s]


Epoch 34/50 | Train Loss: 0.9470 | Val Loss: 1.1536 | Val Acc: 62.60% | Time: 146.32s


Training: 100%|██████████| 403/403 [02:25<00:00,  2.78it/s, loss=2]    
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.34it/s]


Epoch 35/50 | Train Loss: 0.9354 | Val Loss: 1.2302 | Val Acc: 63.40% | Time: 146.37s


Training: 100%|██████████| 403/403 [02:25<00:00,  2.77it/s, loss=2.12] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.10it/s]


Epoch 36/50 | Train Loss: 0.9585 | Val Loss: 1.0993 | Val Acc: 66.20% | Time: 146.93s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.74it/s, loss=2.4]  
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.88it/s]


Epoch 37/50 | Train Loss: 0.8957 | Val Loss: 1.1619 | Val Acc: 66.60% | Time: 148.65s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.73it/s, loss=0.092]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.12it/s]


Epoch 38/50 | Train Loss: 0.9120 | Val Loss: 1.1486 | Val Acc: 64.00% | Time: 148.93s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.73it/s, loss=0.0485]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.91it/s]


Epoch 39/50 | Train Loss: 0.8672 | Val Loss: 1.1966 | Val Acc: 63.60% | Time: 149.07s


Training: 100%|██████████| 403/403 [02:28<00:00,  2.72it/s, loss=0.134]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.17it/s]


Epoch 40/50 | Train Loss: 0.8546 | Val Loss: 1.1413 | Val Acc: 66.60% | Time: 149.31s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.73it/s, loss=0.308]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.83it/s]


Epoch 41/50 | Train Loss: 0.8344 | Val Loss: 1.0624 | Val Acc: 68.20% | Time: 149.12s
--> New best model saved to vit_heads_8_best.pth with accuracy: 68.20%


Training: 100%|██████████| 403/403 [02:27<00:00,  2.73it/s, loss=0.00162]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.11it/s]


Epoch 42/50 | Train Loss: 0.8285 | Val Loss: 1.1072 | Val Acc: 66.20% | Time: 148.74s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.74it/s, loss=0.0555]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.84it/s]


Epoch 43/50 | Train Loss: 0.8067 | Val Loss: 1.2035 | Val Acc: 66.20% | Time: 148.66s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.74it/s, loss=4.93] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.94it/s]


Epoch 44/50 | Train Loss: 0.7904 | Val Loss: 1.0742 | Val Acc: 70.60% | Time: 148.58s
--> New best model saved to vit_heads_8_best.pth with accuracy: 70.60%


Training: 100%|██████████| 403/403 [02:26<00:00,  2.74it/s, loss=0.228]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.90it/s]


Epoch 45/50 | Train Loss: 0.7739 | Val Loss: 1.1589 | Val Acc: 67.60% | Time: 148.33s


Training: 100%|██████████| 403/403 [02:26<00:00,  2.74it/s, loss=0.0703]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.01it/s]


Epoch 46/50 | Train Loss: 0.7795 | Val Loss: 1.0768 | Val Acc: 68.40% | Time: 148.29s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.73it/s, loss=0.00858]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.99it/s]


Epoch 47/50 | Train Loss: 0.7513 | Val Loss: 1.0946 | Val Acc: 69.20% | Time: 148.88s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.73it/s, loss=1.1]  
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.20it/s]


Epoch 48/50 | Train Loss: 0.7362 | Val Loss: 1.2074 | Val Acc: 65.20% | Time: 148.74s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.73it/s, loss=0.000436]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.84it/s]


Epoch 49/50 | Train Loss: 0.7535 | Val Loss: 1.1263 | Val Acc: 65.60% | Time: 148.85s


Training: 100%|██████████| 403/403 [02:27<00:00,  2.74it/s, loss=0.00553]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.97it/s]


Epoch 50/50 | Train Loss: 0.7239 | Val Loss: 1.0895 | Val Acc: 66.80% | Time: 148.44s

Training for 8-head model finished in 123.49 minutes.
Best validation accuracy: 70.60%

--- Evaluating best 8-head model on the TEST set ---


Validating: 100%|██████████| 8/8 [00:01<00:00,  5.92it/s]

Final Test Accuracy for 8 heads: 67.40%


  EXPERIMENT SUMMARY: EFFECT OF NUMBER OF HEADS
Heads      | Best Val Acc (%)     | Final Test Acc (%)   | Train Time (min)    
---------------------------------------------------------------------------
4          | 66.80                | 67.60                | 113.64              
8          | 70.60                | 67.40                | 123.49              





## ViT Architecture

In [None]:
import torch
import torch.nn as nn
import math

class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, d_model):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, d_model * 3)
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        return x

class MLP(nn.Module):
    def __init__(self, d_model, mlp_ratio, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, int(d_model * mlp_ratio))
        self.act = nn.GELU()
        self.fc2 = nn.Linear(int(d_model * mlp_ratio), d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, d_model, n_heads, mlp_ratio, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, mlp_ratio, dropout)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, n_classes, d_model, n_heads, n_layers, mlp_ratio, pos_embed_type='learnable'):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, d_model)
        num_patches = (img_size // patch_size) ** 2
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        
        if pos_embed_type == 'learnable':
            print("Using Learnable Positional Embedding")
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, d_model))
        elif pos_embed_type == 'sine':
            print("Using Sinusoidal Positional Embedding")
            pe = torch.zeros(num_patches + 1, d_model)
            position = torch.arange(0, num_patches + 1, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0)
            self.register_buffer('pos_embed', pe)
        else:
            print("Not using any Positional Embedding")
            self.pos_embed = None

        self.encoder = nn.Sequential(*[
            TransformerEncoder(d_model, n_heads, mlp_ratio) for _ in range(n_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, n_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        if self.pos_embed is not None:
            x = x + self.pos_embed
        
        x = self.encoder(x)
        x = self.norm(x)
        
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)
        
        return x

## ViT Training

In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim

POS_EMBEDS_TO_TEST = ['learnable', 'sine', None]
FIXED_NUM_HEADS = 4
experiment_results = {}

for pos_embed_type in POS_EMBEDS_TO_TEST:
    pos_embed_name = str(pos_embed_type)
    
    print(f"\n{'='*60}")
    print(f"  STARTING EXPERIMENT: {pos_embed_name.upper()} POSITIONAL EMBEDDING")
    print(f"{'='*60}\n")
    
    model = VisionTransformer(
        img_size=IMAGE_SIZE,
        patch_size=PATCH_SIZE,
        in_channels=NUM_CHANNELS,
        n_classes=NUM_CLASSES,
        d_model=D_MODEL,
        n_heads=FIXED_NUM_HEADS,
        n_layers=NUM_LAYERS,
        mlp_ratio=MLP_RATIO,
        pos_embed_type=pos_embed_type
    ).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model has {total_params / 1e6:.2f}M trainable parameters.")

    best_val_acc = 0.0
    model_save_path = f'vit_pos_{pos_embed_name.lower()}_best.pth'
    history = {}

    print(f"Starting training for {pos_embed_name} model...")
    start_time = time.time()

    for epoch in range(EPOCHS):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), model_save_path)
            print(f"--> New best model saved to {model_save_path} with accuracy: {best_val_acc:.2f}%")

    total_training_time = time.time() - start_time
    print(f"\nTraining for {pos_embed_name} model finished in {total_training_time/60:.2f} minutes.")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")

    print(f"\n--- Evaluating best {pos_embed_name} model on the TEST set ---")
    final_model = VisionTransformer(img_size=IMAGE_SIZE, patch_size=PATCH_SIZE, in_channels=NUM_CHANNELS,
                                    n_classes=NUM_CLASSES, d_model=D_MODEL, n_heads=FIXED_NUM_HEADS,
                                    n_layers=NUM_LAYERS, mlp_ratio=MLP_RATIO, pos_embed_type=pos_embed_type).to(device)
    final_model.load_state_dict(torch.load(model_save_path))
    
    test_loss, test_acc = validate(final_model, test_loader, criterion, device)
    print(f"Final Test Accuracy for {pos_embed_name} model: {test_acc:.2f}%")

    experiment_results[pos_embed_name] = {
        'best_val_acc': best_val_acc,
        'test_acc': test_acc,
        'training_time_min': total_training_time / 60
    }

print(f"\n\n{'='*75}")
print(f"  EXPERIMENT SUMMARY: EFFECT OF POSITIONAL EMBEDDING (Heads={FIXED_NUM_HEADS})")
print(f"{'='*75}")
print(f"{'Positional Embedding':<25} | {'Best Val Acc (%)':<20} | {'Final Test Acc (%)':<20} | {'Train Time (min)':<20}")
print(f"-"*90)
for pos_embed_name, results in experiment_results.items():
    print(f"{pos_embed_name:<25} | {results['best_val_acc']:<20.2f} | {results['test_acc']:<20.2f} | {results['training_time_min']:<20.2f}")


  STARTING EXPERIMENT: LEARNABLE POSITIONAL EMBEDDING

Using Learnable Positional Embedding
Model has 11.03M trainable parameters.
Starting training for learnable model...


Training: 100%|██████████| 403/403 [02:13<00:00,  3.02it/s, loss=1.58]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.48it/s]


Epoch 1/50 | Train Loss: 2.4275 | Val Loss: 2.1593 | Val Acc: 33.20%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 33.20%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.68]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.52it/s]


Epoch 2/50 | Train Loss: 2.0644 | Val Loss: 1.9522 | Val Acc: 39.60%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 39.60%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.72]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.74it/s]


Epoch 3/50 | Train Loss: 1.9049 | Val Loss: 2.0888 | Val Acc: 36.20%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.32]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.54it/s]


Epoch 4/50 | Train Loss: 1.8189 | Val Loss: 1.8850 | Val Acc: 40.60%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 40.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.45]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.56it/s]


Epoch 5/50 | Train Loss: 1.7255 | Val Loss: 1.7303 | Val Acc: 45.40%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 45.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.0732]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.48it/s]


Epoch 6/50 | Train Loss: 1.6726 | Val Loss: 1.7494 | Val Acc: 46.60%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 46.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=2.45]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.61it/s]


Epoch 7/50 | Train Loss: 1.6049 | Val Loss: 1.6552 | Val Acc: 47.20%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 47.20%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.918]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.19it/s]


Epoch 8/50 | Train Loss: 1.5558 | Val Loss: 1.6229 | Val Acc: 46.40%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.24]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.67it/s]


Epoch 9/50 | Train Loss: 1.5320 | Val Loss: 1.6530 | Val Acc: 45.00%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.1] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.93it/s]


Epoch 10/50 | Train Loss: 1.4982 | Val Loss: 1.6008 | Val Acc: 49.60%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 49.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=2.8] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.93it/s]


Epoch 11/50 | Train Loss: 1.4497 | Val Loss: 1.4242 | Val Acc: 55.00%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 55.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.72]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.83it/s]


Epoch 12/50 | Train Loss: 1.3948 | Val Loss: 1.4992 | Val Acc: 52.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=2.66]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.59it/s]


Epoch 13/50 | Train Loss: 1.4041 | Val Loss: 1.5652 | Val Acc: 48.20%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=1.83] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.60it/s]


Epoch 14/50 | Train Loss: 1.3828 | Val Loss: 1.4774 | Val Acc: 54.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.349]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s]


Epoch 15/50 | Train Loss: 1.3674 | Val Loss: 1.3579 | Val Acc: 57.80%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 57.80%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=3.39] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s]


Epoch 16/50 | Train Loss: 1.3091 | Val Loss: 1.3375 | Val Acc: 59.60%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 59.60%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.33] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.66it/s]


Epoch 17/50 | Train Loss: 1.2817 | Val Loss: 1.3952 | Val Acc: 55.20%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.78] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.59it/s]


Epoch 18/50 | Train Loss: 1.2883 | Val Loss: 1.3604 | Val Acc: 56.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.88] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.55it/s]


Epoch 19/50 | Train Loss: 1.2799 | Val Loss: 1.2686 | Val Acc: 60.40%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 60.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.42] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.59it/s]


Epoch 20/50 | Train Loss: 1.2525 | Val Loss: 1.2932 | Val Acc: 58.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.332]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.40it/s]


Epoch 21/50 | Train Loss: 1.2131 | Val Loss: 1.3550 | Val Acc: 56.20%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.212]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s]


Epoch 22/50 | Train Loss: 1.1939 | Val Loss: 1.2582 | Val Acc: 60.60%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 60.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=0.0358]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.67it/s]


Epoch 23/50 | Train Loss: 1.1681 | Val Loss: 1.3265 | Val Acc: 56.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=0.0795]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.61it/s]


Epoch 24/50 | Train Loss: 1.1427 | Val Loss: 1.2254 | Val Acc: 61.60%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 61.60%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.02it/s, loss=0.793]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.70it/s]


Epoch 25/50 | Train Loss: 1.1226 | Val Loss: 1.2960 | Val Acc: 58.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.03] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.67it/s]


Epoch 26/50 | Train Loss: 1.1203 | Val Loss: 1.1579 | Val Acc: 64.40%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 64.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=0.18] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.65it/s]


Epoch 27/50 | Train Loss: 1.0953 | Val Loss: 1.1682 | Val Acc: 63.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.899]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.53it/s]


Epoch 28/50 | Train Loss: 1.0701 | Val Loss: 1.2072 | Val Acc: 63.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.225]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.61it/s]


Epoch 29/50 | Train Loss: 1.0606 | Val Loss: 1.2413 | Val Acc: 62.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.771]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.42it/s]


Epoch 30/50 | Train Loss: 1.0342 | Val Loss: 1.2956 | Val Acc: 58.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=0.0483]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.57it/s]


Epoch 31/50 | Train Loss: 1.0330 | Val Loss: 1.1787 | Val Acc: 62.20%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=0.00101]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.58it/s]


Epoch 32/50 | Train Loss: 0.9969 | Val Loss: 1.1329 | Val Acc: 65.40%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 65.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=0.849]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.56it/s]


Epoch 33/50 | Train Loss: 0.9760 | Val Loss: 1.1242 | Val Acc: 65.60%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 65.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=3.03] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.65it/s]


Epoch 34/50 | Train Loss: 0.9653 | Val Loss: 1.1788 | Val Acc: 64.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=3.12] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.61it/s]


Epoch 35/50 | Train Loss: 0.9451 | Val Loss: 1.1775 | Val Acc: 64.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.236]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.65it/s]


Epoch 36/50 | Train Loss: 0.9747 | Val Loss: 1.1412 | Val Acc: 66.40%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 66.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.00972]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.58it/s]


Epoch 37/50 | Train Loss: 0.9165 | Val Loss: 1.1805 | Val Acc: 64.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.0191]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.35it/s]


Epoch 38/50 | Train Loss: 0.8964 | Val Loss: 1.2031 | Val Acc: 63.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.119]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.53it/s]


Epoch 39/50 | Train Loss: 0.8714 | Val Loss: 1.1385 | Val Acc: 63.00%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=3.14] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.71it/s]


Epoch 40/50 | Train Loss: 0.8594 | Val Loss: 1.0519 | Val Acc: 66.60%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 66.60%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=0.00102]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.64it/s]


Epoch 41/50 | Train Loss: 0.8426 | Val Loss: 1.1021 | Val Acc: 67.00%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 67.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.221]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.29it/s]


Epoch 42/50 | Train Loss: 0.8295 | Val Loss: 1.1336 | Val Acc: 65.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.346]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.76it/s]


Epoch 43/50 | Train Loss: 0.8192 | Val Loss: 1.2076 | Val Acc: 63.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.198]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.44it/s]


Epoch 44/50 | Train Loss: 0.8139 | Val Loss: 1.1217 | Val Acc: 64.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=8.8e-5]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.62it/s]


Epoch 45/50 | Train Loss: 0.7920 | Val Loss: 1.1209 | Val Acc: 65.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.226]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.67it/s]


Epoch 46/50 | Train Loss: 0.7684 | Val Loss: 1.1586 | Val Acc: 65.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.53] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.61it/s]


Epoch 47/50 | Train Loss: 0.7562 | Val Loss: 1.0737 | Val Acc: 70.40%
--> New best model saved to vit_pos_learnable_best.pth with accuracy: 70.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=0.966]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.29it/s]


Epoch 48/50 | Train Loss: 0.7634 | Val Loss: 1.1903 | Val Acc: 64.40%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.87] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.59it/s]


Epoch 49/50 | Train Loss: 0.7495 | Val Loss: 1.1287 | Val Acc: 64.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=0.0109]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s]


Epoch 50/50 | Train Loss: 0.7686 | Val Loss: 1.1416 | Val Acc: 68.20%

Training for learnable model finished in 113.00 minutes.
Best validation accuracy: 70.40%

--- Evaluating best learnable model on the TEST set ---
Using Learnable Positional Embedding


Validating: 100%|██████████| 8/8 [00:01<00:00,  6.70it/s]


Final Test Accuracy for learnable model: 68.80%

  STARTING EXPERIMENT: SINE POSITIONAL EMBEDDING

Using Sinusoidal Positional Embedding
Model has 10.95M trainable parameters.
Starting training for sine model...


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.73]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.87it/s]


Epoch 1/50 | Train Loss: 2.4187 | Val Loss: 2.1872 | Val Acc: 33.40%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 33.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.03]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.69it/s]


Epoch 2/50 | Train Loss: 2.0533 | Val Loss: 1.9888 | Val Acc: 39.60%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 39.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.783]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.73it/s]


Epoch 3/50 | Train Loss: 1.9073 | Val Loss: 1.9104 | Val Acc: 39.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=1.07]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.75it/s]


Epoch 4/50 | Train Loss: 1.7756 | Val Loss: 1.7984 | Val Acc: 46.40%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 46.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.57]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.38it/s]


Epoch 5/50 | Train Loss: 1.7050 | Val Loss: 1.6634 | Val Acc: 49.00%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 49.00%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=0.408]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.48it/s]


Epoch 6/50 | Train Loss: 1.6138 | Val Loss: 1.5600 | Val Acc: 51.40%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 51.40%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=5.39]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.56it/s]


Epoch 7/50 | Train Loss: 1.5663 | Val Loss: 1.6423 | Val Acc: 46.60%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=0.406]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.78it/s]


Epoch 8/50 | Train Loss: 1.5714 | Val Loss: 1.6233 | Val Acc: 47.60%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=0.395]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.58it/s]


Epoch 9/50 | Train Loss: 1.4799 | Val Loss: 1.5867 | Val Acc: 49.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.0205]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.22it/s]


Epoch 10/50 | Train Loss: 1.4486 | Val Loss: 1.3998 | Val Acc: 52.80%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 52.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.618]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.47it/s]


Epoch 11/50 | Train Loss: 1.4128 | Val Loss: 1.4573 | Val Acc: 54.80%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 54.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.508]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.43it/s]


Epoch 12/50 | Train Loss: 1.3869 | Val Loss: 1.4508 | Val Acc: 55.00%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 55.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.102]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.54it/s]


Epoch 13/50 | Train Loss: 1.3579 | Val Loss: 1.4631 | Val Acc: 55.40%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 55.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.0501]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.65it/s]


Epoch 14/50 | Train Loss: 1.3263 | Val Loss: 1.2659 | Val Acc: 59.40%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 59.40%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=0.285]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.57it/s]


Epoch 15/50 | Train Loss: 1.2931 | Val Loss: 1.3501 | Val Acc: 57.20%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.4]  
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.53it/s]


Epoch 16/50 | Train Loss: 1.2726 | Val Loss: 1.2967 | Val Acc: 59.60%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 59.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.01it/s, loss=0.662]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.70it/s]


Epoch 17/50 | Train Loss: 1.2426 | Val Loss: 1.2614 | Val Acc: 61.60%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 61.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.971]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.83it/s]


Epoch 18/50 | Train Loss: 1.2254 | Val Loss: 1.2842 | Val Acc: 60.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.09] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.78it/s]


Epoch 19/50 | Train Loss: 1.2221 | Val Loss: 1.2374 | Val Acc: 60.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.55] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.27it/s]


Epoch 20/50 | Train Loss: 1.1998 | Val Loss: 1.1908 | Val Acc: 63.60%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 63.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  2.99it/s, loss=0.692]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.57it/s]


Epoch 21/50 | Train Loss: 1.1750 | Val Loss: 1.1789 | Val Acc: 62.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.752]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.45it/s]


Epoch 22/50 | Train Loss: 1.1627 | Val Loss: 1.2066 | Val Acc: 62.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.131]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.76it/s]


Epoch 23/50 | Train Loss: 1.1259 | Val Loss: 1.1473 | Val Acc: 63.60%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.01it/s, loss=1.13] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.68it/s]


Epoch 24/50 | Train Loss: 1.0973 | Val Loss: 1.1822 | Val Acc: 60.60%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=1.87] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.66it/s]


Epoch 25/50 | Train Loss: 1.0695 | Val Loss: 1.3333 | Val Acc: 60.00%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.0578]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.46it/s]


Epoch 26/50 | Train Loss: 1.0945 | Val Loss: 1.1160 | Val Acc: 63.80%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 63.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.00118]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.35it/s]


Epoch 27/50 | Train Loss: 1.0372 | Val Loss: 1.1535 | Val Acc: 65.80%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 65.80%


Training: 100%|██████████| 403/403 [02:14<00:00,  3.00it/s, loss=0.0232]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.67it/s]


Epoch 28/50 | Train Loss: 1.0164 | Val Loss: 1.1452 | Val Acc: 64.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.216]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.57it/s]


Epoch 29/50 | Train Loss: 1.0035 | Val Loss: 1.1495 | Val Acc: 62.60%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.05it/s, loss=0.0628]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.55it/s]


Epoch 30/50 | Train Loss: 0.9795 | Val Loss: 1.0734 | Val Acc: 65.60%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=0.186]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.66it/s]


Epoch 31/50 | Train Loss: 0.9740 | Val Loss: 1.0936 | Val Acc: 67.20%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 67.20%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.05it/s, loss=0.857]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.56it/s]


Epoch 32/50 | Train Loss: 0.9427 | Val Loss: 1.0678 | Val Acc: 67.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=2.29] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.54it/s]


Epoch 33/50 | Train Loss: 0.9252 | Val Loss: 1.0855 | Val Acc: 67.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=1.55] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.23it/s]


Epoch 34/50 | Train Loss: 0.9480 | Val Loss: 1.1157 | Val Acc: 64.60%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.0969]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.30it/s]


Epoch 35/50 | Train Loss: 0.9159 | Val Loss: 1.0510 | Val Acc: 67.00%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.022]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.56it/s]


Epoch 36/50 | Train Loss: 0.8844 | Val Loss: 1.1231 | Val Acc: 64.60%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=1.25] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.29it/s]


Epoch 37/50 | Train Loss: 0.8542 | Val Loss: 1.0357 | Val Acc: 68.40%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 68.40%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=0.209]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.33it/s]


Epoch 38/50 | Train Loss: 0.8424 | Val Loss: 1.0749 | Val Acc: 66.40%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=0.0646]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.38it/s]


Epoch 39/50 | Train Loss: 0.8280 | Val Loss: 1.0671 | Val Acc: 67.80%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=0.311]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.64it/s]


Epoch 40/50 | Train Loss: 0.8126 | Val Loss: 1.1485 | Val Acc: 66.20%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=0.0107]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.64it/s]


Epoch 41/50 | Train Loss: 0.7915 | Val Loss: 1.0716 | Val Acc: 69.80%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 69.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=0.00355]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.57it/s]


Epoch 42/50 | Train Loss: 0.7849 | Val Loss: 1.0910 | Val Acc: 68.60%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.00444]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.57it/s]


Epoch 43/50 | Train Loss: 0.7633 | Val Loss: 1.1122 | Val Acc: 67.60%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.12] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.33it/s]


Epoch 44/50 | Train Loss: 0.7571 | Val Loss: 1.0998 | Val Acc: 69.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.000369]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.33it/s]


Epoch 45/50 | Train Loss: 0.7299 | Val Loss: 1.0035 | Val Acc: 67.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.0401]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.31it/s]


Epoch 46/50 | Train Loss: 0.7166 | Val Loss: 1.0206 | Val Acc: 70.00%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 70.00%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=0.393]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.65it/s]


Epoch 47/50 | Train Loss: 0.7015 | Val Loss: 1.0405 | Val Acc: 68.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=2.14] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.61it/s]


Epoch 48/50 | Train Loss: 0.6954 | Val Loss: 0.9669 | Val Acc: 70.40%
--> New best model saved to vit_pos_sine_best.pth with accuracy: 70.40%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=0.0375]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.60it/s]


Epoch 49/50 | Train Loss: 0.6886 | Val Loss: 1.0874 | Val Acc: 68.40%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=1.37] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.34it/s]


Epoch 50/50 | Train Loss: 0.6677 | Val Loss: 1.1144 | Val Acc: 67.80%

Training for sine model finished in 112.26 minutes.
Best validation accuracy: 70.40%

--- Evaluating best sine model on the TEST set ---
Using Sinusoidal Positional Embedding


Validating: 100%|██████████| 8/8 [00:01<00:00,  6.39it/s]


Final Test Accuracy for sine model: 70.20%

  STARTING EXPERIMENT: NONE POSITIONAL EMBEDDING

Not using any Positional Embedding
Model has 10.95M trainable parameters.
Starting training for None model...


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=0.411]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.65it/s]


Epoch 1/50 | Train Loss: 2.4311 | Val Loss: 2.1580 | Val Acc: 34.00%
--> New best model saved to vit_pos_none_best.pth with accuracy: 34.00%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.05it/s, loss=0.833]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.29it/s]


Epoch 2/50 | Train Loss: 2.0712 | Val Loss: 2.0692 | Val Acc: 36.80%
--> New best model saved to vit_pos_none_best.pth with accuracy: 36.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=0.445]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.22it/s]


Epoch 3/50 | Train Loss: 1.9377 | Val Loss: 1.8476 | Val Acc: 43.40%
--> New best model saved to vit_pos_none_best.pth with accuracy: 43.40%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=0.296]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.55it/s]


Epoch 4/50 | Train Loss: 1.8134 | Val Loss: 1.8324 | Val Acc: 42.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=5]   
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.60it/s]


Epoch 5/50 | Train Loss: 1.7351 | Val Loss: 1.7283 | Val Acc: 46.00%
--> New best model saved to vit_pos_none_best.pth with accuracy: 46.00%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=0.0473]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.59it/s]


Epoch 6/50 | Train Loss: 1.6761 | Val Loss: 1.6818 | Val Acc: 46.00%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=1.03]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.45it/s]


Epoch 7/50 | Train Loss: 1.6314 | Val Loss: 1.6349 | Val Acc: 50.40%
--> New best model saved to vit_pos_none_best.pth with accuracy: 50.40%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=0.104]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.73it/s]


Epoch 8/50 | Train Loss: 1.5800 | Val Loss: 1.5940 | Val Acc: 49.80%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=0.292]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.52it/s]


Epoch 9/50 | Train Loss: 1.5332 | Val Loss: 1.6290 | Val Acc: 47.00%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.06it/s, loss=0.733]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.37it/s]


Epoch 10/50 | Train Loss: 1.5047 | Val Loss: 1.5582 | Val Acc: 49.00%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.05it/s, loss=0.602]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.53it/s]


Epoch 11/50 | Train Loss: 1.5128 | Val Loss: 1.4567 | Val Acc: 52.60%
--> New best model saved to vit_pos_none_best.pth with accuracy: 52.60%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=0.6] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.73it/s]


Epoch 12/50 | Train Loss: 1.4384 | Val Loss: 1.4021 | Val Acc: 53.20%
--> New best model saved to vit_pos_none_best.pth with accuracy: 53.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=0.59] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.32it/s]


Epoch 13/50 | Train Loss: 1.4155 | Val Loss: 1.4569 | Val Acc: 53.80%
--> New best model saved to vit_pos_none_best.pth with accuracy: 53.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=3.02]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.69it/s]


Epoch 14/50 | Train Loss: 1.3914 | Val Loss: 1.3623 | Val Acc: 54.60%
--> New best model saved to vit_pos_none_best.pth with accuracy: 54.60%


Training: 100%|██████████| 403/403 [02:11<00:00,  3.05it/s, loss=0.966]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.52it/s]


Epoch 15/50 | Train Loss: 1.3524 | Val Loss: 1.4999 | Val Acc: 53.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.05it/s, loss=0.535]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.51it/s]


Epoch 16/50 | Train Loss: 1.3466 | Val Loss: 1.3565 | Val Acc: 57.20%
--> New best model saved to vit_pos_none_best.pth with accuracy: 57.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=1.74] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.54it/s]


Epoch 17/50 | Train Loss: 1.3184 | Val Loss: 1.3739 | Val Acc: 52.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.828]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.30it/s]


Epoch 18/50 | Train Loss: 1.3186 | Val Loss: 1.3161 | Val Acc: 59.20%
--> New best model saved to vit_pos_none_best.pth with accuracy: 59.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.176]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.59it/s]


Epoch 19/50 | Train Loss: 1.2833 | Val Loss: 1.3601 | Val Acc: 56.60%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=1.61] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.44it/s]


Epoch 20/50 | Train Loss: 1.2448 | Val Loss: 1.3769 | Val Acc: 56.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.198]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.39it/s]


Epoch 21/50 | Train Loss: 1.2539 | Val Loss: 1.4022 | Val Acc: 54.60%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=1.22] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.32it/s]


Epoch 22/50 | Train Loss: 1.2086 | Val Loss: 1.3976 | Val Acc: 54.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.114]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.30it/s]


Epoch 23/50 | Train Loss: 1.2121 | Val Loss: 1.2125 | Val Acc: 60.40%
--> New best model saved to vit_pos_none_best.pth with accuracy: 60.40%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=1.62] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.53it/s]


Epoch 24/50 | Train Loss: 1.1613 | Val Loss: 1.3152 | Val Acc: 57.40%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.113]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.51it/s]


Epoch 25/50 | Train Loss: 1.1322 | Val Loss: 1.3093 | Val Acc: 59.40%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.792]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.16it/s]


Epoch 26/50 | Train Loss: 1.1299 | Val Loss: 1.2411 | Val Acc: 61.80%
--> New best model saved to vit_pos_none_best.pth with accuracy: 61.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=2.55] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.70it/s]


Epoch 27/50 | Train Loss: 1.1259 | Val Loss: 1.2569 | Val Acc: 60.00%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=0.141]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.59it/s]


Epoch 28/50 | Train Loss: 1.1095 | Val Loss: 1.2188 | Val Acc: 62.20%
--> New best model saved to vit_pos_none_best.pth with accuracy: 62.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.205]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.47it/s]


Epoch 29/50 | Train Loss: 1.0688 | Val Loss: 1.2352 | Val Acc: 60.60%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.155]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.88it/s]


Epoch 30/50 | Train Loss: 1.0561 | Val Loss: 1.1923 | Val Acc: 63.40%
--> New best model saved to vit_pos_none_best.pth with accuracy: 63.40%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.00767]
Validating: 100%|██████████| 8/8 [00:01<00:00,  5.98it/s]


Epoch 31/50 | Train Loss: 1.0246 | Val Loss: 1.2095 | Val Acc: 62.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.04it/s, loss=0.287]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.31it/s]


Epoch 32/50 | Train Loss: 1.0102 | Val Loss: 1.2025 | Val Acc: 63.40%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.00485]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.42it/s]


Epoch 33/50 | Train Loss: 1.0003 | Val Loss: 1.1498 | Val Acc: 62.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.00188]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.41it/s]


Epoch 34/50 | Train Loss: 0.9686 | Val Loss: 1.1459 | Val Acc: 63.40%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.0108]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.31it/s]


Epoch 35/50 | Train Loss: 0.9565 | Val Loss: 1.2158 | Val Acc: 62.80%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=0.00345]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.58it/s]


Epoch 36/50 | Train Loss: 0.9483 | Val Loss: 1.1516 | Val Acc: 63.80%
--> New best model saved to vit_pos_none_best.pth with accuracy: 63.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.0139]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.51it/s]


Epoch 37/50 | Train Loss: 0.9198 | Val Loss: 1.1367 | Val Acc: 64.20%
--> New best model saved to vit_pos_none_best.pth with accuracy: 64.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.552]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.24it/s]


Epoch 38/50 | Train Loss: 0.9030 | Val Loss: 1.1897 | Val Acc: 60.80%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.407]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.43it/s]


Epoch 39/50 | Train Loss: 0.8892 | Val Loss: 1.1090 | Val Acc: 66.40%
--> New best model saved to vit_pos_none_best.pth with accuracy: 66.40%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.196]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.28it/s]


Epoch 40/50 | Train Loss: 0.8792 | Val Loss: 1.1524 | Val Acc: 66.00%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=0.00483]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.70it/s]


Epoch 41/50 | Train Loss: 0.8690 | Val Loss: 1.1891 | Val Acc: 62.80%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=0.0758]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.56it/s]


Epoch 42/50 | Train Loss: 0.8397 | Val Loss: 1.2073 | Val Acc: 62.00%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.0388]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.25it/s]


Epoch 43/50 | Train Loss: 0.8136 | Val Loss: 1.2005 | Val Acc: 62.80%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=2.81] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.60it/s]


Epoch 44/50 | Train Loss: 0.8161 | Val Loss: 1.0820 | Val Acc: 67.00%
--> New best model saved to vit_pos_none_best.pth with accuracy: 67.00%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=0.104]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.22it/s]


Epoch 45/50 | Train Loss: 0.7863 | Val Loss: 1.1716 | Val Acc: 64.60%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=5.25] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.53it/s]


Epoch 46/50 | Train Loss: 0.7693 | Val Loss: 1.0783 | Val Acc: 67.60%
--> New best model saved to vit_pos_none_best.pth with accuracy: 67.60%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=2.68] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.33it/s]


Epoch 47/50 | Train Loss: 0.7679 | Val Loss: 1.0526 | Val Acc: 67.40%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=0.0273]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.66it/s]


Epoch 48/50 | Train Loss: 0.7683 | Val Loss: 1.0801 | Val Acc: 66.20%


Training: 100%|██████████| 403/403 [02:12<00:00,  3.03it/s, loss=0.122]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.19it/s]


Epoch 49/50 | Train Loss: 0.7344 | Val Loss: 1.1049 | Val Acc: 68.00%
--> New best model saved to vit_pos_none_best.pth with accuracy: 68.00%


Training: 100%|██████████| 403/403 [02:13<00:00,  3.03it/s, loss=0.000245]
Validating: 100%|██████████| 8/8 [00:01<00:00,  6.41it/s]


Epoch 50/50 | Train Loss: 0.7187 | Val Loss: 1.0915 | Val Acc: 67.60%

Training for None model finished in 111.66 minutes.
Best validation accuracy: 68.00%

--- Evaluating best None model on the TEST set ---
Not using any Positional Embedding


Validating: 100%|██████████| 8/8 [00:01<00:00,  6.29it/s]

Final Test Accuracy for None model: 69.40%


  EXPERIMENT SUMMARY: EFFECT OF POSITIONAL EMBEDDING (Heads=4)
Positional Embedding      | Best Val Acc (%)     | Final Test Acc (%)   | Train Time (min)    
------------------------------------------------------------------------------------------
learnable                 | 70.40                | 68.80                | 113.00              
sine                      | 70.40                | 70.20                | 112.26              
None                      | 68.00                | 69.40                | 111.66              





## FCNN

In [10]:
class FCFNNClassifier(nn.Module):
    def __init__(self, img_size=224, in_channels=3, num_classes=20):
        super(FCFNNClassifier, self).__init__()
        input_features = in_channels * img_size * img_size
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_features, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)

print("\n--- Starting FCFNN Experiment ---")
fcfnn_model = FCFNNClassifier(img_size=IMAGE_SIZE, in_channels=NUM_CHANNELS, num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(fcfnn_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

total_params = sum(p.numel() for p in fcfnn_model.parameters() if p.requires_grad)
print(f"FCFNN Model - Total trainable parameters: {total_params / 1e6:.2f}M")

best_val_acc = 0.0
fcfnn_history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

print("Starting FCFNN training...")
start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    
    train_loss = train_one_epoch(fcfnn_model, train_loader, criterion, optimizer, scaler, device)
    val_loss, val_acc = validate(fcfnn_model, val_loader, criterion, device)
    
    fcfnn_history['train_loss'].append(train_loss)
    fcfnn_history['val_loss'].append(val_loss)
    fcfnn_history['val_acc'].append(val_acc)
    
    epoch_duration = time.time() - epoch_start_time
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Time: {epoch_duration:.2f}s")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(fcfnn_model.state_dict(), 'fcfnn_best_model.pth')
        print(f"New best FCFNN model saved with accuracy: {best_val_acc:.2f}%")

total_training_time = time.time() - start_time
print(f"\nFCFNN training finished in {total_training_time/60:.2f} minutes.")
print(f"FCFNN best validation accuracy: {best_val_acc:.2f}%")

print("\n--- Evaluating best FCFNN model on the final test set ---")
final_fcfnn_model = FCFNNClassifier(img_size=IMAGE_SIZE, in_channels=NUM_CHANNELS, num_classes=NUM_CLASSES).to(device)
final_fcfnn_model.load_state_dict(torch.load('fcfnn_best_model.pth'))
fcfnn_test_loss, fcfnn_test_acc = validate(final_fcfnn_model, test_loader, criterion, device)
print(f"\nFinal FCFNN Test Accuracy: {fcfnn_test_acc:.2f}%")
print(f"Final FCFNN Test Loss: {fcfnn_test_loss:.4f}")


--- Starting FCFNN Experiment ---
FCFNN Model - Total trainable parameters: 154.68M
Starting FCFNN training...


Training: 100%|██████████| 403/403 [00:44<00:00,  9.02it/s, loss=2.99]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.33it/s]


Epoch 1/50 | Train Loss: 4.5035 | Val Loss: 2.9220 | Val Acc: 10.20% | Time: 45.35s
New best FCFNN model saved with accuracy: 10.20%


Training: 100%|██████████| 403/403 [00:44<00:00,  8.96it/s, loss=2.96]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.00it/s]


Epoch 2/50 | Train Loss: 2.9744 | Val Loss: 2.8780 | Val Acc: 11.60% | Time: 45.64s
New best FCFNN model saved with accuracy: 11.60%


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.96]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.31it/s]


Epoch 3/50 | Train Loss: 2.9379 | Val Loss: 2.8110 | Val Acc: 15.60% | Time: 45.59s
New best FCFNN model saved with accuracy: 15.60%


Training: 100%|██████████| 403/403 [00:45<00:00,  8.95it/s, loss=2.97]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.42it/s]


Epoch 4/50 | Train Loss: 2.9056 | Val Loss: 2.7703 | Val Acc: 16.00% | Time: 45.66s
New best FCFNN model saved with accuracy: 16.00%


Training: 100%|██████████| 403/403 [00:44<00:00,  8.99it/s, loss=2.75]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.08it/s]


Epoch 5/50 | Train Loss: 2.8787 | Val Loss: 2.7532 | Val Acc: 16.60% | Time: 45.50s
New best FCFNN model saved with accuracy: 16.60%


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.75]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.59it/s]


Epoch 6/50 | Train Loss: 2.8588 | Val Loss: 2.7043 | Val Acc: 18.00% | Time: 45.53s
New best FCFNN model saved with accuracy: 18.00%


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.2] 
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.37it/s]


Epoch 7/50 | Train Loss: 2.8482 | Val Loss: 2.7227 | Val Acc: 15.40% | Time: 45.52s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.89]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.16it/s]


Epoch 8/50 | Train Loss: 2.8389 | Val Loss: 2.6991 | Val Acc: 19.40% | Time: 45.60s
New best FCFNN model saved with accuracy: 19.40%


Training: 100%|██████████| 403/403 [00:44<00:00,  8.99it/s, loss=2.86]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.62it/s]


Epoch 9/50 | Train Loss: 2.8313 | Val Loss: 2.6896 | Val Acc: 20.20% | Time: 45.50s
New best FCFNN model saved with accuracy: 20.20%


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.56]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.11it/s]


Epoch 10/50 | Train Loss: 2.8293 | Val Loss: 2.6877 | Val Acc: 20.80% | Time: 45.59s
New best FCFNN model saved with accuracy: 20.80%


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.88]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.68it/s]


Epoch 11/50 | Train Loss: 2.8285 | Val Loss: 2.6864 | Val Acc: 18.40% | Time: 45.62s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.89]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.99it/s]


Epoch 12/50 | Train Loss: 2.8215 | Val Loss: 2.6874 | Val Acc: 16.60% | Time: 45.60s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.96]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.76it/s]


Epoch 13/50 | Train Loss: 2.8267 | Val Loss: 2.7033 | Val Acc: 18.20% | Time: 45.59s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.95]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.05it/s]


Epoch 14/50 | Train Loss: 2.8250 | Val Loss: 2.6695 | Val Acc: 19.00% | Time: 45.56s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.81]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.64it/s]


Epoch 15/50 | Train Loss: 2.8298 | Val Loss: 2.6844 | Val Acc: 16.20% | Time: 45.64s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=3.08]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.75it/s]


Epoch 16/50 | Train Loss: 2.8315 | Val Loss: 2.6890 | Val Acc: 16.00% | Time: 45.60s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.96it/s, loss=2.98]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.48it/s]


Epoch 17/50 | Train Loss: 2.8278 | Val Loss: 2.6882 | Val Acc: 19.40% | Time: 45.66s


Training: 100%|██████████| 403/403 [00:44<00:00,  9.00it/s, loss=1.55]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.01it/s]


Epoch 18/50 | Train Loss: 2.8316 | Val Loss: 2.6928 | Val Acc: 17.60% | Time: 45.47s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.96it/s, loss=3.55]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.30it/s]


Epoch 19/50 | Train Loss: 2.8252 | Val Loss: 2.7086 | Val Acc: 17.40% | Time: 45.63s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.74]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.69it/s]


Epoch 20/50 | Train Loss: 2.8337 | Val Loss: 2.6939 | Val Acc: 18.60% | Time: 45.62s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.81]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.87it/s]


Epoch 21/50 | Train Loss: 2.8269 | Val Loss: 2.6858 | Val Acc: 16.00% | Time: 45.58s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=0.615]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.35it/s]


Epoch 22/50 | Train Loss: 2.8309 | Val Loss: 2.6851 | Val Acc: 18.80% | Time: 45.52s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=3.1] 
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.57it/s]


Epoch 23/50 | Train Loss: 2.8372 | Val Loss: 2.7041 | Val Acc: 18.60% | Time: 45.61s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.79]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.09it/s]


Epoch 24/50 | Train Loss: 2.8337 | Val Loss: 2.7171 | Val Acc: 19.00% | Time: 45.58s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.98]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.50it/s]


Epoch 25/50 | Train Loss: 2.8329 | Val Loss: 2.6639 | Val Acc: 20.40% | Time: 45.54s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.13]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.97it/s]


Epoch 26/50 | Train Loss: 2.8401 | Val Loss: 2.7009 | Val Acc: 19.60% | Time: 45.60s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.07]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.37it/s]


Epoch 27/50 | Train Loss: 2.8471 | Val Loss: 2.6592 | Val Acc: 20.20% | Time: 45.56s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.61]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.82it/s]


Epoch 28/50 | Train Loss: 2.8376 | Val Loss: 2.7097 | Val Acc: 15.60% | Time: 45.57s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.86]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.49it/s]


Epoch 29/50 | Train Loss: 2.8430 | Val Loss: 2.6820 | Val Acc: 16.20% | Time: 45.58s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.96]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.84it/s]


Epoch 30/50 | Train Loss: 2.8453 | Val Loss: 2.6940 | Val Acc: 18.20% | Time: 45.56s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=3.11]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.71it/s]


Epoch 31/50 | Train Loss: 2.8523 | Val Loss: 2.6622 | Val Acc: 19.00% | Time: 45.55s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.99it/s, loss=2.6] 
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.63it/s]


Epoch 32/50 | Train Loss: 2.8498 | Val Loss: 2.7324 | Val Acc: 14.20% | Time: 45.48s


Training: 100%|██████████| 403/403 [00:45<00:00,  8.95it/s, loss=3.31]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.39it/s]


Epoch 33/50 | Train Loss: 2.8684 | Val Loss: 2.7991 | Val Acc: 11.80% | Time: 45.69s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.82]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.66it/s]


Epoch 34/50 | Train Loss: 2.8657 | Val Loss: 2.7134 | Val Acc: 16.80% | Time: 45.58s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=1.36]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.21it/s]


Epoch 35/50 | Train Loss: 2.8479 | Val Loss: 2.7144 | Val Acc: 15.60% | Time: 45.55s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.96]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.25it/s]


Epoch 36/50 | Train Loss: 2.8521 | Val Loss: 2.7216 | Val Acc: 14.40% | Time: 45.54s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=1.56]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.46it/s]


Epoch 37/50 | Train Loss: 2.8690 | Val Loss: 2.6689 | Val Acc: 19.60% | Time: 45.51s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=3.18]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.27it/s]


Epoch 38/50 | Train Loss: 2.8742 | Val Loss: 2.7775 | Val Acc: 15.40% | Time: 45.53s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=3.1] 
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.28it/s]


Epoch 39/50 | Train Loss: 2.8625 | Val Loss: 2.7348 | Val Acc: 14.80% | Time: 45.61s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=3.25]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.11it/s]


Epoch 40/50 | Train Loss: 2.8611 | Val Loss: 2.7443 | Val Acc: 14.40% | Time: 45.57s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.99it/s, loss=2.98]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.89it/s]


Epoch 41/50 | Train Loss: 2.8635 | Val Loss: 2.7696 | Val Acc: 14.80% | Time: 45.53s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.97]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.31it/s]


Epoch 42/50 | Train Loss: 2.8737 | Val Loss: 2.7523 | Val Acc: 15.40% | Time: 45.57s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.79]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.93it/s]


Epoch 43/50 | Train Loss: 2.8555 | Val Loss: 2.7598 | Val Acc: 15.40% | Time: 45.54s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=3.03]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.97it/s]


Epoch 44/50 | Train Loss: 2.8601 | Val Loss: 2.7695 | Val Acc: 12.60% | Time: 45.53s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.97it/s, loss=2.93]
Validating: 100%|██████████| 8/8 [00:00<00:00, 11.99it/s]


Epoch 45/50 | Train Loss: 2.8591 | Val Loss: 2.8208 | Val Acc: 9.60% | Time: 45.61s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.96it/s, loss=2.99]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.06it/s]


Epoch 46/50 | Train Loss: 2.8612 | Val Loss: 2.7561 | Val Acc: 14.80% | Time: 45.62s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=1.68]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.13it/s]


Epoch 47/50 | Train Loss: 2.8580 | Val Loss: 2.7272 | Val Acc: 17.00% | Time: 45.55s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.68]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.33it/s]


Epoch 48/50 | Train Loss: 2.8526 | Val Loss: 2.6916 | Val Acc: 15.80% | Time: 45.55s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=3.13]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.01it/s]


Epoch 49/50 | Train Loss: 2.8580 | Val Loss: 2.7460 | Val Acc: 14.00% | Time: 45.53s


Training: 100%|██████████| 403/403 [00:44<00:00,  8.98it/s, loss=2.85]
Validating: 100%|██████████| 8/8 [00:00<00:00, 12.09it/s]


Epoch 50/50 | Train Loss: 2.8554 | Val Loss: 2.7313 | Val Acc: 13.00% | Time: 45.57s

FCFNN training finished in 38.82 minutes.
FCFNN best validation accuracy: 20.80%

--- Evaluating best FCFNN model on the final test set ---


Validating: 100%|██████████| 8/8 [00:00<00:00, 11.57it/s]


Final FCFNN Test Accuracy: 21.40%
Final FCFNN Test Loss: 2.6615





## CNN

In [11]:
class CNNClassifier(nn.Module):
    def __init__(self, in_channels=3, num_classes=20):
        super(CNNClassifier, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 7 * 7, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        return x

print("\n--- Starting CNN Experiment ---")
cnn_model = CNNClassifier(in_channels=NUM_CHANNELS, num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(cnn_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

total_params = sum(p.numel() for p in cnn_model.parameters() if p.requires_grad)
print(f"CNN Model - Total trainable parameters: {total_params / 1e6:.2f}M")

best_val_acc = 0.0
cnn_history = {'train_loss': [], 'val_loss': [], 'val_acc': []}

print("Starting CNN training...")
start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    
    train_loss = train_one_epoch(cnn_model, train_loader, criterion, optimizer, scaler, device)
    val_loss, val_acc = validate(cnn_model, val_loader, criterion, device)
    
    cnn_history['train_loss'].append(train_loss)
    cnn_history['val_loss'].append(val_loss)
    cnn_history['val_acc'].append(val_acc)
    
    epoch_duration = time.time() - epoch_start_time
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | Time: {epoch_duration:.2f}s")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(cnn_model.state_dict(), 'cnn_best_model.pth')
        print(f"New best CNN model saved with accuracy: {best_val_acc:.2f}%")

total_training_time = time.time() - start_time
print(f"\nCNN training finished in {total_training_time/60:.2f} minutes.")
print(f"CNN best validation accuracy: {best_val_acc:.2f}%")

print("\n--- Evaluating best CNN model on the final test set ---")
final_cnn_model = CNNClassifier(in_channels=NUM_CHANNELS, num_classes=NUM_CLASSES).to(device)
final_cnn_model.load_state_dict(torch.load('cnn_best_model.pth'))
cnn_test_loss, cnn_test_acc = validate(final_cnn_model, test_loader, criterion, device)
print(f"\nFinal CNN Test Accuracy: {cnn_test_acc:.2f}%")
print(f"Final CNN Test Loss: {cnn_test_loss:.4f}")


--- Starting CNN Experiment ---
CNN Model - Total trainable parameters: 27.79M
Starting CNN training...


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=3.33]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.15it/s]


Epoch 1/50 | Train Loss: 2.5994 | Val Loss: 2.0940 | Val Acc: 33.20% | Time: 75.25s
New best CNN model saved with accuracy: 33.20%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=2.61]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.73it/s]


Epoch 2/50 | Train Loss: 2.1537 | Val Loss: 1.9239 | Val Acc: 38.80% | Time: 75.25s
New best CNN model saved with accuracy: 38.80%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=0.827]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.53it/s]


Epoch 3/50 | Train Loss: 1.9332 | Val Loss: 1.6732 | Val Acc: 46.00% | Time: 75.48s
New best CNN model saved with accuracy: 46.00%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=7.73]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.67it/s]


Epoch 4/50 | Train Loss: 1.7676 | Val Loss: 1.4404 | Val Acc: 55.80% | Time: 75.29s
New best CNN model saved with accuracy: 55.80%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=2.39]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.23it/s]


Epoch 5/50 | Train Loss: 1.6315 | Val Loss: 1.3541 | Val Acc: 56.60% | Time: 75.21s
New best CNN model saved with accuracy: 56.60%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=7.93]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.82it/s]


Epoch 6/50 | Train Loss: 1.5183 | Val Loss: 1.2298 | Val Acc: 61.20% | Time: 75.23s
New best CNN model saved with accuracy: 61.20%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=0.261]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.84it/s]


Epoch 7/50 | Train Loss: 1.4377 | Val Loss: 1.2603 | Val Acc: 60.40% | Time: 75.39s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=0.49] 
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.05it/s]


Epoch 8/50 | Train Loss: 1.3370 | Val Loss: 1.0803 | Val Acc: 65.00% | Time: 75.41s
New best CNN model saved with accuracy: 65.00%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=2.53] 
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.21it/s]


Epoch 9/50 | Train Loss: 1.2932 | Val Loss: 1.2109 | Val Acc: 61.80% | Time: 75.23s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=0.139]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.68it/s]


Epoch 10/50 | Train Loss: 1.2453 | Val Loss: 1.0128 | Val Acc: 66.20% | Time: 75.34s
New best CNN model saved with accuracy: 66.20%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=0.114]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.68it/s]


Epoch 11/50 | Train Loss: 1.1834 | Val Loss: 0.9261 | Val Acc: 69.60% | Time: 75.30s
New best CNN model saved with accuracy: 69.60%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=0.606]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.91it/s]


Epoch 12/50 | Train Loss: 1.1456 | Val Loss: 0.9100 | Val Acc: 71.40% | Time: 75.24s
New best CNN model saved with accuracy: 71.40%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=0.00672]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.92it/s]


Epoch 13/50 | Train Loss: 1.1185 | Val Loss: 0.8719 | Val Acc: 72.00% | Time: 75.29s
New best CNN model saved with accuracy: 72.00%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=0.00122]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.16it/s]


Epoch 14/50 | Train Loss: 1.0747 | Val Loss: 0.8258 | Val Acc: 74.40% | Time: 75.20s
New best CNN model saved with accuracy: 74.40%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=0.016]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.07it/s]


Epoch 15/50 | Train Loss: 1.0343 | Val Loss: 0.7460 | Val Acc: 74.60% | Time: 75.28s
New best CNN model saved with accuracy: 74.60%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=1.66] 
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.13it/s]


Epoch 16/50 | Train Loss: 1.0058 | Val Loss: 0.8523 | Val Acc: 73.20% | Time: 75.26s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=0.0138]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.21it/s]


Epoch 17/50 | Train Loss: 0.9929 | Val Loss: 0.7060 | Val Acc: 76.80% | Time: 75.29s
New best CNN model saved with accuracy: 76.80%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=1.03] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.94it/s]


Epoch 18/50 | Train Loss: 0.9418 | Val Loss: 0.7521 | Val Acc: 75.60% | Time: 75.30s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=1.06] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.85it/s]


Epoch 19/50 | Train Loss: 0.9338 | Val Loss: 0.8058 | Val Acc: 75.80% | Time: 75.32s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.43it/s, loss=1.01] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.72it/s]


Epoch 20/50 | Train Loss: 0.9198 | Val Loss: 0.7026 | Val Acc: 77.60% | Time: 75.32s
New best CNN model saved with accuracy: 77.60%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=0.00166]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.81it/s]


Epoch 21/50 | Train Loss: 0.8810 | Val Loss: 0.7697 | Val Acc: 75.40% | Time: 75.40s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=2.45] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.60it/s]


Epoch 22/50 | Train Loss: 0.8654 | Val Loss: 0.7403 | Val Acc: 76.60% | Time: 75.46s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.40it/s, loss=0.000934]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.04it/s]


Epoch 23/50 | Train Loss: 0.8850 | Val Loss: 0.7192 | Val Acc: 77.60% | Time: 75.58s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0.00848]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.03it/s]


Epoch 24/50 | Train Loss: 0.8301 | Val Loss: 0.6660 | Val Acc: 78.40% | Time: 75.55s
New best CNN model saved with accuracy: 78.40%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=1.31e-6]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.37it/s]


Epoch 25/50 | Train Loss: 0.8102 | Val Loss: 0.6544 | Val Acc: 80.00% | Time: 75.53s
New best CNN model saved with accuracy: 80.00%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.40it/s, loss=0.448]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.35it/s]


Epoch 26/50 | Train Loss: 0.8023 | Val Loss: 0.6839 | Val Acc: 77.80% | Time: 75.53s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0.0417]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.76it/s]


Epoch 27/50 | Train Loss: 0.7911 | Val Loss: 0.5873 | Val Acc: 81.00% | Time: 75.50s
New best CNN model saved with accuracy: 81.00%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0.381]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.62it/s]


Epoch 28/50 | Train Loss: 0.7542 | Val Loss: 0.6213 | Val Acc: 79.60% | Time: 75.50s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0.135]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.02it/s]


Epoch 29/50 | Train Loss: 0.7602 | Val Loss: 0.6303 | Val Acc: 81.00% | Time: 75.50s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0]    
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.05it/s]


Epoch 30/50 | Train Loss: 0.7490 | Val Loss: 0.5764 | Val Acc: 80.60% | Time: 75.51s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0]    
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.14it/s]


Epoch 31/50 | Train Loss: 0.7257 | Val Loss: 0.5609 | Val Acc: 81.40% | Time: 75.49s
New best CNN model saved with accuracy: 81.40%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=1.25] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.74it/s]


Epoch 32/50 | Train Loss: 0.7248 | Val Loss: 0.7483 | Val Acc: 78.60% | Time: 75.53s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0.0037]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.79it/s]


Epoch 33/50 | Train Loss: 0.7324 | Val Loss: 0.6198 | Val Acc: 80.80% | Time: 75.53s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0.0421]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.91it/s]


Epoch 34/50 | Train Loss: 0.7002 | Val Loss: 0.5790 | Val Acc: 81.20% | Time: 75.45s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=6.5]  
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.39it/s]


Epoch 35/50 | Train Loss: 0.6811 | Val Loss: 0.6208 | Val Acc: 79.80% | Time: 75.50s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=1.08] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.58it/s]


Epoch 36/50 | Train Loss: 0.6792 | Val Loss: 0.6597 | Val Acc: 80.80% | Time: 75.56s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=0.0118]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.21it/s]


Epoch 37/50 | Train Loss: 0.6806 | Val Loss: 0.6076 | Val Acc: 81.00% | Time: 75.36s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=9.54e-6]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.58it/s]


Epoch 38/50 | Train Loss: 0.6644 | Val Loss: 0.5598 | Val Acc: 82.60% | Time: 75.43s
New best CNN model saved with accuracy: 82.60%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0.249]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.82it/s]


Epoch 39/50 | Train Loss: 0.6390 | Val Loss: 0.6188 | Val Acc: 81.20% | Time: 75.46s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=4.36e-5]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.61it/s]


Epoch 40/50 | Train Loss: 0.6499 | Val Loss: 0.5945 | Val Acc: 83.40% | Time: 75.61s
New best CNN model saved with accuracy: 83.40%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=2.63] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.64it/s]


Epoch 41/50 | Train Loss: 0.6382 | Val Loss: 0.5399 | Val Acc: 81.60% | Time: 75.52s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0.916]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.74it/s]


Epoch 42/50 | Train Loss: 0.6295 | Val Loss: 0.5409 | Val Acc: 82.60% | Time: 75.60s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=3.22] 
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.04it/s]


Epoch 43/50 | Train Loss: 0.6255 | Val Loss: 0.6882 | Val Acc: 82.60% | Time: 75.43s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=3.28] 
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.87it/s]


Epoch 44/50 | Train Loss: 0.6295 | Val Loss: 0.6571 | Val Acc: 80.40% | Time: 75.50s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=1.06] 
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.02it/s]


Epoch 45/50 | Train Loss: 0.6300 | Val Loss: 0.5617 | Val Acc: 83.60% | Time: 75.53s
New best CNN model saved with accuracy: 83.60%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=0.00742]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.80it/s]


Epoch 46/50 | Train Loss: 0.6427 | Val Loss: 0.5917 | Val Acc: 84.20% | Time: 75.51s
New best CNN model saved with accuracy: 84.20%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.42it/s, loss=0.0184]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.06it/s]


Epoch 47/50 | Train Loss: 0.6115 | Val Loss: 0.5796 | Val Acc: 83.80% | Time: 75.42s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=1.68] 
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.17it/s]


Epoch 48/50 | Train Loss: 0.5865 | Val Loss: 0.5490 | Val Acc: 85.00% | Time: 75.51s
New best CNN model saved with accuracy: 85.00%


Training: 100%|██████████| 403/403 [01:14<00:00,  5.41it/s, loss=1.19e-7]
Validating: 100%|██████████| 8/8 [00:01<00:00,  7.98it/s]


Epoch 49/50 | Train Loss: 0.6011 | Val Loss: 0.5790 | Val Acc: 83.20% | Time: 75.53s


Training: 100%|██████████| 403/403 [01:14<00:00,  5.40it/s, loss=5.13e-6]
Validating: 100%|██████████| 8/8 [00:00<00:00,  8.17it/s]


Epoch 50/50 | Train Loss: 0.5695 | Val Loss: 0.4994 | Val Acc: 85.20% | Time: 75.60s
New best CNN model saved with accuracy: 85.20%

CNN training finished in 63.28 minutes.
CNN best validation accuracy: 85.20%

--- Evaluating best CNN model on the final test set ---


Validating: 100%|██████████| 8/8 [00:01<00:00,  7.68it/s]


Final CNN Test Accuracy: 84.00%
Final CNN Test Loss: 0.5392



