In [1]:
#!/usr/bin/env python
# coding: utf-8
import os
os.chdir("/data/lodhar2/GSViT")
import pandas as pd
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import Compose, CenterCrop, RandomCrop, Resize, RandomHorizontalFlip, RandomRotation, ToTensor
from sklearn.metrics import balanced_accuracy_score

from EfficientViT.classification.model.build import EfficientViT_M5



In [2]:
# ——————————————
# 1. channel-flip to match GSViT’s BGR convention
def process_inputs(images):
    tmp = images[:, 0, :, :].clone()
    images[:, 0, :, :] = images[:, 2, :, :]
    images[:, 2, :, :] = tmp
    return images

In [3]:
# ——————————————
# 2. Model wrapper with new head, filtering only the `evit.` keys
class GSViTWithHead(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # build backbone and strip its old classifier
        backbone_pre = EfficientViT_M5(pretrained=None)
        backbone = nn.Sequential(*list(backbone_pre.children())[:-1])

        # load checkpoint & keep only the "evit." entries
        ckpt = torch.load("GSViT.pkl", map_location="cpu")
        filtered = {
            k[len("evit."):]: v
            for k,v in ckpt.items()
            if k.startswith("evit.")
        }
        backbone.load_state_dict(filtered, strict=False)
        self.backbone = backbone

        # figure out feature‐dim on a dummy
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 224, 224)
            feat = self.backbone(dummy)
            feat_dim = feat.view(1, -1).shape[1]

        # new classification head
        self.classifier = nn.Linear(feat_dim, num_classes)

    def forward(self, x):
        x = process_inputs(x)
        feats = self.backbone(x).view(x.size(0), -1)
        return self.classifier(feats)

In [4]:
# ——————————————
# 3. Dataset → yields (image_tensor, label_idx)
class HistologyDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.root = root_dir
        self.transform = transform
        self.classes = sorted(self.df["Class"].unique())
        self.class_to_idx = {c:i for i,c in enumerate(self.classes)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = os.path.join(self.root, row["img_path"])
        img  = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.class_to_idx[row["Class"]]
        return img, label

In [5]:
df = pd.read_csv("split_balanced_dataset.csv")
#df['Class'] = np.where(df['Class'] == "Clear_cell", "Clear_cell", "Other")

In [6]:
train_df = df[df["split"]=="train"]
val_df   = df[df["split"]=="val"]

# === Transforms ===
train_transform = Compose([
    CenterCrop(1080),
    RandomCrop(768),
    RandomHorizontalFlip(),
    RandomRotation(degrees=15),
    Resize((224, 224)),
    ToTensor(),
])

val_transform = Compose([
    CenterCrop(768),
    Resize((224, 224)),
    ToTensor(),
])

# === Dataloaders ===
train_loader = DataLoader(
    HistologyDataset(train_df, ".", transform=train_transform),
    batch_size=16, shuffle=True,  num_workers=4
)
val_loader = DataLoader(
    HistologyDataset(val_df, ".", transform=val_transform),
    batch_size=16, shuffle=False, num_workers=4
)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model  = GSViTWithHead(num_classes=len(train_df["Class"].unique())).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [8]:
best_acc = 0.0
for epoch in range(1, 11):
    # — train —
    model.train()
    total_loss = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss   = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} train loss: {total_loss/len(train_loader):.4f}")

    # — validate —
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs)
            preds  = logits.argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    # Flatten predictions and labels
    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()

    acc = balanced_accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch} val balanced acc: {acc:.4f}")

    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "best_gsvit.pth")
        print("saved new best model")

Epoch 1 train loss: 9.6432
Epoch 1 val balanced acc: 0.1558
saved new best model
Epoch 2 train loss: 5.5558
Epoch 2 val balanced acc: 0.2027
saved new best model
Epoch 3 train loss: 5.6478
Epoch 3 val balanced acc: 0.1483
Epoch 4 train loss: 6.5286
Epoch 4 val balanced acc: 0.1653
Epoch 5 train loss: 4.7867
Epoch 5 val balanced acc: 0.1738
Epoch 6 train loss: 6.3377
Epoch 6 val balanced acc: 0.1780
Epoch 7 train loss: 4.9505
Epoch 7 val balanced acc: 0.1639
Epoch 8 train loss: 4.8724
Epoch 8 val balanced acc: 0.1863
Epoch 9 train loss: 4.0978
Epoch 9 val balanced acc: 0.1754
Epoch 10 train loss: 4.0055
Epoch 10 val balanced acc: 0.1812
