# Building a base Vision Transformer for Multi-class Classification of Breast Cancer Histopathological Dataset(BreaKHis Dataset)



*   In this notebook, we've built a ViT, whose parameters are as close to that described in the original Vision Transformers paper titled, "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale".
*   Since the dataset can be used for both binary and multi-class classification, for this project, we've narrowed it down to just multi-class classification.
*   The dataset also captures the tumours at multiple zoom levels(40x, 100x, 200x and 400x), to remove unnecessary complexity, we've just focused on the 400x zoom images.



In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Importing libraries and setting up device agnostic code

In [2]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

### Setting up transforms and dataloader

In [3]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import v2

FLIP_PROBABILITY = 0.1

data_transform = v2.Compose([
    v2.Resize(size=(224,224)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=(-10,10)),
    v2.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
    v2.RandomAffine(degrees=0, translate=(0.05, 0.05)),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Add normalization
])



### Getting dataset and class labels

In [4]:
train_data = datasets.ImageFolder(root="/content/drive/MyDrive/dataset_v3/train",
                                  transform=data_transform,
                                  target_transform=None)
test_data = datasets.ImageFolder(root="/content/drive/MyDrive/dataset_v3/test",
                                  transform=data_transform)
class_names = train_data.classes

### Preprocessing the dataset

The dataset is imbalanced, as the training and test examples for 'ductal_carcinoma' type of tumour are ~5x that of other cases.

So, we've randomly reduced the samples to come on line with the numbers from other cases so it doesn't imbalance the model.

In [6]:
from sklearn.utils import resample

def balance_dataset(dataset, size):
    # Separate data by class
    class_data = {i: [] for i in range(len(class_names))}  # For 8 classes

    # Group data by class
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        class_data[label].append(idx)

    # Randomly sample indices from each class
    balanced_indices = []
    for class_idx, indices in class_data.items():
        # If class has more samples than min_size, downsample it
        if len(indices) > size:
            balanced_indices.extend(np.random.choice(indices, size=size, replace=False))
        else:
            balanced_indices.extend(indices)

    # Create a subset dataset
    from torch.utils.data import Subset
    balanced_dataset = Subset(dataset, balanced_indices)

    return balanced_dataset

balanced_train_data = balance_dataset(train_data, 200)

balanced_test_data = balance_dataset(test_data, 50)

### Loading up the DataLoader

In [7]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16

train_dataloader = DataLoader(balanced_train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(balanced_test_data, batch_size=BATCH_SIZE, shuffle=False)

### Setting up transforms

In [8]:
train_transform = v2.Compose([
    v2.Resize(size=(224,224)),
    v2.TrivialAugmentWide(num_magnitude_bins=5),
    v2.ToTensor()
])

test_transform = v2.Compose([
    v2.Resize(size=(224,224)),
    v2.ToTensor()
])



## Implementing the Vision Transformer

### Defining the Vision Transformer

In [9]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)  # (B, E, H', W')
        x = x.flatten(2)  # (B, E, N)
        x = x.transpose(1, 2)  # (B, N, E)
        return x

class Attention(nn.Module):
    def __init__(self, dim, n_heads=12, qkv_bias=True, attn_drop=0.1, proj_drop=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.scale = (dim // n_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, C // self.n_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, drop=0.1):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4., qkv_bias=True, drop=0.1, attn_drop=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dim, drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=len(class_names),
                 embed_dim=768, depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0.1, attn_drop_rate=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
        torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.pos_drop = nn.Dropout(drop_rate)

        self.blocks = nn.Sequential(*[
            Block(embed_dim, n_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_token = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.blocks(x)
        x = self.norm(x)

        x = x[:, 0]  # take cls token
        x = self.head(x)
        return x

## Running the ViT

In [10]:
from torch.nn import functional as F
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score
import torch
import torch.nn as nn

class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    losses = AverageMeter()
    progress_bar = tqdm(train_loader, desc='Training')

    all_predictions = []
    all_labels = []

    for batch in progress_bar:
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        losses.update(loss.item(), images.size(0))
        progress_bar.set_postfix({'train_loss': f'{losses.avg:.4f}'})

        predictions = torch.argmax(outputs, dim=1)
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    epoch_accuracy = accuracy_score(all_labels, all_predictions)
    return losses.avg, epoch_accuracy

def evaluate(model, test_loader, criterion, device):
    model.eval()
    losses = AverageMeter()

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc='Testing')
        for batch in progress_bar:
            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            losses.update(loss.item(), images.size(0))
            progress_bar.set_postfix({'test_loss': f'{losses.avg:.4f}'})

            predictions = torch.argmax(outputs, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_predictions)
    return losses.avg, accuracy

def train_model(model, train_loader, test_loader, num_epochs=100):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=5e-5,
        weight_decay=0.02,
        betas=(0.9, 0.999)
    )
    steps_per_epoch = len(train_loader)
    total_steps = steps_per_epoch * num_epochs
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=5e-4,
        total_steps=total_steps,
        pct_start=0.2,
        anneal_strategy='cos',
        cycle_momentum=True,
        base_momentum=0.85,
        max_momentum=0.95,
        div_factor=10.0,
        final_div_factor=1000.0
    )

    # Add tracking for best model
    best_accuracy = 0.0
    best_model_state = None
    best_epoch = 0
    best_test_loss = float('inf')

    print("Starting training...")
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        train_loss, train_accuracy = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        test_loss, test_accuracy = evaluate(
            model, test_loader, criterion, device
        )
        scheduler.step()

        # Save best model
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            best_test_loss = test_loss
            best_model_state = model.state_dict().copy()
            best_epoch = epoch + 1

            # Save the checkpoint with all necessary information
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'train_loss': train_loss,
                'train_accuracy': train_accuracy,
                'test_loss': test_loss,
                'test_accuracy': test_accuracy,
                'best_accuracy': best_accuracy
            }
            torch.save(checkpoint, 'best_vit_model.pth')
            print(f"New best model saved! Accuracy: {best_accuracy:.4f}")

        if epoch % 10 == 0:
            print(f"Train Loss: {train_loss:.4f} | Train Accuracy: {train_accuracy:.4f}")
            print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_accuracy:.4f}")

    # Load best model before returning
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"\nLoaded best model from epoch {best_epoch}")
        print(f"Best Test Metrics:")
        print(f"Test Loss: {best_test_loss:.4f}")
        print(f"Test Accuracy: {best_accuracy:.4f}")

    return model, best_accuracy

# Initialize model
model = VisionTransformer(
    img_size=224,
    patch_size=16,
    in_channels=3,
    num_classes=len(class_names),
    embed_dim=384,
    depth=8,
    n_heads=8,
    mlp_ratio=4,
    qkv_bias=True,
    drop_rate=0.15
)

# Train the model
trained_model, best_accuracy = train_model(
    model,
    train_dataloader,
    test_dataloader,
    num_epochs=100
)
print(f"\nTraining completed!")
print(f"Best test accuracy achieved: {best_accuracy:.4f}")

# To load the saved model later, you can use:
"""
checkpoint = torch.load('best_vit_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
best_accuracy = checkpoint['best_accuracy']
"""

Starting training...

Epoch 1/100


Training: 100%|██████████| 64/64 [00:22<00:00,  2.86it/s, train_loss=2.0223]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.19it/s, test_loss=1.9284]


New best model saved! Accuracy: 0.2857
Train Loss: 2.0223 | Train Accuracy: 0.2454
Test Loss: 1.9284 | Test Accuracy: 0.2857

Epoch 2/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.14it/s, train_loss=1.8943]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.56it/s, test_loss=1.8885]


New best model saved! Accuracy: 0.3243

Epoch 3/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.10it/s, train_loss=1.8499]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.92it/s, test_loss=1.8510]



Epoch 4/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.15it/s, train_loss=1.8276]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.31it/s, test_loss=1.8507]


New best model saved! Accuracy: 0.4015

Epoch 5/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.92it/s, train_loss=1.7915]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.44it/s, test_loss=1.8233]



Epoch 6/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.06it/s, train_loss=1.7779]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.40it/s, test_loss=1.7263]



Epoch 7/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.14it/s, train_loss=1.7218]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.68it/s, test_loss=1.7091]



Epoch 8/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.12it/s, train_loss=1.6911]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.83it/s, test_loss=1.7854]



Epoch 9/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.11it/s, train_loss=1.6972]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.37it/s, test_loss=1.7284]


New best model saved! Accuracy: 0.4286

Epoch 10/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.02it/s, train_loss=1.6655]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.41it/s, test_loss=1.6147]


New best model saved! Accuracy: 0.4865

Epoch 11/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.02it/s, train_loss=1.5832]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.33it/s, test_loss=1.6972]


Train Loss: 1.5832 | Train Accuracy: 0.4594
Test Loss: 1.6972 | Test Accuracy: 0.4054

Epoch 12/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.07it/s, train_loss=1.5587]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.90it/s, test_loss=1.7578]



Epoch 13/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.11it/s, train_loss=1.5366]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.49it/s, test_loss=1.5681]



Epoch 14/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.09it/s, train_loss=1.4620]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.28it/s, test_loss=1.6037]


New best model saved! Accuracy: 0.4981

Epoch 15/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.03it/s, train_loss=1.4931]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.26it/s, test_loss=1.5532]



Epoch 16/100


Training: 100%|██████████| 64/64 [00:27<00:00,  2.31it/s, train_loss=1.4560]
Testing: 100%|██████████| 17/17 [00:05<00:00,  3.38it/s, test_loss=1.5270]



Epoch 17/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.08it/s, train_loss=1.4265]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.35it/s, test_loss=1.6693]



Epoch 18/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.02it/s, train_loss=1.3946]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.44it/s, test_loss=1.5023]



Epoch 19/100


Training: 100%|██████████| 64/64 [00:22<00:00,  2.89it/s, train_loss=1.3098]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.37it/s, test_loss=1.4174]


New best model saved! Accuracy: 0.5483

Epoch 20/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.99it/s, train_loss=1.3138]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.90it/s, test_loss=1.4550]


New best model saved! Accuracy: 0.5560

Epoch 21/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.09it/s, train_loss=1.2567]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.53it/s, test_loss=1.4163]


New best model saved! Accuracy: 0.6100
Train Loss: 1.2567 | Train Accuracy: 0.6197
Test Loss: 1.4163 | Test Accuracy: 0.6100

Epoch 22/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.06it/s, train_loss=1.2525]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.84it/s, test_loss=1.3464]



Epoch 23/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.06it/s, train_loss=1.2688]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.30it/s, test_loss=1.3964]



Epoch 24/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.96it/s, train_loss=1.1993]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.36it/s, test_loss=1.3805]


New best model saved! Accuracy: 0.6139

Epoch 25/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.96it/s, train_loss=1.2002]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.30it/s, test_loss=1.3404]


New best model saved! Accuracy: 0.6293

Epoch 26/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.93it/s, train_loss=1.2144]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.13it/s, test_loss=1.3975]



Epoch 27/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.09it/s, train_loss=1.1752]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.57it/s, test_loss=1.5581]



Epoch 28/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.10it/s, train_loss=1.1756]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.93it/s, test_loss=1.3849]



Epoch 29/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.03it/s, train_loss=1.1439]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.34it/s, test_loss=1.3389]


New best model saved! Accuracy: 0.6448

Epoch 30/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.96it/s, train_loss=1.1137]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.21it/s, test_loss=1.2277]


New best model saved! Accuracy: 0.6525

Epoch 31/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.97it/s, train_loss=1.0931]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.24it/s, test_loss=1.3628]


Train Loss: 1.0931 | Train Accuracy: 0.7038
Test Loss: 1.3628 | Test Accuracy: 0.6178

Epoch 32/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.00it/s, train_loss=1.1034]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.20it/s, test_loss=1.2618]



Epoch 33/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.06it/s, train_loss=1.0655]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.57it/s, test_loss=1.2349]



Epoch 34/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.07it/s, train_loss=1.0765]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.80it/s, test_loss=1.1893]


New best model saved! Accuracy: 0.6795

Epoch 35/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.05it/s, train_loss=1.0491]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.19it/s, test_loss=1.2647]



Epoch 36/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.01it/s, train_loss=1.1101]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.31it/s, test_loss=1.1919]


New best model saved! Accuracy: 0.6834

Epoch 37/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.99it/s, train_loss=1.0269]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.25it/s, test_loss=1.1846]


New best model saved! Accuracy: 0.6911

Epoch 38/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.97it/s, train_loss=1.0128]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.33it/s, test_loss=1.1593]



Epoch 39/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.03it/s, train_loss=0.9893]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.74it/s, test_loss=1.2806]



Epoch 40/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.10it/s, train_loss=1.0129]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.71it/s, test_loss=1.1964]



Epoch 41/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.10it/s, train_loss=0.9976]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.30it/s, test_loss=1.2373]


Train Loss: 0.9976 | Train Accuracy: 0.7429
Test Loss: 1.2373 | Test Accuracy: 0.6795

Epoch 42/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.98it/s, train_loss=0.9656]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.36it/s, test_loss=1.1713]


New best model saved! Accuracy: 0.6988

Epoch 43/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.97it/s, train_loss=0.9533]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.39it/s, test_loss=1.4480]



Epoch 44/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.02it/s, train_loss=0.9699]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.96it/s, test_loss=1.0343]


New best model saved! Accuracy: 0.7490

Epoch 45/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.06it/s, train_loss=0.9692]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.75it/s, test_loss=1.2055]



Epoch 46/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.06it/s, train_loss=0.9387]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.73it/s, test_loss=1.1242]



Epoch 47/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.05it/s, train_loss=0.9455]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.26it/s, test_loss=1.4309]



Epoch 48/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.97it/s, train_loss=0.9237]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.26it/s, test_loss=1.1734]



Epoch 49/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.99it/s, train_loss=0.9187]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.26it/s, test_loss=1.2012]



Epoch 50/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.00it/s, train_loss=0.9337]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.04it/s, test_loss=1.3342]



Epoch 51/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.06it/s, train_loss=0.8963]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.62it/s, test_loss=1.0479]


New best model saved! Accuracy: 0.7568
Train Loss: 0.8963 | Train Accuracy: 0.7908
Test Loss: 1.0479 | Test Accuracy: 0.7568

Epoch 52/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.07it/s, train_loss=0.8957]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.57it/s, test_loss=1.3480]



Epoch 53/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.07it/s, train_loss=0.9035]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.08it/s, test_loss=1.0313]


New best model saved! Accuracy: 0.7645

Epoch 54/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.00it/s, train_loss=0.8640]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.39it/s, test_loss=1.0862]



Epoch 55/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.97it/s, train_loss=0.8684]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.09it/s, test_loss=1.0933]



Epoch 56/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.94it/s, train_loss=0.8581]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.21it/s, test_loss=1.0781]



Epoch 57/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.01it/s, train_loss=0.8622]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.77it/s, test_loss=1.1748]



Epoch 58/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.08it/s, train_loss=0.8344]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.53it/s, test_loss=1.1239]



Epoch 59/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.07it/s, train_loss=0.8619]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.99it/s, test_loss=1.0412]



Epoch 60/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.98it/s, train_loss=0.8223]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.39it/s, test_loss=1.2435]



Epoch 61/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.97it/s, train_loss=0.8427]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.37it/s, test_loss=1.0540]


Train Loss: 0.8427 | Train Accuracy: 0.8172
Test Loss: 1.0540 | Test Accuracy: 0.7529

Epoch 62/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.99it/s, train_loss=0.8481]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.22it/s, test_loss=1.1759]



Epoch 63/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.03it/s, train_loss=0.8484]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.75it/s, test_loss=1.1364]



Epoch 64/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.05it/s, train_loss=0.8207]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.61it/s, test_loss=1.1286]



Epoch 65/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.07it/s, train_loss=0.8257]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.36it/s, test_loss=1.2142]



Epoch 66/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.98it/s, train_loss=0.8153]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.33it/s, test_loss=1.1403]



Epoch 67/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.98it/s, train_loss=0.8249]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.36it/s, test_loss=1.1241]



Epoch 68/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.02it/s, train_loss=0.8490]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.12it/s, test_loss=1.1163]



Epoch 69/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.06it/s, train_loss=0.8057]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.54it/s, test_loss=1.1107]



Epoch 70/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.06it/s, train_loss=0.8370]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.83it/s, test_loss=1.0486]


New best model saved! Accuracy: 0.7799

Epoch 71/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.00it/s, train_loss=0.8134]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.05it/s, test_loss=1.0342]


Train Loss: 0.8134 | Train Accuracy: 0.8495
Test Loss: 1.0342 | Test Accuracy: 0.7683

Epoch 72/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.98it/s, train_loss=0.8471]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.11it/s, test_loss=1.0615]



Epoch 73/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.93it/s, train_loss=0.8101]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.19it/s, test_loss=1.2165]



Epoch 74/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.95it/s, train_loss=0.7856]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.24it/s, test_loss=0.9724]



Epoch 75/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.01it/s, train_loss=0.7664]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.70it/s, test_loss=1.0298]



Epoch 76/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.04it/s, train_loss=0.7921]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.51it/s, test_loss=0.9686]


New best model saved! Accuracy: 0.7954

Epoch 77/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.97it/s, train_loss=0.7733]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.77it/s, test_loss=1.0924]



Epoch 78/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.01it/s, train_loss=0.7471]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.13it/s, test_loss=1.1044]



Epoch 79/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.94it/s, train_loss=0.7984]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.25it/s, test_loss=1.1230]



Epoch 80/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.93it/s, train_loss=0.7554]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.17it/s, test_loss=1.1722]



Epoch 81/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.91it/s, train_loss=0.7573]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.28it/s, test_loss=1.3085]


Train Loss: 0.7573 | Train Accuracy: 0.8602
Test Loss: 1.3085 | Test Accuracy: 0.6718

Epoch 82/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.97it/s, train_loss=0.7962]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.65it/s, test_loss=1.0711]



Epoch 83/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.01it/s, train_loss=0.7669]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.46it/s, test_loss=1.1128]



Epoch 84/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.01it/s, train_loss=0.7647]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.86it/s, test_loss=1.1301]



Epoch 85/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.97it/s, train_loss=0.7817]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.20it/s, test_loss=1.0594]



Epoch 86/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.94it/s, train_loss=0.7557]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.13it/s, test_loss=1.1905]



Epoch 87/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.92it/s, train_loss=0.7516]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.25it/s, test_loss=0.9281]


New best model saved! Accuracy: 0.8108

Epoch 88/100


Training: 100%|██████████| 64/64 [00:22<00:00,  2.89it/s, train_loss=0.7175]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.17it/s, test_loss=1.0724]



Epoch 89/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.95it/s, train_loss=0.7138]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.75it/s, test_loss=1.1898]



Epoch 90/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.04it/s, train_loss=0.7455]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.46it/s, test_loss=1.0997]



Epoch 91/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.05it/s, train_loss=0.7327]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.87it/s, test_loss=1.0637]


Train Loss: 0.7327 | Train Accuracy: 0.8827
Test Loss: 1.0637 | Test Accuracy: 0.7529

Epoch 92/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.00it/s, train_loss=0.7883]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.23it/s, test_loss=1.0918]



Epoch 93/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.92it/s, train_loss=0.7482]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, test_loss=0.9903]



Epoch 94/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.94it/s, train_loss=0.7352]
Testing: 100%|██████████| 17/17 [00:03<00:00,  4.26it/s, test_loss=1.0703]



Epoch 95/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.96it/s, train_loss=0.7263]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.06it/s, test_loss=1.0978]



Epoch 96/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.99it/s, train_loss=0.7517]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.59it/s, test_loss=1.0606]



Epoch 97/100


Training: 100%|██████████| 64/64 [00:20<00:00,  3.05it/s, train_loss=0.7211]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.53it/s, test_loss=1.0205]



Epoch 98/100


Training: 100%|██████████| 64/64 [00:21<00:00,  3.01it/s, train_loss=0.7429]
Testing: 100%|██████████| 17/17 [00:04<00:00,  3.94it/s, test_loss=1.1618]



Epoch 99/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.95it/s, train_loss=0.7500]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.20it/s, test_loss=1.1239]



Epoch 100/100


Training: 100%|██████████| 64/64 [00:21<00:00,  2.93it/s, train_loss=0.7460]
Testing: 100%|██████████| 17/17 [00:04<00:00,  4.16it/s, test_loss=1.1381]


Loaded best model from epoch 87
Best Test Metrics:
Test Loss: 0.9281
Test Accuracy: 0.8108

Training completed!
Best test accuracy achieved: 0.8108





"\ncheckpoint = torch.load('best_vit_model.pth')\nmodel.load_state_dict(checkpoint['model_state_dict'])\noptimizer.load_state_dict(checkpoint['optimizer_state_dict'])\nscheduler.load_state_dict(checkpoint['scheduler_state_dict'])\nbest_accuracy = checkpoint['best_accuracy']\n"