# Vision Transformer (ViT) — Knee OA Classifier
*Generated on 2025-09-23 14:43:50*

Trains a **Vision Transformer (ViT_B_16)** on the Knee OA dataset with:
- Same augmentation & imbalance handling as earlier
- **Minimum 50 epochs** before early stopping
- **Percentage confusion matrix** (each row sums to 100%; diagonal shows per-class accuracy in %)


In [1]:
# ================== Setup & Config ==================
import os, random, time, math, json
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models
from PIL import Image

from sklearn.metrics import confusion_matrix, classification_report, roc_curve, precision_recall_curve
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import kagglehub

SEED = 42
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

BATCH_SIZE, NUM_WORKERS = 32, 0
NUM_EPOCHS, MIN_EPOCHS, PATIENCE = 100, 20, 5
LR, WEIGHT_DECAY = 3e-4, 1e-4
PRETRAINED, USE_AMP = True, True
CLASSES = ['0','1','2','3','4']; N_CLASSES=len(CLASSES); IMG_SIZE=224


Using device: cuda


In [2]:
# ================== Data ==================
# path = kagglehub.dataset_download("shashwatwork/knee-osteoarthritis-dataset-with-severity")
# print("Dataset path:", path)
path = 'Data'

DATA_DIR = Path(path); TRAIN_DIR = DATA_DIR/'train'; TEST_DIR = DATA_DIR/'test'

def load_dataset_as_dataframe(subdir: str, max_images=None):
    root = TRAIN_DIR.parent / subdir; rows=[]
    for cls in CLASSES:
        cdir = root / cls
        if not cdir.exists(): continue
        files = sorted([p for p in cdir.iterdir() if p.is_file() and p.suffix.lower() in {'.png','.jpg','.jpeg'}])
        if max_images: files = files[:max_images]
        for p in files: rows.append({'path': str(p), 'label': int(cls)})
    return pd.DataFrame(rows)

train_df = load_dataset_as_dataframe('train'); test_df = load_dataset_as_dataframe('test')
tr_df, val_df = train_test_split(train_df, test_size=0.2, stratify=train_df['label'], random_state=SEED)
print('Splits:', tr_df.shape, val_df.shape, test_df.shape)

class_counts = tr_df['label'].value_counts().sort_index(); total = class_counts.sum()
class_weights = (total / (len(CLASSES) * class_counts)).to_numpy()
sample_weights = tr_df['label'].map({i:w for i,w in enumerate(class_weights)}).to_numpy()
print('Class counts:', class_counts.to_dict())
print('Class weights:', class_weights)


Splits: (4622, 2) (1156, 2) (1656, 2)
Class counts: {0: 1829, 1: 837, 2: 1213, 3: 605, 4: 138}
Class weights: [0.50541279 1.10442055 0.76207749 1.52793388 6.69855072]


In [3]:
# ================== Transforms & Dataset ==================
train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
eval_tfms = transforms.Compose([
    transforms.Resize(int(IMG_SIZE*1.15)),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class KneeDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df.reset_index(drop=True); self.transform = transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['path']).convert('RGB')
        img = self.transform(img) if self.transform else img
        return img, int(row['label'])

train_ds = KneeDataset(tr_df, train_tfms)
val_ds   = KneeDataset(val_df, eval_tfms)
test_ds  = KneeDataset(test_df, eval_tfms)

sampler = torch.utils.data.WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)


In [4]:
# ================== ViT Model ==================
def build_vit_b16(num_classes=len(CLASSES), pretrained=True):
    vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT if pretrained else None)
    vit.heads.head = torch.nn.Linear(vit.heads.head.in_features, num_classes)
    return vit

model = build_vit_b16(pretrained=PRETRAINED).to(device)

criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float32, device=device))
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda') and USE_AMP)


In [5]:
# ================== Train / Eval Utils ==================
def accuracy_from_logits(logits, targets):
    preds = torch.argmax(logits, dim=1)
    return (preds == targets).float().mean().item()

@torch.no_grad()
def evaluate(model, loader):
    model.eval(); all_logits=[]; all_targets=[]; total=0.0
    for imgs,labels in loader:
        imgs,labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        total += loss.item()*imgs.size(0)
        all_logits.append(logits.detach().cpu()); all_targets.append(labels.detach().cpu())
    all_logits = torch.cat(all_logits); all_targets = torch.cat(all_targets)
    acc = accuracy_from_logits(all_logits, all_targets)
    return total/len(loader.dataset), acc, all_logits, all_targets

def train_one_epoch(model, loader, optimizer, scaler):
    model.train(); total_loss=0.0; total_acc=0.0
    for imgs,labels in tqdm(loader, leave=False):
        imgs,labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device.type=='cuda') and USE_AMP):
            logits = model(imgs); loss = criterion(logits, labels)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        total_loss += loss.item()*imgs.size(0)
        total_acc  += accuracy_from_logits(logits.detach(), labels)*imgs.size(0)
    n=len(loader.dataset); return total_loss/n, total_acc/n


In [6]:
# ================== Training Loop (Min 50 epochs) ==================
best_val_acc=-1.0; history={'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[], 'lr':[]}
ckpt_path='best_vit_b16.pt'; patience_used=0

for epoch in range(1, NUM_EPOCHS+1):
    t0=time.time()
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scaler)
    val_loss, val_acc, _, _ = evaluate(model, val_loader)
    scheduler.step(val_acc)

    history['train_loss'].append(train_loss); history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss); history['val_acc'].append(val_acc)
    history['lr'].append(optimizer.param_groups[0]['lr'])

    print(f"Epoch {epoch:02d} | train_loss={train_loss:.4f} acc={train_acc:.4f} | "
          f"val_loss={val_loss:.4f} acc={val_acc:.4f} | lr={optimizer.param_groups[0]['lr']:.2e}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({'model_state_dict': model.state_dict(), 'classes': CLASSES}, ckpt_path)
        patience_used = 0
    else:
        patience_used += 1
    if epoch >= MIN_EPOCHS and patience_used >= PATIENCE:
        print(f"Early stopping at epoch {epoch}."); break
print('Best val acc:', best_val_acc)


                                                 

Epoch 01 | train_loss=1.2206 acc=0.1967 | val_loss=2.1066 acc=0.0303 | lr=3.00e-04


                                                 

Epoch 02 | train_loss=1.1700 acc=0.2114 | val_loss=2.1668 acc=0.0303 | lr=3.00e-04


                                                 

Epoch 03 | train_loss=1.1923 acc=0.1984 | val_loss=2.1470 acc=0.0303 | lr=3.00e-04


                                                 

Epoch 00004: reducing learning rate of group 0 to 1.5000e-04.
Epoch 04 | train_loss=1.1733 acc=0.1997 | val_loss=2.0330 acc=0.0303 | lr=1.50e-04


                                                 

Epoch 05 | train_loss=1.1694 acc=0.2042 | val_loss=2.1781 acc=0.0303 | lr=1.50e-04


                                                 

Epoch 06 | train_loss=1.2102 acc=0.1876 | val_loss=2.1157 acc=0.0303 | lr=1.50e-04


                                                  

KeyboardInterrupt: 

In [None]:
# ================== Curves ==================
def plot_curves(history, title='ViT_B_16'):
    fig,ax = plt.subplots(figsize=(6,4))
    ax.plot(history['train_loss'], label='train_loss')
    ax.plot(history['val_loss'], label='val_loss')
    ax.set_title(f'{title} — Loss'); ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.grid(alpha=0.3); ax.legend(); plt.show()

    fig,ax = plt.subplots(figsize=(6,4))
    ax.plot(history['train_acc'], label='train_acc')
    ax.plot(history['val_acc'], label='val_acc')
    ax.set_title(f'{title} — Accuracy'); ax.set_xlabel('Epoch'); ax.set_ylabel('Accuracy'); ax.grid(alpha=0.3); ax.legend(); plt.show()

plot_curves(history)


In [None]:
# ================== Percentage Confusion Matrix ==================
@torch.no_grad()
def percent_confusion_matrix(logits, targets, classes=CLASSES):
    y_true = targets.numpy()
    y_pred = torch.argmax(logits,1).numpy()
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes)))).astype(np.float64)
    row_sums = cm.sum(axis=1, keepdims=True); row_sums[row_sums==0]=1.0
    cm_pct = (cm/row_sums)*100.0
    df = pd.DataFrame(np.round(cm_pct,2), index=[f'True {c}' for c in classes],
                      columns=[f'Pred {c}' for c in classes])
    diag = np.diag(cm_pct); support = cm.sum(axis=1).astype(int)
    summary = pd.DataFrame({'Class': classes, 'Per-Class Acc (%)': np.round(diag,2), 'Support': support})
    return df, summary

ckpt = torch.load('best_vit_b16.pt', map_location=device)
model.load_state_dict(ckpt['model_state_dict'])

val_loss, val_acc, val_logits, val_targets = evaluate(model, val_loader)
test_loss, test_acc, test_logits, test_targets = evaluate(model, test_loader)

print(f"VAL  loss={val_loss:.4f} acc={val_acc:.4f}")
val_df, val_summary = percent_confusion_matrix(val_logits, val_targets); display(val_df); display(val_summary)

print(f"TEST loss={test_loss:.4f} acc={test_acc:.4f}")
test_df, test_summary = percent_confusion_matrix(test_logits, test_targets); display(test_df); display(test_summary)

val_df.to_csv('vit_val_percent_confusion.csv'); test_df.to_csv('vit_test_percent_confusion.csv')
val_summary.to_csv('vit_val_per_class_accuracy.csv', index=False)
test_summary.to_csv('vit_test_per_class_accuracy.csv', index=False)
print('Saved CSVs.')
