<a href="https://www.kaggle.com/code/shobhiii/common-test-1-ml4sci?scriptVersionId=230354842" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Vision Transformer (ViT) Setup for Multiclass Image Classification  

This script imports essential libraries for loading datasets, preprocessing images, training a **ViT-based classifier**, and evaluating its performance using PyTorch and `transformers`.  


In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torchvision.transforms import functional as F
from torchvision import transforms
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, confusion_matrix, precision_score, recall_score, f1_score
from itertools import cycle

In [2]:
def plot_roc_curves(fpr, tpr, roc_auc, class_names):
    plt.figure(figsize=(10, 8))
    
    # Plot ROC curve for each class
    colors = cycle(['blue', 'red', 'green'])
    for i, color in zip(range(len(class_names)), colors):
        plt.plot(
            fpr[i], 
            tpr[i], 
            color=color, 
            lw=2,
            label=f'{class_names[i]} (AUC = {roc_auc[i]:.3f})'
        )
    
    # Plot micro-average ROC curve
    plt.plot(
        fpr["micro"], 
        tpr["micro"],
        color='deeppink', 
        linestyle=':', 
        linewidth=4,
        label=f'Micro-average (AUC = {roc_auc["micro"]:.3f})'
    )
    
    # Plot diagonal line (random classifier)
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curves')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.savefig('roc_curves.png', dpi=300)
    plt.close()

### Data Loading and Transformation  

Defines functions to load `.npy` image files, apply augmentations for training, and normalize images for evaluation. Creates PyTorch dataloaders for training, validation, and testing.  
 


In [3]:
def npy_loader(path):
    tensor = torch.from_numpy(np.load(path))

    # the tensor has shape [1,150,150] - single channel. so we will repeat the single channel to create 3 identical channels
    if tensor.dim() == 3 and tensor.shape[0] == 1:
        tensor = tensor.repeat(3, 1, 1)
    return tensor

In [4]:
# Transform pipeline for training data
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(224, scale = (0.8, 1.0)), #random crops with some zoom
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(0.1),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def train_loader(path):
    tensor = npy_loader(path)
    return train_transforms(tensor)

In [5]:
# Transform pipeline for validation and test data
eval_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def eval_loader(path):
    tensor = npy_loader(path)
    return eval_transforms(tensor)

In [6]:
train_data_path = "/kaggle/input/common-test-1-dataset/dataset/train"
test_data_path = "/kaggle/input/common-test-1-dataset/dataset/val"

# Load the full training dataset
full_train_dataset = datasets.DatasetFolder(
    root=train_data_path,
    loader=train_loader,
    extensions=('.npy',)  
)

# Load the test dataset 
test_dataset = datasets.DatasetFolder(
    root=test_data_path,
    loader=eval_loader,
    extensions=('.npy',)
)

train_size = int(0.9 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    full_train_dataset, 
    [train_size, val_size]
)

In [7]:
# Create data loaders
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    num_workers=4, 
    pin_memory=True
)

val_dataloader = DataLoader(
    val_dataset,  
    batch_size=32, 
    shuffle=False,
    num_workers=4, 
    pin_memory=True
)

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=32, 
    shuffle=False, 
    num_workers=4, 
    pin_memory=True
)

### Model Initialization  

Loads a pre-trained Vision Transformer (ViT) model and feature extractor, adapting it for a 3-class classification task.  


In [8]:
model_name = "google/vit-base-patch16-224"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name,
                                                 num_labels = 3,
                                                 ignore_mismatched_sizes=True)

preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
print("Class-to-label mapping:", train_dataset.dataset.class_to_idx)

Class-to-label mapping: {'no': 0, 'sphere': 1, 'vort': 2}


### Fine-Tuning Setup  

Defines the loss function (**CrossEntropyLoss**), optimizer (**AdamW** with weight decay), and learning rate scheduler (**CosineAnnealingLR**). Moves the model to GPU if available.  


In [10]:
#  FINE TUNING(LOSS FUNCTION, OPTIMIZER AND SCHEDULER)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr = 5e-5, weight_decay = 0.05)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 10)

### Evaluation Metrics  

Defines functions to compute **accuracy** and **ROC-AUC** for a 3-class classification task. The `compute_roc_auc` function calculates **one-vs-rest ROC curves** and micro-average AUC.  


In [11]:
def accuracy(outputs, labels):
    _, preds = torch.max(outputs.logits, dim = 1)
    return torch.sum(preds == labels).item() / len(labels)

# Function to compute ROC curve and AUC for each class
def compute_roc_auc(all_labels, all_logits):
    # convert to numpy array 
    all_labels = all_labels.cpu().numpy()
    all_logits = all_logits.cpu().numpy()

    n_classes = 3

    fpr = {}
    tpr = {}
    roc_auc = {}

    # for each class compute the ROC curve using one-vs-rest approach
    for i in range (n_classes):
        fpr[i], tpr[i], _ = roc_curve(all_labels == i, all_logits[:,i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(
        np.eye(n_classes)[all_labels].ravel(), all_logits.ravel()
    )
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    return fpr, tpr, roc_auc

In [12]:
num_epochs = 15
best_val_acc = 0.0
best_val_metric = 0.0

### Training and Validation Loop  

Implements the **training and evaluation process** for fine-tuning ViT:  
- Performs forward and backward passes, updates weights, and tracks loss & accuracy.  
- Computes **ROC-AUC** for multi-class classification.  
- Saves the best model based on **micro-average AUC**.  
- Optionally plots ROC curves at the last epoch.  


In [13]:
for epoch in range(num_epochs):
    model.train
    train_loss = 0.0
    train_acc = 0.0

    train_bar = tqdm(train_dataloader, desc = f"Epoch {epoch + 1}/{num_epochs}[train]")
    for images, labels in train_bar:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(pixel_values = images, labels = labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()

        # TRACK STATISTICS
        train_loss += loss.item()*images.size(0)
        train_acc += accuracy(outputs, labels)*images.size(0)

        # Update progress bar
        train_bar.set_postfix({
            "loss":loss.item(),
            "LR":optimizer.param_groups[0]["lr"]
        })

    # Calculate average training statistics
    train_loss = train_loss/len(train_dataset)
    train_acc = train_acc/len(train_dataset)

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    all_labels = []
    all_logits = []
    
    with torch.no_grad():
        val_bar = tqdm(val_dataloader, desc = f"Epoch {epoch+1}/{num_epochs} [Validation]")
        for images, labels in val_bar:
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(pixel_values=images, labels = labels)
            loss = outputs.loss

            # Track Statistics
            val_loss += loss.item()*images.size(0)
            val_acc += accuracy(outputs, labels)*images.size(0)

            # Collect predictions and labels for ROC-AUC calculation
            all_labels.append(labels)
            all_logits.append(outputs.logits)
            
            # Update progress bar
            val_bar.set_postfix({
                "Loss":loss.item()
            })

    # Concatenate all batches
    all_labels = torch.cat(all_labels)
    all_logits = torch.cat(all_logits)

    # Compute ROC-AUC
    fpr, tpr, roc_auc = compute_roc_auc(all_labels, all_logits)
    
    # Calculate average validation statistics
    val_loss = val_loss/len(val_dataset)
    val_acc = val_acc/len(val_dataset)

    # Update learning rate
    scheduler.step()

    # Print epoch summary
    print(f"Epoch {epoch+1}/{num_epochs}: "
         f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
         f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Print AUC for each class
    for i in range(3):
        class_name = train_dataset.dataset.classes[i]
        print(f"AUC for Class {class_name}: {roc_auc[i]:.4f}")

    # print micro-average AUC
    print(f"Micro-average AUC: {roc_auc['micro']:.4f}")
    
    # Save best model based on micro-average AUC instead of accuracy
    if roc_auc['micro'] > best_val_metric:
        best_val_metric = roc_auc['micro']
        model.save_pretrained("./best_vit_model")
        print(f"Model saved with best micro-average AUC: {best_val_metric:.4f}")

    # Optionally, plot and save ROC curves
    if epoch == num_epochs - 1:  # Plot on the last epoch
        plot_roc_curves(fpr, tpr, roc_auc, train_dataset.dataset.classes)   
print(f"Training complete! Best micro-average AUC: {best_val_metric:.4f}")

Epoch 1/15[train]: 100%|██████████| 844/844 [08:23<00:00,  1.68it/s, loss=0.651, LR=5e-5]
Epoch 1/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.05it/s, Loss=0.683]


Epoch 1/15: Train Loss: 0.9215, Train Acc: 0.5354, Val Loss: 0.7304, Val Acc: 0.6727
AUC for Class no: 0.8932
AUC for Class sphere: 0.7429
AUC for Class vort: 0.8155
Micro-average AUC: 0.8269
Model saved with best micro-average AUC: 0.8269


Epoch 2/15[train]: 100%|██████████| 844/844 [08:23<00:00,  1.68it/s, loss=0.41, LR=4.88e-5]
Epoch 2/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.10it/s, Loss=0.772]


Epoch 2/15: Train Loss: 0.6728, Train Acc: 0.7030, Val Loss: 0.6035, Val Acc: 0.7447
AUC for Class no: 0.9288
AUC for Class sphere: 0.8263
AUC for Class vort: 0.8972
Micro-average AUC: 0.8691
Model saved with best micro-average AUC: 0.8691


Epoch 3/15[train]: 100%|██████████| 844/844 [08:23<00:00,  1.68it/s, loss=0.489, LR=4.52e-5]
Epoch 3/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.10it/s, Loss=0.466]


Epoch 3/15: Train Loss: 0.5543, Train Acc: 0.7678, Val Loss: 0.5125, Val Acc: 0.7980
AUC for Class no: 0.9466
AUC for Class sphere: 0.8608
AUC for Class vort: 0.9109
Micro-average AUC: 0.9135
Model saved with best micro-average AUC: 0.9135


Epoch 4/15[train]: 100%|██████████| 844/844 [08:23<00:00,  1.67it/s, loss=0.303, LR=3.97e-5]
Epoch 4/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.09it/s, Loss=0.424]


Epoch 4/15: Train Loss: 0.4753, Train Acc: 0.8078, Val Loss: 0.4793, Val Acc: 0.8150
AUC for Class no: 0.9541
AUC for Class sphere: 0.8746
AUC for Class vort: 0.9353
Micro-average AUC: 0.9180
Model saved with best micro-average AUC: 0.9180


Epoch 5/15[train]: 100%|██████████| 844/844 [08:24<00:00,  1.67it/s, loss=0.156, LR=3.27e-5]
Epoch 5/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.07it/s, Loss=0.314]


Epoch 5/15: Train Loss: 0.4165, Train Acc: 0.8374, Val Loss: 0.3912, Val Acc: 0.8507
AUC for Class no: 0.9644
AUC for Class sphere: 0.9031
AUC for Class vort: 0.9466
Micro-average AUC: 0.9406
Model saved with best micro-average AUC: 0.9406


Epoch 6/15[train]: 100%|██████████| 844/844 [08:24<00:00,  1.67it/s, loss=0.165, LR=2.5e-5]
Epoch 6/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.09it/s, Loss=0.487]


Epoch 6/15: Train Loss: 0.3686, Train Acc: 0.8531, Val Loss: 0.3686, Val Acc: 0.8563
AUC for Class no: 0.9695
AUC for Class sphere: 0.9345
AUC for Class vort: 0.9476
Micro-average AUC: 0.9480
Model saved with best micro-average AUC: 0.9480


Epoch 7/15[train]: 100%|██████████| 844/844 [08:24<00:00,  1.67it/s, loss=0.364, LR=1.73e-5]
Epoch 7/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.09it/s, Loss=0.223]


Epoch 7/15: Train Loss: 0.3307, Train Acc: 0.8718, Val Loss: 0.3195, Val Acc: 0.8800
AUC for Class no: 0.9747
AUC for Class sphere: 0.9331
AUC for Class vort: 0.9659
Micro-average AUC: 0.9601
Model saved with best micro-average AUC: 0.9601


Epoch 8/15[train]: 100%|██████████| 844/844 [08:23<00:00,  1.68it/s, loss=0.42, LR=1.03e-5]
Epoch 8/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.10it/s, Loss=0.224]


Epoch 8/15: Train Loss: 0.2897, Train Acc: 0.8907, Val Loss: 0.2927, Val Acc: 0.8850
AUC for Class no: 0.9787
AUC for Class sphere: 0.9398
AUC for Class vort: 0.9691
Micro-average AUC: 0.9636
Model saved with best micro-average AUC: 0.9636


Epoch 9/15[train]: 100%|██████████| 844/844 [08:23<00:00,  1.68it/s, loss=0.364, LR=4.77e-6]
Epoch 9/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.07it/s, Loss=0.163]


Epoch 9/15: Train Loss: 0.2678, Train Acc: 0.8967, Val Loss: 0.2962, Val Acc: 0.8897
AUC for Class no: 0.9771
AUC for Class sphere: 0.9434
AUC for Class vort: 0.9763
Micro-average AUC: 0.9640
Model saved with best micro-average AUC: 0.9640


Epoch 10/15[train]: 100%|██████████| 844/844 [08:23<00:00,  1.68it/s, loss=0.264, LR=1.22e-6]
Epoch 10/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.08it/s, Loss=0.239]


Epoch 10/15: Train Loss: 0.2534, Train Acc: 0.9030, Val Loss: 0.2603, Val Acc: 0.8987
AUC for Class no: 0.9816
AUC for Class sphere: 0.9462
AUC for Class vort: 0.9787
Micro-average AUC: 0.9705
Model saved with best micro-average AUC: 0.9705


Epoch 11/15[train]: 100%|██████████| 844/844 [08:23<00:00,  1.67it/s, loss=0.41, LR=0]
Epoch 11/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.09it/s, Loss=0.23]


Epoch 11/15: Train Loss: 0.2475, Train Acc: 0.9071, Val Loss: 0.2715, Val Acc: 0.8960
AUC for Class no: 0.9797
AUC for Class sphere: 0.9453
AUC for Class vort: 0.9765
Micro-average AUC: 0.9692


Epoch 12/15[train]: 100%|██████████| 844/844 [08:23<00:00,  1.67it/s, loss=0.0427, LR=1.22e-6]
Epoch 12/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.09it/s, Loss=0.337]


Epoch 12/15: Train Loss: 0.2493, Train Acc: 0.9056, Val Loss: 0.2762, Val Acc: 0.8947
AUC for Class no: 0.9799
AUC for Class sphere: 0.9462
AUC for Class vort: 0.9768
Micro-average AUC: 0.9676


Epoch 13/15[train]: 100%|██████████| 844/844 [08:24<00:00,  1.67it/s, loss=0.413, LR=4.77e-6]
Epoch 13/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.10it/s, Loss=0.239]


Epoch 13/15: Train Loss: 0.2543, Train Acc: 0.9041, Val Loss: 0.2686, Val Acc: 0.8960
AUC for Class no: 0.9813
AUC for Class sphere: 0.9531
AUC for Class vort: 0.9753
Micro-average AUC: 0.9717
Model saved with best micro-average AUC: 0.9717


Epoch 14/15[train]: 100%|██████████| 844/844 [08:24<00:00,  1.67it/s, loss=0.554, LR=1.03e-5]
Epoch 14/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.07it/s, Loss=0.162]


Epoch 14/15: Train Loss: 0.2644, Train Acc: 0.9009, Val Loss: 0.2940, Val Acc: 0.8887
AUC for Class no: 0.9800
AUC for Class sphere: 0.9391
AUC for Class vort: 0.9725
Micro-average AUC: 0.9670


Epoch 15/15[train]: 100%|██████████| 844/844 [08:24<00:00,  1.67it/s, loss=0.443, LR=1.73e-5]
Epoch 15/15 [Validation]: 100%|██████████| 94/94 [00:18<00:00,  5.06it/s, Loss=0.375]


Epoch 15/15: Train Loss: 0.2819, Train Acc: 0.8934, Val Loss: 0.2961, Val Acc: 0.8863
AUC for Class no: 0.9767
AUC for Class sphere: 0.9463
AUC for Class vort: 0.9692
Micro-average AUC: 0.9648
Training complete! Best micro-average AUC: 0.9717


In [14]:
# Loading the best saved model
model_path = "./best_vit_model"
model = ViTForImageClassification.from_pretrained(model_path)
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

### Model Evaluation on Test Set  

Defines a function to evaluate the model on the test dataset by computing **loss, accuracy, and predictions**. Tracks performance metrics and updates a progress bar.  


In [15]:
# Evaluation on test set
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    test_loss = 0.0
    test_acc = 0.0
    all_labels = []
    all_logits = []
    all_preds = []
    
    with torch.no_grad():
        test_bar = tqdm(dataloader, desc="Testing")
        for images, labels in test_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(pixel_values=images, labels=labels)
            loss = outputs.loss
            
            # Get predictions
            _, preds = torch.max(outputs.logits, dim=1)
            
            # Track statistics
            test_loss += loss.item() * images.size(0)
            test_acc += torch.sum(preds == labels).item()
            
            # Collect for metrics calculation
            all_labels.append(labels)
            all_logits.append(outputs.logits)
            all_preds.append(preds)
            
            # Update progress bar
            test_bar.set_postfix({
                "Loss": loss.item()
            })
    
    # Calculate final metrics
    test_loss = test_loss / len(dataloader.dataset)
    test_acc = test_acc / len(dataloader.dataset)
    
    # Concatenate all batches
    all_labels = torch.cat(all_labels)
    all_logits = torch.cat(all_logits)
    all_preds = torch.cat(all_preds)
    
    return test_loss, test_acc, all_labels, all_logits, all_preds


In [16]:
# Run evaluation
test_loss, test_acc, all_labels, all_logits, all_preds = evaluate_model(
    model, test_dataloader, criterion, device
)

Testing: 100%|██████████| 235/235 [00:45<00:00,  5.20it/s, Loss=0.579]


### Test Set Performance Metrics  

Computes and prints key evaluation metrics:  
- **ROC-AUC** for each class and micro-average.  
- **Confusion matrix** for classification results.  
- **Precision, Recall, and F1-score** for each class, including weighted F1-score.  


In [17]:
# Compute ROC-AUC
fpr, tpr, roc_auc = compute_roc_auc(all_labels, all_logits)

# Calculate confusion matrix
conf_matrix = confusion_matrix(all_labels.cpu().numpy(), all_preds.cpu().numpy())

# Calculate precision, recall, and F1 score
precision = precision_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average=None)
recall = recall_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average=None)
f1 = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average=None)
weighted_f1 = f1_score(all_labels.cpu().numpy(), all_preds.cpu().numpy(), average='weighted')

# Print results
print("\n--- Test Set Evaluation Results ---")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Weighted F1 Score: {weighted_f1:.4f}")

# Print AUC for each class
print("\n--- AUC Scores per Class ---")
for i in range(3):
    class_name = test_dataset.classes[i]
    print(f"Class {class_name}: {roc_auc[i]:.4f}")
print(f"Micro-average AUC: {roc_auc['micro']:.4f}")

# Print precision, recall, and F1 score for each class
print("\n--- Precision, Recall, and F1 Score per Class ---")
for i in range(3):
    class_name = test_dataset.classes[i]
    print(f"Class {class_name}: Precision={precision[i]:.4f}, Recall={recall[i]:.4f}, F1={f1[i]:.4f}")



--- Test Set Evaluation Results ---
Test Loss: 0.3073
Test Accuracy: 0.8880
Weighted F1 Score: 0.8864

--- AUC Scores per Class ---
Class no: 0.9839
Class sphere: 0.9494
Class vort: 0.9697
Micro-average AUC: 0.9703

--- Precision, Recall, and F1 Score per Class ---
Class no: Precision=0.8132, Recall=0.9980, F1=0.8962
Class sphere: Precision=0.9714, Recall=0.7596, F1=0.8525
Class vort: Precision=0.9148, Recall=0.9064, F1=0.9106


In [18]:
# Display confusion matrix
print("\n--- Confusion Matrix ---")
class_names = test_dataset.classes
df_cm = pd.DataFrame(conf_matrix, index=class_names, columns=class_names)
print(df_cm)

# Plot ROC curves
plt.figure(figsize=(10, 8))
for i in range(3):
    class_name = test_dataset.classes[i]
    plt.plot(fpr[i], tpr[i], 
             label=f'Class {class_name} (AUC = {roc_auc[i]:.4f})')

plt.plot(fpr['micro'], tpr['micro'],
         label=f'Micro-average (AUC = {roc_auc["micro"]:.4f})',
         linestyle=':', linewidth=4)

plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves for Test Set')
plt.legend(loc="lower right")
plt.savefig('test_roc_curves.png')
plt.close()


--- Confusion Matrix ---
          no  sphere  vort
no      2495       3     2
sphere   392    1899   209
vort     181      53  2266


In [19]:
# # Save the model to Kaggle output directory
# output_path = "/kaggle/working/vit_model_final"
# model.save_pretrained(output_path)
# print(f"\nModel saved to {output_path}")