# Setup

In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split, Subset
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.nn.functional import softmax
import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data Loader

In [2]:
# Transforms
transform_inception = transforms.Compose([
    transforms.Resize(320),
    transforms.RandomResizedCrop(299),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_vit = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [3]:
# Paths
data_dir = './cassava'
train_dir = os.path.join(data_dir, 'train')

# Full datasets
full_dataset_incep = datasets.ImageFolder(train_dir, transform=transform_inception)
full_dataset_vit = datasets.ImageFolder(train_dir, transform=transform_vit)

# Shared split
dataset_size = len(full_dataset_incep)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size
indices = torch.randperm(dataset_size)
train_indices, val_indices = indices[:train_size], indices[train_size:]

# Subsets
train_dataset_incep = Subset(full_dataset_incep, train_indices)
val_dataset_incep = Subset(full_dataset_incep, val_indices)

train_dataset_vit = Subset(full_dataset_vit, train_indices)
val_dataset_vit = Subset(full_dataset_vit, val_indices)

# Dataloaders
train_loader_incep = DataLoader(train_dataset_incep, batch_size=32, shuffle=True)
val_loader_incep = DataLoader(val_dataset_incep, batch_size=32, shuffle=False)

train_loader_vit = DataLoader(train_dataset_vit, batch_size=32, shuffle=True)
val_loader_vit = DataLoader(val_dataset_vit, batch_size=32, shuffle=False)

In [6]:
# Class weights
labels = [y for _, y in full_dataset_incep]
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Number of classes
num_classes = len(full_dataset_incep.classes)

# Inception V3

In [7]:
# InceptionV3
model_incep = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
model_incep.fc = nn.Linear(model_incep.fc.in_features, num_classes)
model_incep.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer_incep = torch.optim.Adam(model_incep.parameters(), lr=0.001)
scheduler_incep = torch.optim.lr_scheduler.StepLR(optimizer_incep, step_size=10, gamma=0.1)
scaler_incep = GradScaler()

best_val_acc_incep = 0

for epoch in range(30):
    model_incep.train()
    train_loss = 0
    for images, labels in train_loader_incep:
        images, labels = images.to(device), labels.to(device)
        optimizer_incep.zero_grad()
        with autocast():
            outputs = model_incep(images)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            loss = criterion(outputs, labels)
        scaler_incep.scale(loss).backward()
        scaler_incep.step(optimizer_incep)
        scaler_incep.update()
        train_loss += loss.item() * images.size(0)
    train_loss /= len(train_loader_incep.dataset)

    # Validation
    model_incep.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader_incep:
            images, labels = images.to(device), labels.to(device)
            outputs = model_incep(images)
            if isinstance(outputs, tuple):
                outputs = outputs[0]
            preds = torch.argmax(outputs, dim=1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_acc = accuracy_score(val_labels, val_preds)

    if val_acc > best_val_acc_incep:
        best_val_acc_incep = val_acc
        torch.save(model_incep.state_dict(), 'best_model_inception.pth')

    scheduler_incep.step()

    print(f"[Inception] Epoch {epoch+1}/30 | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")

[Inception] Epoch 1/30 | Train Loss: 1.4182 | Val Acc: 0.5194
[Inception] Epoch 2/30 | Train Loss: 1.1607 | Val Acc: 0.6634
[Inception] Epoch 3/30 | Train Loss: 1.0649 | Val Acc: 0.5186
[Inception] Epoch 4/30 | Train Loss: 0.9968 | Val Acc: 0.6784
[Inception] Epoch 5/30 | Train Loss: 0.9211 | Val Acc: 0.6475
[Inception] Epoch 6/30 | Train Loss: 0.8856 | Val Acc: 0.6528
[Inception] Epoch 7/30 | Train Loss: 0.8826 | Val Acc: 0.6201
[Inception] Epoch 8/30 | Train Loss: 0.8667 | Val Acc: 0.6846
[Inception] Epoch 9/30 | Train Loss: 0.8467 | Val Acc: 0.5424
[Inception] Epoch 10/30 | Train Loss: 0.8135 | Val Acc: 0.7898
[Inception] Epoch 11/30 | Train Loss: 0.7194 | Val Acc: 0.7871
[Inception] Epoch 12/30 | Train Loss: 0.6914 | Val Acc: 0.7942
[Inception] Epoch 13/30 | Train Loss: 0.6507 | Val Acc: 0.8030
[Inception] Epoch 14/30 | Train Loss: 0.6213 | Val Acc: 0.8048
[Inception] Epoch 15/30 | Train Loss: 0.6210 | Val Acc: 0.8092
[Inception] Epoch 16/30 | Train Loss: 0.6271 | Val Acc: 0.7959
[

In [8]:
model_incep.load_state_dict(torch.load('best_model_inception.pth'))
model_incep.eval()

val_preds, val_labels = [], []
with torch.no_grad():
    for images, labels in val_loader_incep:
        images, labels = images.to(device), labels.to(device)
        outputs = model_incep(images)
        if isinstance(outputs, tuple):
            outputs = outputs[0]
        preds = torch.argmax(outputs, dim=1)
        val_preds.extend(preds.cpu().numpy())
        val_labels.extend(labels.cpu().numpy())

print("InceptionV3 Final Accuracy:", accuracy_score(val_labels, val_preds))
print(classification_report(val_labels, val_preds))
print(confusion_matrix(val_labels, val_preds))

InceptionV3 Final Accuracy: 0.818904593639576
              precision    recall  f1-score   support

           0       0.49      0.65      0.56        75
           1       0.85      0.79      0.82       285
           2       0.71      0.84      0.77       148
           3       0.95      0.84      0.89       564
           4       0.60      0.90      0.72        60

    accuracy                           0.82      1132
   macro avg       0.72      0.80      0.75      1132
weighted avg       0.84      0.82      0.83      1132

[[ 49   8   5   6   7]
 [ 34 226   7  10   8]
 [  3   6 124  10   5]
 [ 12  24  38 474  16]
 [  3   2   0   1  54]]


# ViT

In [9]:
model_vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
model_vit.heads.head = nn.Linear(model_vit.heads.head.in_features, num_classes)
model_vit.to(device)

# Freeze backbone initially
for param in model_vit.parameters():
    param.requires_grad = False
for param in model_vit.heads.parameters():
    param.requires_grad = True

optimizer_vit = AdamW(model_vit.parameters(), lr=5e-5, weight_decay=0.01)
scheduler_vit = CosineAnnealingWarmRestarts(optimizer_vit, T_0=10, T_mult=2)
scaler_vit = GradScaler()
criterion_vit = nn.CrossEntropyLoss(weight=class_weights)

best_val_acc_vit = 0
freeze_epochs = 5

for epoch in range(30):
    if epoch == freeze_epochs:
        # Unfreeze backbone
        for param in model_vit.parameters():
            param.requires_grad = True
        optimizer_vit = AdamW(model_vit.parameters(), lr=5e-5, weight_decay=0.01)
        scheduler_vit = CosineAnnealingWarmRestarts(optimizer_vit, T_0=10, T_mult=2)
        print("Unfroze ViT backbone.")

    model_vit.train()
    train_loss = 0
    for images, labels in train_loader_vit:
        images, labels = images.to(device), labels.to(device)
        optimizer_vit.zero_grad()
        with autocast():
            outputs = model_vit(images)
            loss = criterion_vit(outputs, labels)
        scaler_vit.scale(loss).backward()
        scaler_vit.step(optimizer_vit)
        scaler_vit.update()
        train_loss += loss.item() * images.size(0)
    train_loss /= len(train_loader_vit.dataset)

    # Validation
    model_vit.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader_vit:
            images, labels = images.to(device), labels.to(device)
            outputs = model_vit(images)
            preds = torch.argmax(outputs, dim=1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    val_acc = accuracy_score(val_labels, val_preds)

    if val_acc > best_val_acc_vit:
        best_val_acc_vit = val_acc
        torch.save(model_vit.state_dict(), 'best_model_vit.pth')

    scheduler_vit.step()

    print(f"[ViT] Epoch {epoch+1}/30 | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")

[ViT] Epoch 1/30 | Train Loss: 1.5192 | Val Acc: 0.4611
[ViT] Epoch 2/30 | Train Loss: 1.3745 | Val Acc: 0.5186
[ViT] Epoch 3/30 | Train Loss: 1.2898 | Val Acc: 0.5583
[ViT] Epoch 4/30 | Train Loss: 1.2361 | Val Acc: 0.5592
[ViT] Epoch 5/30 | Train Loss: 1.1981 | Val Acc: 0.5707
Unfroze ViT backbone.
[ViT] Epoch 6/30 | Train Loss: 0.8763 | Val Acc: 0.7906
[ViT] Epoch 7/30 | Train Loss: 0.6854 | Val Acc: 0.7279
[ViT] Epoch 8/30 | Train Loss: 0.5932 | Val Acc: 0.7968
[ViT] Epoch 9/30 | Train Loss: 0.5304 | Val Acc: 0.8074
[ViT] Epoch 10/30 | Train Loss: 0.4763 | Val Acc: 0.8242
[ViT] Epoch 11/30 | Train Loss: 0.4235 | Val Acc: 0.8534
[ViT] Epoch 12/30 | Train Loss: 0.4078 | Val Acc: 0.8366
[ViT] Epoch 13/30 | Train Loss: 0.3597 | Val Acc: 0.8525
[ViT] Epoch 14/30 | Train Loss: 0.3132 | Val Acc: 0.8428
[ViT] Epoch 15/30 | Train Loss: 0.2917 | Val Acc: 0.8454
[ViT] Epoch 16/30 | Train Loss: 0.4767 | Val Acc: 0.8366
[ViT] Epoch 17/30 | Train Loss: 0.4490 | Val Acc: 0.7818
[ViT] Epoch 18/30 

In [10]:
model_vit.load_state_dict(torch.load('best_model_vit.pth'))
model_vit.eval()

val_preds_vit, val_labels_vit = [], []
with torch.no_grad():
    for images, labels in val_loader_vit:
        images, labels = images.to(device), labels.to(device)
        outputs = model_vit(images)
        preds = torch.argmax(outputs, dim=1)
        val_preds_vit.extend(preds.cpu().numpy())
        val_labels_vit.extend(labels.cpu().numpy())

print("ViT Final Accuracy:", accuracy_score(val_labels_vit, val_preds_vit))
print(classification_report(val_labels_vit, val_preds_vit))
print(confusion_matrix(val_labels_vit, val_preds_vit))

ViT Final Accuracy: 0.8498233215547704
              precision    recall  f1-score   support

           0       0.49      0.79      0.60        75
           1       0.87      0.78      0.82       285
           2       0.82      0.74      0.78       148
           3       0.94      0.91      0.93       564
           4       0.75      0.93      0.83        60

    accuracy                           0.85      1132
   macro avg       0.77      0.83      0.79      1132
weighted avg       0.87      0.85      0.85      1132

[[ 59   8   5   1   2]
 [ 36 223   5  16   5]
 [ 11   9 109  13   6]
 [ 15  14  14 515   6]
 [  0   2   0   2  56]]


# Ensemble

In [11]:
# Load best models
model_incep.load_state_dict(torch.load('best_model_inception.pth'))
model_incep.eval()
model_vit.load_state_dict(torch.load('best_model_vit.pth'))
model_vit.eval()

# Get softmax logits
inception_logits = []
vit_logits = []
true_labels = []

# Inception predictions
with torch.no_grad():
    for images, labels in val_loader_incep:
        images = images.to(device)
        outputs = model_incep(images)
        if isinstance(outputs, tuple):
            outputs = outputs[0]
        inception_logits.append(softmax(outputs, dim=1).cpu())
        true_labels.extend(labels.cpu().numpy())

# ViT predictions
with torch.no_grad():
    for images, _ in val_loader_vit:
        images = images.to(device)
        outputs = model_vit(images)
        vit_logits.append(softmax(outputs, dim=1).cpu())

# Stack
inception_logits = torch.cat(inception_logits, dim=0)
vit_logits = torch.cat(vit_logits, dim=0)

# Weighted ensemble
ensemble_logits = 0.6 * inception_logits + 0.4 * vit_logits
ensemble_preds = torch.argmax(ensemble_logits, dim=1).numpy()

# Evaluation
print("Final Ensemble Accuracy:", accuracy_score(true_labels, ensemble_preds))
print(classification_report(true_labels, ensemble_preds))
print(confusion_matrix(true_labels, ensemble_preds))

Final Ensemble Accuracy: 0.8683745583038869
              precision    recall  f1-score   support

           0       0.56      0.76      0.64        75
           1       0.86      0.80      0.83       285
           2       0.88      0.80      0.84       148
           3       0.95      0.93      0.94       564
           4       0.70      0.93      0.80        60

    accuracy                           0.87      1132
   macro avg       0.79      0.84      0.81      1132
weighted avg       0.88      0.87      0.87      1132

[[ 57   9   4   2   3]
 [ 32 228   3  17   5]
 [  4  10 119   8   7]
 [  8  14  10 523   9]
 [  1   3   0   0  56]]
