In [1]:
#!/usr/bin/env python
# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vit_h_14, ViT_H_14_Weights
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import balanced_accuracy_score
import numpy as np
import random
import os

# === Parameters ===
working_dir = "/data/lodhar2/milan"
os.chdir(working_dir)
TRAIN_PATH = "data/vit_train.npz"
VAL_PATH = "data/vit_val.npz"
BATCH_SIZE = 1
EPOCHS = 20
PATIENCE = 5
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# === Reproducibility ===
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# === Load preprocessed data ===
print("Loading data...")
train_data = np.load(TRAIN_PATH)
val_data = np.load(VAL_PATH)

X_train = torch.tensor(train_data["images"]).float()
X_val = torch.tensor(val_data["images"]).float()

# === Map string labels to integers ===
unique_labels = np.unique(train_data["labels"])
label_to_idx = {label: i for i, label in enumerate(unique_labels)}

y_train = torch.tensor([label_to_idx[label] for label in train_data["labels"]]).long()
y_val = torch.tensor([label_to_idx[label] for label in val_data["labels"]]).long()

num_classes = len(torch.unique(y_train))
print(f"Detected {num_classes} classes: {unique_labels}.")

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=BATCH_SIZE, shuffle=False)

Using device: cuda
Loading data...
Detected 6 classes: ['Angiomyolipoma' 'Chromophobe' 'Clear_cell' 'Hybrid' 'Oncocytoma'
 'Papillary'].


In [2]:
# === Top-K accuracy ===
def top_k_accuracy(y_true, y_probs, k=2):
    top_k_preds = np.argsort(y_probs, axis=1)[:, -k:]
    correct = sum(label in top_k for label, top_k in zip(y_true, top_k_preds))
    return correct / len(y_true)

# === Define model ===
print("Loading ViT-H-14...")
weights = ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
model = vit_h_14(weights=weights)

# Replace head
in_features = model.heads[-1].in_features
model.heads = nn.Linear(in_features, num_classes)
model.to(DEVICE)

# === Loss and optimizer ===
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2)

Loading ViT-H-14...


In [None]:
# === Training loop ===
best_top2_acc = 0
patience_counter = 0

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"\nEpoch {epoch+1}/{EPOCHS} - Training Loss: {running_loss / len(train_loader):.4f}")

    # === Validation ===
    model.eval()
    all_probs, all_preds, all_labels = [], [], []
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(DEVICE)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1).cpu().numpy()
            preds = np.argmax(probs, axis=1)
            all_probs.append(probs)
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())

    all_probs = np.concatenate(all_probs, axis=0)
    all_labels = np.array(all_labels)
    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    top2_acc = top_k_accuracy(all_labels, all_probs, k=2)

    print(f"Validation Balanced Accuracy: {bal_acc:.4f} | Top-2 Accuracy: {top2_acc:.4f}")

    scheduler.step(top2_acc)

    if top2_acc > best_top2_acc:
        best_top2_acc = top2_acc
        patience_counter = 0
        torch.save(model.state_dict(), "best_vit.pt")
        print("New best model saved.")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Early stopping triggered.")
            break


Epoch 1/20 - Training Loss: 2.3485
Validation Balanced Accuracy: 0.1667 | Top-2 Accuracy: 0.4091
New best model saved.

Epoch 2/20 - Training Loss: 1.9910
Validation Balanced Accuracy: 0.1771 | Top-2 Accuracy: 0.2727

Epoch 3/20 - Training Loss: 1.9517
Validation Balanced Accuracy: 0.1553 | Top-2 Accuracy: 0.1773

Epoch 4/20 - Training Loss: 1.9403
Validation Balanced Accuracy: 0.1595 | Top-2 Accuracy: 0.2818

Epoch 5/20 - Training Loss: 1.7836
Validation Balanced Accuracy: 0.1523 | Top-2 Accuracy: 0.3727

Epoch 6/20 - Training Loss: 1.7581
