In [None]:
# hybrid_supcon_234_balanced_viz_messidor_convnext_mobilevit.py
"""
SupCon pretrain on classes [2,3,4] with balanced batches + weighted SupCon loss,
then fine-tune whole hybrid model on all 5 classes with MixUp + Focal Loss,
gradual unfreeze and multi-GPU DataParallel for the classifier stage.

Swapped dataset -> Messidor
Swapped models -> ConvNeXt + MobileViT
"""

import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, ConfusionMatrixDisplay, roc_auc_score, RocCurveDisplay
import timm
import matplotlib.pyplot as plt

# ----------------------------
# CONFIG (Messidor)
# ----------------------------
CSV_PATH = "/kaggle/input/messidor2preprocess/messidor_data.csv"
IMG_DIR = "/kaggle/input/messidor2preprocess/messidor-2/messidor-2/preprocess"

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# SupCon stage
SUPCON_CLASSES = [1, 2, 3, 4]   # keep same as previous (pretrain on these)
SUPCON_IMG = 224
SUPCON_EPOCHS = 50
SUPCON_BPC = 4
SUPCON_LR = 1e-4
SUPCON_CLASS_WEIGHTS = torch.tensor([1.5, 1.0, 1.5, 1.5])  # emphasize some classes

# Classifier stage
IMG_SIZE = 380
BATCH_SIZE = 16
EPOCHS = 40
LR = 1e-4
USE_AMP = True
MIXUP_ALPHA = 0.4
SAVE_PATH = "best_hybrid_supcon234_finetune_messidor.pth"
NUM_WORKERS = 2
PIN_MEMORY = True

FREEZE_SCHEDULE = {
    2: [],
    5: ["model_cnx.stages.3", "model_mv.blocks.11"],
    8: ["model_cnx.stages.2", "model_mv.blocks.10"],
    12: ["model_cnx.stages.1", "model_mv.blocks.9"]
}
FOCAL_GAMMA = 2.0

# ----------------------------
# Dataset
# ----------------------------
class MessidorDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, label_col="diagnosis"):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.label_col = label_col

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # messidor id_code already contains extension (e.g. 'xyz.jpg')
        img_path = os.path.join(self.img_dir, row["id_code"])
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = int(row[self.label_col])
        return img, label

# ----------------------------
# Balanced Batch Sampler
# ----------------------------
class BalancedBatchSampler(Sampler):
    def __init__(self, groups, bpc):
        # groups: list of lists of indices per class
        self.groups = [list(g)[:] for g in groups]
        self.bpc = int(bpc)
        self.num_classes = len(groups)
        # how many batches we can form per epoch
        self._batches_per_epoch = min(len(g) // self.bpc for g in self.groups) if len(self.groups) > 0 else 0

    def __iter__(self):
        pools = [g[:] for g in self.groups]
        for p in pools:
            random.shuffle(p)
        ptr = [0] * self.num_classes
        for _ in range(self._batches_per_epoch):
            batch = []
            for c in range(self.num_classes):
                start = ptr[c]
                batch.extend(pools[c][start:start + self.bpc])
                ptr[c] += self.bpc
            random.shuffle(batch)
            yield batch

    def __len__(self):
        return self._batches_per_epoch

# ----------------------------
# Models (ConvNeXt + MobileViT)
# ----------------------------
class HybridModel(nn.Module):
    def __init__(self, num_classes=5, pretrained=True, mv_name="mobilevitv2_100", cnx_name="convnext_base"):
        super().__init__()
        # create feature-only backbones (num_classes=0)
        self.model_mv = timm.create_model(mv_name, pretrained=pretrained, num_classes=0, global_pool="avg")
        self.model_cnx = timm.create_model(cnx_name, pretrained=pretrained, num_classes=0, global_pool="avg")

        # read num_features (timm exposes .num_features for many models)
        f1 = getattr(self.model_mv, "num_features", None)
        f2 = getattr(self.model_cnx, "num_features", None)
        if f1 is None or f2 is None:
            # fallback: run a dummy forward with a single image to infer features
            self.model_mv.eval()
            self.model_cnx.eval()
            with torch.no_grad():
                dummy = torch.randn(1, 3, SUPCON_IMG, SUPCON_IMG)
                mv_feat = self.model_mv(dummy)
                cnx_feat = self.model_cnx(dummy)
            f1 = mv_feat.shape[1]
            f2 = cnx_feat.shape[1]

        total_features = f1 + f2

        self.classifier = nn.Sequential(
            nn.Linear(total_features, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x, return_features=False):
        f1 = self.model_mv(x)
        f2 = self.model_cnx(x)
        # Ensure fused dims align
        fused = torch.cat([f1, f2], dim=1)
        if return_features:
            return fused
        return self.classifier(fused)

# ----------------------------
# Losses
# ----------------------------
class WeightedSupConLoss(nn.Module):
    def __init__(self, temperature=0.07, class_weights=None):
        super().__init__()
        self.temperature = temperature
        self.class_weights = class_weights.float() if class_weights is not None else None

    def forward(self, features, labels):
        device = features.device
        B = features.shape[0]
        features = F.normalize(features, dim=1)
        labels = labels.view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)

        logits = torch.matmul(features, features.T) / self.temperature
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()

        diag_mask = torch.eye(B, device=device)
        exp_logits = torch.exp(logits) * (1 - diag_mask)

        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-9)

        positives = (mask - diag_mask)
        denom = positives.sum(1) + 1e-9
        mean_log_prob_pos = (positives * log_prob).sum(1) / denom

        if self.class_weights is not None:
            sample_weights = self.class_weights[labels.squeeze().long()].to(device)
            loss = -(sample_weights * mean_log_prob_pos).mean()
        else:
            loss = -mean_log_prob_pos.mean()
        return loss

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction="mean"):
        super().__init__()
        self.gamma = gamma
        self.alpha = torch.tensor(alpha, dtype=torch.float) if alpha is not None else None
        self.reduction = reduction

    def forward(self, inputs, targets):
        device = inputs.device
        logpt = F.log_softmax(inputs, dim=1)
        pt = torch.exp(logpt)
        targets = targets.long()
        logpt_t = logpt.gather(1, targets.unsqueeze(1)).squeeze(1)
        pt_t = pt.gather(1, targets.unsqueeze(1)).squeeze(1)
        if self.alpha is not None:
            alpha_t = self.alpha.to(device).gather(0, targets)
            loss = -alpha_t * ((1 - pt_t) ** self.gamma) * logpt_t
        else:
            loss = -((1 - pt_t) ** self.gamma) * logpt_t
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss

# ----------------------------
# MixUp
# ----------------------------
def mixup_data(x, y, alpha=MIXUP_ALPHA):
    lam = np.random.beta(alpha, alpha) if alpha > 0 else 1.0
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1.0 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1.0 - lam) * criterion(pred, y_b)

# ----------------------------
# Data prep (Messidor)
# ----------------------------
df = pd.read_csv(CSV_PATH)

# Ensure labels are ints 0–4 (if not already)
if df["diagnosis"].dtype not in ["int64", "int32"]:
    label_map = {cls: i for i, cls in enumerate(sorted(df["diagnosis"].unique()))}
    df["diagnosis"] = df["diagnosis"].map(label_map)

# SupCon subset
supcon_df = df[df["diagnosis"].isin(SUPCON_CLASSES)].reset_index(drop=True)
# Map to 0..len(SUPCON_CLASSES)-1 for supcon labels
supcon_df["supcon_label"] = supcon_df["diagnosis"].map({c: i for i, c in enumerate(SUPCON_CLASSES)})

supcon_transform = transforms.Compose([
    transforms.RandomResizedCrop(SUPCON_IMG, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2,0.2,0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
supcon_dataset = MessidorDataset(supcon_df, IMG_DIR, transform=supcon_transform, label_col="supcon_label")

indices_per_class = []
for class_idx in range(len(SUPCON_CLASSES)):
    idxs = supcon_df.index[supcon_df["supcon_label"] == class_idx].tolist()
    indices_per_class.append(idxs)

batch_sampler = BalancedBatchSampler(indices_per_class, bpc=SUPCON_BPC)
supcon_loader = DataLoader(supcon_dataset, batch_sampler=batch_sampler, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

# ----------------------------
# SupCon Training
# ----------------------------
model_supcon = HybridModel(num_classes=5, pretrained=True).to(DEVICE)
supcon_loss_fn = WeightedSupConLoss(temperature=0.07, class_weights=SUPCON_CLASS_WEIGHTS.to(DEVICE))
optimizer_supcon = optim.AdamW(model_supcon.parameters(), lr=SUPCON_LR, weight_decay=1e-4)
scaler_supcon = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))

print("=== SupCon pretraining on classes", SUPCON_CLASSES, "===")
for epoch in range(SUPCON_EPOCHS):
    model_supcon.train()
    running_loss = 0.0
    loop = tqdm(supcon_loader, desc=f"SupCon Epoch {epoch+1}/{SUPCON_EPOCHS}")
    for imgs, labels in loop:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer_supcon.zero_grad()
        with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
            feats = model_supcon(imgs, return_features=True)
            loss = supcon_loss_fn(feats, labels)
        scaler_supcon.scale(loss).backward()
        scaler_supcon.step(optimizer_supcon)
        scaler_supcon.update()
        running_loss += loss.item()
        loop.set_postfix(loss=loss.item())
    avg_loss = running_loss / max(1, len(supcon_loader))
    print(f"SupCon Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}")

torch.save(model_supcon.state_dict(), "hybrid_supcon234_pretrained_messidor.pth")
print("Saved SupCon weights.")

# ----------------------------
# Classifier Training
# ----------------------------
train_df, val_df = train_test_split(df, test_size=0.2, stratify=df["diagnosis"], random_state=SEED)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_dataset = MessidorDataset(train_df, IMG_DIR, transform=train_transform)
val_dataset = MessidorDataset(val_df, IMG_DIR, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

model = HybridModel(num_classes=5, pretrained=True).to(DEVICE)
state = torch.load("hybrid_supcon234_pretrained_messidor.pth", map_location=DEVICE)
# load where keys match; we used strict=False to allow shape differences if any
model.load_state_dict(state, strict=False)

class_counts = train_df["diagnosis"].value_counts().sort_index().values
class_weights = (class_counts.sum() / (len(class_counts) * class_counts)).astype(np.float32)
alpha = torch.tensor(class_weights, dtype=torch.float32)
criterion_cls = FocalLoss(gamma=FOCAL_GAMMA, alpha=alpha)
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs.")
    model = nn.DataParallel(model)

# By default freeze all except classifier
for name, p in model.named_parameters():
    p.requires_grad = ("classifier" in name)

def gradual_unfreeze(model_obj, epoch, schedule):
    if epoch in schedule:
        patterns = schedule[epoch]
        for name, p in model_obj.named_parameters():
            if any(pat in name for pat in patterns):
                p.requires_grad = True
        print(f"[Unfreeze] Epoch {epoch}: {patterns}")

# ----------------------------
# Track metrics for plotting
# ----------------------------
train_losses, val_losses = [], []
train_accs, val_accs = [], []

best_macro_f1 = -1.0
print("=== Classifier training ===")
for epoch in range(EPOCHS):
    gradual_unfreeze(model, epoch, FREEZE_SCHEDULE)

    model.train()
    running_loss = 0.0
    preds_all, targets_all = [], []
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]", leave=False):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        imgs_mixed, ya, yb, lam = mixup_data(imgs, labels)
        with torch.cuda.amp.autocast(enabled=USE_AMP):
            outputs = model(imgs_mixed)
            loss = mixup_criterion(criterion_cls, outputs, ya, yb, lam)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
        preds_all.extend(outputs.argmax(1).detach().cpu().numpy())
        targets_all.extend(labels.cpu().numpy())
    train_acc = accuracy_score(targets_all, preds_all)
    train_loss_avg = running_loss / max(1, len(train_loader))
    train_losses.append(train_loss_avg)
    train_accs.append(train_acc)

    model.eval()
    val_preds, val_targets = [], []
    val_outputs_all = []
    val_loss_accum = 0.0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc="Validation", leave=False):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                outputs = model(imgs)
                loss_v = criterion_cls(outputs, labels)
            val_loss_accum += loss_v.item()
            val_preds.extend(outputs.argmax(1).cpu().numpy())
            val_targets.extend(labels.cpu().numpy())
            val_outputs_all.append(outputs.softmax(dim=1).cpu())
    val_loss_avg = val_loss_accum / max(1, len(val_loader))
    val_losses.append(val_loss_avg)
    val_acc = accuracy_score(val_targets, val_preds)
    val_accs.append(val_acc)

    from sklearn.metrics import f1_score
    macro_f1 = f1_score(val_targets, val_preds, average="macro")
    print(f"Epoch {epoch+1} | Train Loss {train_loss_avg:.4f} Acc {train_acc:.4f} | Val Loss {val_loss_avg:.4f} Acc {val_acc:.4f} | F1 {macro_f1:.4f}")

    if macro_f1 > best_macro_f1:
        best_macro_f1 = macro_f1
        if isinstance(model, nn.DataParallel):
            torch.save(model.module.state_dict(), SAVE_PATH)
        else:
            torch.save(model.state_dict(), SAVE_PATH)
        print(f"Saved best model (F1={macro_f1:.4f})")

    scheduler.step()

# ----------------------------
# Plot metrics
# ----------------------------
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(range(1,EPOCHS+1), train_losses, label="Train Loss")
plt.plot(range(1,EPOCHS+1), val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss vs Epoch")
plt.legend()
plt.subplot(1,2,2)
plt.plot(range(1,EPOCHS+1), train_accs, label="Train Acc")
plt.plot(range(1,EPOCHS+1), val_accs, label="Val Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy vs Epoch")
plt.legend()
plt.show()

# ----------------------------
# Final eval & ROC-AUC curves
# ----------------------------
best_state = torch.load(SAVE_PATH, map_location=DEVICE)
final_model = HybridModel(num_classes=5, pretrained=True).to(DEVICE)
final_model.load_state_dict(best_state, strict=False)
final_model.eval()

val_preds, val_targets = [], []
val_probs = []
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        outputs = final_model(imgs)
        val_preds.extend(outputs.argmax(1).cpu().numpy())
        val_targets.extend(labels.cpu().numpy())
        val_probs.append(outputs.softmax(dim=1).cpu())
val_probs = torch.cat(val_probs, dim=0).numpy()
val_targets_np = np.array(val_targets)

print("\nFinal Report:")
print(classification_report(val_targets, val_preds, digits=4))
cm = confusion_matrix(val_targets, val_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(range(5)))
disp.plot(cmap="Blues", xticks_rotation=45)
plt.show()

# ROC-AUC per class
plt.figure(figsize=(10,8))
for i in range(5):
    RocCurveDisplay.from_predictions((val_targets_np==i).astype(int), val_probs[:,i], name=f"Class {i}")
plt.plot([0,1],[0,1],"k--")
plt.title("ROC Curves per Class")
plt.show()

# ----------------------------
# Visualize predictions on validation images
# ----------------------------
def show_predictions(model, dataset, num_images=8):
    model.eval()
    indices = np.random.choice(len(dataset), size=num_images, replace=False)
    plt.figure(figsize=(15,5))
    for i, idx in enumerate(indices):
        img, label = dataset[idx]
        inp = img.unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            out = model(inp)
            pred = out.argmax(1).item()
        img_np = img.permute(1,2,0).numpy()
        img_np = np.clip(img_np * [0.229,0.224,0.225] + [0.485,0.456,0.406], 0,1)
        plt.subplot(2, num_images//2, i+1)
        plt.imshow(img_np)
        plt.title(f"GT: {label} | Pred: {pred}")
        plt.axis("off")
    plt.show()

show_predictions(final_model, val_dataset, num_images=8)

torch.save(final_model.state_dict(), SAVE_PATH.replace(".pth", "_final.pth"))
print("Final model saved.")
