In [None]:
# === PART 0: INSTALL DEPENDENCIES (KAGGLE) ============================
# What this does:
# - Installs auxiliary libs (NOT external grad-cam).
# - PyTorch stays as preinstalled (Kaggle).
# ======================================================================
%pip install -q timm umap-learn
# (optional) %pip install -q torchcam


In [None]:
# === PART 1: IMPORTS & GLOBAL CONFIG (ROBUST + BUILT-IN GRAD-CAM) =====
# What this does:
# - Imports libs and sets paths/hyperparams (100+100 epochs, freezer OFF).
# - Adds a built-in Grad-CAM fallback (no external install needed).
# - Kaggle-safe: NUM_WORKERS=0 and throttled tqdm to avoid timeouts.
# ======================================================================
import os, random, time, math, shutil
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid, save_image

import timm

from sklearn.metrics import (classification_report, confusion_matrix, roc_curve, auc)
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, silhouette_score

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.cm as mpl_cm
from tqdm.auto import tqdm

# ---------- Built-in Grad-CAM (pure PyTorch) ----------
def _overlay_cam_on_image_rgb(raw_rgb, grayscale_cam, alpha=0.35):
    """Blend a [H,W,3] RGB float (0..1) with a [H,W] CAM heatmap."""
    raw = np.asarray(raw_rgb, dtype=np.float32)
    if raw.max() > 1.0: raw = raw / 255.0
    gray = np.asarray(grayscale_cam, dtype=np.float32)
    if gray.ndim == 3 and gray.shape[0] == 1: gray = gray[0]
    gray = np.clip(gray, 0.0, 1.0)
    heat = mpl_cm.get_cmap('jet')(gray)[..., :3]
    out = (1 - alpha) * raw + alpha * heat
    return (np.clip(out, 0, 1) * 255).astype(np.uint8)

class _SimpleGradCAM:
    """Minimal Grad-CAM for a single target layer."""
    def __init__(self, model, target_layer):
        self.model, self.layer = model, target_layer
        self.activations, self.gradients = None, None
        self.h_fwd = self.layer.register_forward_hook(self._save_act)
        if hasattr(self.layer, 'register_full_backward_hook'):
            self.h_bwd = self.layer.register_full_backward_hook(self._save_grad)
        else:
            self.h_bwd = self.layer.register_backward_hook(self._save_grad_depr)
    def _save_act(self, mod, inp, out): self.activations = out
    def _save_grad(self, mod, gin, gout): self.gradients = gout[0]
    def _save_grad_depr(self, mod, gin, gout): self.gradients = gout[0]
    @torch.no_grad()
    def _normalize(self, cam):
        cam = torch.relu(cam)
        cam_min = cam.amin(dim=(1,2), keepdim=True)
        cam_max = cam.amax(dim=(1,2), keepdim=True)
        return (cam - cam_min) / (cam_max - cam_min + 1e-6)
    def __call__(self, input_tensor, class_idx=None):
        self.model.zero_grad(set_to_none=True)
        input_tensor.requires_grad_(True)
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(1).item()
        loss = output[:, class_idx].sum()
        loss.backward()
        grads, acts = self.gradients, self.activations
        weights = grads.mean(dim=(2,3), keepdim=True)
        cam = (weights * acts).sum(dim=1)
        return self._normalize(cam).detach().cpu().numpy()

def make_cam_runner(model, target_layer):
    engine = _SimpleGradCAM(model, target_layer)
    def _runner(input_tensor, class_idx):
        return engine(input_tensor, int(class_idx))[0]
    return _runner

# ------------------ Device/seed/paths/hparams -------------------------
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

IS_KAGGLE = os.path.exists("/kaggle")
RUN_NAME = "moco_rice38_70_20_10"

# EDIT if your dataset folder differs:
DATA_DIR = Path("/kaggle/input/riceds-original/Original") if IS_KAGGLE else Path("riceds-original/Original")
WORK_DIR = Path(f"/kaggle/working/runs/{RUN_NAME}") if IS_KAGGLE else Path(f"runs/{RUN_NAME}")
WORK_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = 224
BATCH_SIZE = 16

# Kaggle-safe loader settings (fixes multiprocessing shutdown AssertionError)
NUM_WORKERS = 0
PIN_MEMORY = (DEVICE == 'cuda')
PERSISTENT = False

NUM_CLASSES = 38

# MoCo pretrain + Supervised finetune epochs
MOCO_EPOCHS = 100
FINETUNE_EPOCHS = 100

# Supervised training hparams
BASE_LR = 3e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 1000          # effectively off
LABEL_SMOOTH = 0.05
MIXUP_ALPHA = 0.2
CUTMIX_ALPHA = 0.0
TTA_TIMES = 4

# MoCo-specific hparams
PROJ_DIM = 256
QUEUE_SIZE = 65536
MOMENTUM = 0.999
TEMP = 0.2
SSL_LR = 0.2
SSL_WD = 1e-4
SSL_MOM = 0.9

# Throttle tqdm to reduce Jupyter IO stalls
TQDM_KW = dict(dynamic_ncols=True, leave=False, mininterval=10.0)


In [None]:
# === PART 2: DATA SCAN & STRATIFIED SPLIT (70/10/20) ==================
# What this does:
# - Collects image paths/labels from class subfolders.
# - Stratified split: 20% test from full; then 12.5% of remaining 80% as val ⇒ 10% absolute.
# - Leaves 70% for train.
# ======================================================================
assert DATA_DIR.exists(), f"Path not found: {DATA_DIR.resolve()}"

classes = sorted([p.name for p in DATA_DIR.iterdir() if p.is_dir()])
class_to_idx = {c:i for i,c in enumerate(classes)}
assert len(classes) == NUM_CLASSES, f"Expected {NUM_CLASSES} classes, found {len(classes)}"

all_paths, all_labels = [], []
for c in classes:
    for imgp in (DATA_DIR/c).glob("*"):
        if imgp.suffix.lower() in {".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"}:
            all_paths.append(str(imgp))
            all_labels.append(class_to_idx[c])

all_paths = np.array(all_paths)
all_labels = np.array(all_labels)
print("Total images:", len(all_paths))

# 20% test from full
sss_outer = StratifiedShuffleSplit(n_splits=1, test_size=0.20, random_state=SEED)
trainval_idx, test_idx = next(sss_outer.split(all_paths, all_labels))
X_trainval, y_trainval = all_paths[trainval_idx], all_labels[trainval_idx]
X_test,     y_test     = all_paths[test_idx],    all_labels[test_idx]

# From remaining 80%, take 12.5% as val (=> 10% absolute)
sss_inner = StratifiedShuffleSplit(n_splits=1, test_size=0.125, random_state=SEED)
train_idx, val_idx = next(sss_inner.split(X_trainval, y_trainval))
X_train, y_train = X_trainval[train_idx], y_trainval[train_idx]
X_val,   y_val   = X_trainval[val_idx],   y_trainval[val_idx]

print(f"Split -> train: {len(X_train)} | val: {len(X_val)} | test: {len(X_test)}")


In [None]:
# === PART 3: DATASETS, AUGS, DATALOADERS (MoCo + Supervised) ==========
# What this does:
# - MoCo two-view dataset on train split; labeled datasets for train/val/test.
# - Strong augs for SSL & train; light augs for val/test.
# - WeightedRandomSampler to mitigate class imbalance (for supervised).
# - DataLoader settings avoid multiprocessing teardown errors on Kaggle.
# ======================================================================
class LabeledDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = list(paths); self.labels = list(labels); self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]; y = self.labels[idx]
        img = Image.open(p).convert("RGB")
        if self.transform: img = self.transform(img)
        return img, y, p

class TwoCropsBYOL(Dataset):
    """Returns two independently augmented views of the same image (x1, x2)."""
    def __init__(self, paths, transform1, transform2):
        self.paths = list(paths); self.t1 = transform1; self.t2 = transform2
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        return self.t1(img), self.t2(img)

# SSL augs (MoCo v2-style strong augs)
ssl_view1 = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.4,0.4,0.4,0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=21, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])
ssl_view2 = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.2,0.2,0.2,0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=21, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])

train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.2,0.2,0.2,0.1)], p=0.6),
    transforms.RandomAutocontrast(p=0.3),
    transforms.RandomAdjustSharpness(1.5, p=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])

test_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])

# Datasets
ssl_ds   = TwoCropsBYOL(X_train, ssl_view1, ssl_view2)   # unlabeled SSL on train
train_ds = LabeledDataset(X_train, y_train, train_tfms)
val_ds   = LabeledDataset(X_val,   y_val,   test_tfms)
test_ds  = LabeledDataset(X_test,  y_test,  test_tfms)

# Weighted sampler for supervised training
class_counts = Counter(y_train)
weights = [1.0/class_counts[y] for y in y_train]
sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

# Loaders (NUM_WORKERS=0 & persistent=False to avoid teardown AssertionError)
ssl_loader   = DataLoader(ssl_ds,   batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, 
                          pin_memory=PIN_MEMORY, drop_last=True,  persistent_workers=PERSISTENT)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS, 
                          pin_memory=PIN_MEMORY, drop_last=True,  persistent_workers=PERSISTENT)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, 
                          pin_memory=PIN_MEMORY,                persistent_workers=PERSISTENT)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, 
                          pin_memory=PIN_MEMORY,                persistent_workers=PERSISTENT)

# (Optional) quick batch viz
def show_batch(n=24, title="Supervised Train samples (augmented)"):
    x, y, _ = next(iter(DataLoader(train_ds, batch_size=n, shuffle=True, num_workers=0)))
    grid = make_grid(x[:n], nrow=6, normalize=True, value_range=(-2,2))
    plt.figure(figsize=(10,10)); plt.imshow(grid.permute(1,2,0).cpu()); plt.axis('off'); plt.title(title); plt.show()

# show_batch()  # uncomment if you want a grid


In [None]:
# === PART 4A: MoCo PRETRAINING ON TRAIN SPLIT =========================
# What this does:
# - Implements MoCo v2 (query/key encoders, momentum update, queue, InfoNCE).
# - Trains self-supervised for MOCO_EPOCHS on the 70% train split.
# - Saves the query encoder weights to moco_encoder.pt for fine-tuning.
# ======================================================================
class ProjectionMLP(nn.Module):
    def __init__(self, in_dim, hid=2048, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid), nn.BatchNorm1d(hid), nn.ReLU(inplace=True),
            nn.Linear(hid, out_dim), nn.BatchNorm1d(out_dim)
        )
    def forward(self, x): return self.net(x)

@torch.no_grad()
def _update_momentum(q_model, k_model, m=MOMENTUM):
    for p_q, p_k in zip(q_model.parameters(), k_model.parameters()):
        p_k.data = p_k.data * m + p_q.data * (1. - m)

def _normalize(x): return F.normalize(x, dim=1)

feat_dim = 1280  # EfficientNet-B0 feature size
proj_dim = PROJ_DIM

def make_encoder():
    enc = timm.create_model('tf_efficientnet_b0_ns', pretrained=True, num_classes=0)  # features
    return enc

encoder_q = make_encoder().to(DEVICE)
encoder_k = make_encoder().to(DEVICE)
projector_q = ProjectionMLP(feat_dim, 2048, proj_dim).to(DEVICE)
projector_k = ProjectionMLP(feat_dim, 2048, proj_dim).to(DEVICE)

# Init key encoder with query encoder weights; freeze key side
encoder_k.load_state_dict(encoder_q.state_dict(), strict=False)
projector_k.load_state_dict(projector_q.state_dict(), strict=False)
for p in list(encoder_k.parameters()) + list(projector_k.parameters()):
    p.requires_grad = False

# MoCo queue (features are column-major: [C, K] for fast matmul)
queue = torch.randn(proj_dim, QUEUE_SIZE, device=DEVICE)
queue = F.normalize(queue, dim=0)
queue_ptr = torch.zeros(1, dtype=torch.long, device=DEVICE)

@torch.no_grad()
def _dequeue_and_enqueue(keys):
    """keys: [B, C] l2-normalized"""
    global queue, queue_ptr
    bsz = keys.shape[0]
    ptr = int(queue_ptr.item())
    queue[:, ptr:ptr + bsz] = keys.T
    ptr = (ptr + bsz) % QUEUE_SIZE
    queue_ptr[0] = ptr

opt_ssl = torch.optim.SGD(
    list(encoder_q.parameters()) + list(projector_q.parameters()),
    lr=SSL_LR, momentum=SSL_MOM, weight_decay=SSL_WD
)
scaler_ssl = torch.cuda.amp.GradScaler(enabled=(DEVICE=='cuda'))

print("Starting MoCo pretraining...")
encoder_q.train(); projector_q.train()
encoder_k.eval();  projector_k.eval()
for epoch in range(1, MOCO_EPOCHS+1):
    running = 0.0
    pbar = tqdm(ssl_loader, desc=f"MoCo Epoch {epoch}/{MOCO_EPOCHS}", **TQDM_KW)
    for x1, x2 in pbar:
        x1, x2 = x1.to(DEVICE), x2.to(DEVICE)
        with torch.cuda.amp.autocast(enabled=(DEVICE=='cuda')):
            # Query path
            q_feat = encoder_q(x1)
            if isinstance(q_feat, (tuple, list)): q_feat = q_feat[0]
            q = _normalize(projector_q(q_feat.flatten(1)))

            with torch.no_grad():
                # Update key encoder via momentum BEFORE computing keys
                _update_momentum(encoder_q, encoder_k, m=MOMENTUM)
                _update_momentum(projector_q, projector_k, m=MOMENTUM)
                k_feat = encoder_k(x2)
                if isinstance(k_feat, (tuple, list)): k_feat = k_feat[0]
                k = _normalize(projector_k(k_feat.flatten(1)))

            # MoCo InfoNCE logits: [B, 1 + K]
            l_pos = (q * k).sum(dim=1, keepdim=True)         # [B,1]
            l_neg = torch.matmul(q, queue)                   # [B,K]
            logits = torch.cat([l_pos, l_neg], dim=1) / TEMP
            labels = torch.zeros(logits.size(0), dtype=torch.long, device=DEVICE)  # positives at index 0

            loss = F.cross_entropy(logits, labels)

        scaler_ssl.scale(loss).backward()
        scaler_ssl.step(opt_ssl); scaler_ssl.update()
        opt_ssl.zero_grad(set_to_none=True)

        with torch.no_grad():
            _dequeue_and_enqueue(k)

        running += loss.item() * x1.size(0)

    epoch_loss = running / len(ssl_ds)
    print(f"MoCo epoch {epoch}: loss={epoch_loss:.4f}")

# Save the query encoder weights for fine-tuning
moco_ckpt_path = WORK_DIR / "moco_encoder.pt"
torch.save({"encoder": encoder_q.state_dict()}, moco_ckpt_path)
print("Saved MoCo encoder to:", moco_ckpt_path)


In [None]:
# === PART 4B: SUPERVISED MODEL + OPTIMIZER + SCHEDULER + EMA ==========
# What this does:
# - Builds EfficientNet-B0 classifier (38 classes).
# - Loads MoCo-pretrained query encoder weights where names match.
# - FREEZER OFF: all layers trainable.
# - AdamW + cosine LR with warmup; label smoothing; EMA tracker.
# ======================================================================
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {n:p.detach().clone() for n,p in model.named_parameters() if p.requires_grad}
    def update(self, model):
        for n,p in model.named_parameters():
            if p.requires_grad:
                self.shadow[n] = (1-self.decay)*p.detach() + self.decay*self.shadow[n]
    def apply_to(self, model):
        with torch.no_grad():
            for n,p in model.named_parameters():
                if p.requires_grad:
                    p.data.copy_(self.shadow[n].data)

model = timm.create_model('tf_efficientnet_b0_ns', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)

# Load MoCo (query) encoder weights
ckpt = torch.load(moco_ckpt_path, map_location=DEVICE)
enc = ckpt["encoder"]
model_dict = model.state_dict()
loadable = {k:v for k,v in enc.items() if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(loadable)
model.load_state_dict(model_dict)
print(f"Loaded {len(loadable)} MoCo params into classifier backbone.")

# FREEZER OFF
for p in model.parameters(): p.requires_grad = True
print("All layers trainable.")

# Optimizer + cosine LR warmup
optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=WEIGHT_DECAY)
total_steps = FINETUNE_EPOCHS * max(1, len(train_loader))
warmup_steps = int(0.1 * total_steps)
def lr_fn(step):
    if step < warmup_steps: return (step+1)/max(1,warmup_steps)
    p = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return 0.5*(1 + math.cos(math.pi * p))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)

scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=='cuda'))
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)
ema = EMA(model, decay=0.999)


In [None]:
# === PART 5: MIXUP / METRICS HELPERS =================================
# What this does:
# - Provides MixUp utilities (regularization for small datasets).
# - Accuracy convenience function for quick evaluation.
# ======================================================================
def mixup_data(x, y, alpha):
    if alpha <= 0: 
        return x, y.float(), 1.0, torch.arange(len(y))
    lam = np.random.beta(alpha, alpha)
    bs = x.size(0)
    idx = torch.randperm(bs, device=x.device)
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, (y_a, y_b), lam, idx

def mixup_criterion(criterion, pred, targets, lam):
    y_a, y_b = targets
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def accuracy(output, target):
    return (output.argmax(1) == target).float().mean().item()


In [None]:
# === PART 6: SUPERVISED TRAINING + VALIDATION + ROBUST CKPTS ==========
# What this does:
# - Trains for FINETUNE_EPOCHS with AMP, MixUp, EMA updates (FREEZER OFF).
# - Validates each epoch on the 10% validation split.
# - Saves 'last_model.pt' EVERY epoch; saves 'best_model.pt' on val improv.
# - Plots loss curve + val acc curve.
# ======================================================================
best_val = -1.0
best_path = WORK_DIR / "best_model.pt"
last_path = WORK_DIR / "last_model.pt"
history = {"train_loss":[], "val_loss":[], "val_acc":[]}
no_improve = 0
global_step = 0
CLIP_NORM = 1.0

for epoch in range(1, FINETUNE_EPOCHS+1):
    model.train()
    running = 0.0
    pbar = tqdm(train_loader, desc=f"FT Epoch {epoch}/{FINETUNE_EPOCHS}", **TQDM_KW)
    for xb, yb, _ in pbar:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        with torch.cuda.amp.autocast(enabled=(DEVICE=='cuda')):
            if MIXUP_ALPHA>0 or CUTMIX_ALPHA>0:
                xb, targets, lam, _ = mixup_data(xb, yb, MIXUP_ALPHA if MIXUP_ALPHA>0 else CUTMIX_ALPHA)
                out = model(xb)
                loss = mixup_criterion(criterion, out, targets, lam)
            else:
                out = model(xb)
                loss = criterion(out, yb)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
        scaler.step(optimizer); scaler.update()
        optimizer.zero_grad(set_to_none=True)
        ema.update(model)
        running += loss.item() * xb.size(0)
        scheduler.step(); global_step += 1

    train_loss = running / len(train_ds)
    history["train_loss"].append(train_loss)

    # ---- Validation with EMA weights
    model_ema_eval = timm.create_model('tf_efficientnet_b0_ns', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
    model_ema_eval.load_state_dict(model.state_dict())
    ema.apply_to(model_ema_eval)
    model_ema_eval.eval()

    val_loss, val_acc = 0.0, 0.0
    with torch.no_grad():
        for xb, yb, _ in val_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            out = model_ema_eval(xb)
            loss = criterion(out, yb)
            val_loss += loss.item() * xb.size(0)
            val_acc  += (out.argmax(1) == yb).float().sum().item()

    val_loss /= len(val_ds)
    val_acc  /= len(val_ds)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    # Save LAST (EMA) every epoch
    torch.save({"model": model_ema_eval.state_dict(),
                "classes": classes,
                "cfg": dict(NUM_CLASSES=NUM_CLASSES, IMG_SIZE=IMG_SIZE),
                "epoch": epoch,
                "val_acc": val_acc},
               last_path)
    print(f"Epoch {epoch}: train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.4f}")
    print(f"  -> saved LAST to {last_path}")

    # Save BEST if improved
    if val_acc > best_val:
        best_val = val_acc
        no_improve = 0
        shutil.copy2(last_path, best_path)
        print(f"  -> updated BEST to {best_path} (val_acc={val_acc:.4f})")
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print("Early stopping.")
            break

# Ensure BEST exists even if no improvement was recorded
if not best_path.exists() and last_path.exists():
    shutil.copy2(last_path, best_path)
    print(f"No best found; copied LAST -> BEST:\n  {last_path} -> {best_path}")

# Plot curves
plt.figure(figsize=(10,4))
plt.subplot(1,2,1); plt.plot(history["train_loss"], label="train"); plt.plot(history["val_loss"], label="val"); plt.title("Loss"); plt.legend()
plt.subplot(1,2,2); plt.plot(history["val_acc"], label="val acc"); plt.title("Val Accuracy"); plt.legend()
plt.tight_layout(); plt.show()


In [None]:
# === PART 7: MICRO OVERFIT CHECK (OPTIONAL) ===========================
# What this does:
# - Trains a fresh model on 8 samples for ~60 steps.
# - If it can memorize them (high acc), your training pipeline is sound.
# ======================================================================
def tiny_overfit_steps(steps=60):
    model_small = timm.create_model('tf_efficientnet_b0_ns', pretrained=True, num_classes=NUM_CLASSES).to(DEVICE)
    for p in model_small.parameters(): p.requires_grad = True
    optimizer = torch.optim.AdamW(model_small.parameters(), lr=1e-3)
    crit = nn.CrossEntropyLoss()
    model_small.train()
    xb, yb, _ = next(iter(train_loader))
    xb, yb = xb[:8].to(DEVICE), yb[:8].to(DEVICE)
    for i in range(steps):
        out = model_small(xb)
        loss = crit(out, yb)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        if (i+1)%10==0:
            acc = (out.argmax(1)==yb).float().mean().item()
            print(f"step {i+1}/{steps} loss {loss.item():.3f} acc {acc:.3f}")

# tiny_overfit_steps()  # uncomment to run once


In [None]:
# === PART 8: TEST EVAL + REPORTS + ROC/AUC + CONFUSION MATRIX =========
# What this does:
# - Loads BEST checkpoint; if missing, falls back to LAST.
# - Restores `classes` from ckpt if needed.
# - Runs TTA inference, prints classification report, plots CM & ROC.
# ======================================================================
best_path = WORK_DIR / "best_model.pt"
last_path = WORK_DIR / "last_model.pt"
ckpt_path = best_path if best_path.exists() else last_path
assert ckpt_path.exists(), f"Checkpoint not found.\nLooked for:\n  {best_path}\n  {last_path}\nRun Part 6 first."

print(f"Loading checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=DEVICE)
if "classes" not in globals():
    classes = ckpt.get("classes", None)
    assert classes is not None, "Classes not found in checkpoint and not defined in session."
    NUM_CLASSES = len(classes)

model_best = timm.create_model('tf_efficientnet_b0_ns', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
model_best.load_state_dict(ckpt["model"], strict=True)
model_best.eval()

def predict_with_tta(model, loader, tta=TTA_TIMES):
    all_probs, all_labels, all_paths = [], [], []
    aug = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9,1.0)),
    ])
    base = test_tfms
    for xb, yb, ps in tqdm(loader, desc="Testing", **TQDM_KW):
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        with torch.no_grad():
            logits_sum = 0
            for t in range(tta):
                if t == 0:
                    x_aug = xb
                else:
                    imgs = []
                    for timg in xb:
                        img = transforms.ToPILImage()(
                            (timg.cpu()*torch.tensor([0.229,0.224,0.225]).view(3,1,1)+
                             torch.tensor([0.485,0.456,0.406]).view(3,1,1)).clamp(0,1)
                        )
                        img = aug(img); img = base(img); imgs.append(img)
                    x_aug = torch.stack(imgs).to(DEVICE)
                logits_sum += model(x_aug)
            probs = (logits_sum/tta).softmax(1)

        all_probs.append(probs.cpu())
        all_labels.append(yb.cpu())
        all_paths += list(ps)

    return torch.cat(all_probs).numpy(), torch.cat(all_labels).numpy(), all_paths

test_probs, test_labels, test_paths = predict_with_tta(model_best, test_loader, tta=TTA_TIMES)
test_pred = test_probs.argmax(1)

print("\nClassification report:")
print(classification_report(test_labels, test_pred, target_names=classes, digits=3))

# Confusion matrix (do NOT name it 'cm' to avoid colormap shadowing)
conf_mat = confusion_matrix(test_labels, test_pred)
plt.figure(figsize=(12,10))
sns.heatmap(conf_mat, cmap="Blues", annot=False, fmt="d")
plt.title("Confusion Matrix"); plt.xlabel("Predicted"); plt.ylabel("True"); plt.show()

# ROC / AUC (one-vs-rest)
y_true_bin = label_binarize(test_labels, classes=list(range(NUM_CLASSES)))
fpr, tpr, roc_auc = {}, {}, {}
for i in range(NUM_CLASSES):
    fpr[i], tpr[i], _ = roc_curve(y_true_bin[:,i], test_probs[:,i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Micro/macro AUC
fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), test_probs.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
macro_auc = np.mean([roc_auc[i] for i in range(NUM_CLASSES)])
print(f"Macro AUC: {macro_auc:.4f} | Micro AUC: {roc_auc['micro']:.4f}")

# Plot micro + top-5 per-class by AUC
plt.figure(figsize=(8,6))
plt.plot(fpr["micro"], tpr["micro"], label=f"micro-average ROC (AUC = {roc_auc['micro']:.3f})")
top5 = np.argsort([-roc_auc[i] for i in range(NUM_CLASSES)])[:5]
for i in top5:
    plt.plot(fpr[i], tpr[i], lw=1, label=f"{classes[i]} (AUC={roc_auc[i]:.3f})")
plt.plot([0,1],[0,1],'--', lw=1, color='gray')
plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
plt.title("ROC Curves"); plt.legend(); plt.grid(alpha=0.2); plt.show()


In [None]:
# === PART 9: GRAD-CAM VISUAL EXPLANATIONS (SAME AS BYOL PART 9) =======
# What this does:
# - Uses the built-in Grad-CAM from Part 1 (no external lib).
# - Shows overlays for top confident corrects and mistakes.
# - Includes a hotfix to delete any accidental global 'cm' variable.
# ======================================================================
if 'cm' in globals():
    try: del cm
    except Exception: pass

target_layer = model_best.conv_head if hasattr(model_best, "conv_head") else list(model_best.modules())[-1]
cam_runner = make_cam_runner(model_best, target_layer)
print("Grad-CAM ready on layer:", type(target_layer).__name__)

def pick_examples(n=3):
    """Return indices of n most-confident corrects and n most-confident mistakes."""
    conf = test_probs.max(1)
    correct_mask = (test_pred == test_labels)
    wrong_mask = ~correct_mask
    idx_correct_pool = np.where(correct_mask)[0]
    idx_wrong_pool   = np.where(wrong_mask)[0]
    idx_correct_sorted = idx_correct_pool[np.argsort(-conf[idx_correct_pool])][:n]
    idx_wrong_sorted   = idx_wrong_pool[np.argsort(-conf[idx_wrong_pool])][:n]
    return list(idx_correct_sorted), list(idx_wrong_sorted)

good_ids, bad_ids = pick_examples(n=3)

def show_cam(idx_list, title):
    """Render Grad-CAM overlays for a list of test indices."""
    if not idx_list:
        print(f"{title}: no examples to show"); 
        return
    plt.figure(figsize=(12,4))
    for j, idx in enumerate(idx_list):
        p = test_paths[idx]
        raw = np.array(Image.open(p).convert("RGB").resize((IMG_SIZE, IMG_SIZE))).astype(np.float32) / 255.0
        x_cam = test_tfms(Image.fromarray((raw*255).astype(np.uint8))).unsqueeze(0).to(DEVICE)
        cls = int(test_pred[idx])
        grayscale = cam_runner(x_cam, cls)  # [H,W] 0..1
        vis = _overlay_cam_on_image_rgb(raw, grayscale, alpha=0.35)
        plt.subplot(1, len(idx_list), j+1)
        plt.imshow(vis); plt.axis('off')
        plt.title(f"pred={classes[test_pred[idx]]}\ntrue={classes[test_labels[idx]]}")
    plt.suptitle(title); plt.tight_layout(); plt.show()

show_cam(good_ids, "Grad-CAM — confident corrects")
show_cam(bad_ids,  "Grad-CAM — confident mistakes")


In [None]:
# === PART 10: CLUSTERING & 2D VISUALIZATION ===========================
# What this does:
# - Extracts penultimate-layer embeddings from the trained classifier.
# - KMeans clustering (k=38) with ARI/NMI/Silhouette metrics.
# - t-SNE plots: colored by TRUE labels and by predicted clusters.
# ======================================================================
ckpt_path = (WORK_DIR/"best_model.pt") if (WORK_DIR/"best_model.pt").exists() else (WORK_DIR/"last_model.pt")
ckpt = torch.load(ckpt_path, map_location=DEVICE)

feature_model = timm.create_model('tf_efficientnet_b0_ns', pretrained=False, num_classes=NUM_CLASSES)
feature_model.load_state_dict(ckpt["model"])
feature_model.classifier = nn.Identity()  # output embeddings
feature_model = feature_model.to(DEVICE).eval()

def extract_features(loader):
    feats, labels, paths = [], [], []
    with torch.no_grad():
        for xb, yb, ps in tqdm(loader, desc="Extract feats", **TQDM_KW):
            xb = xb.to(DEVICE)
            f = feature_model(xb)
            feats.append(f.cpu()); labels.append(yb); paths += list(ps)
    return torch.cat(feats).numpy(), torch.cat(labels).numpy(), paths

feats_test, labels_test, _ = extract_features(test_loader)

kmeans = KMeans(n_clusters=NUM_CLASSES, n_init=20, random_state=SEED)
clusters = kmeans.fit_predict(feats_test)

ari = adjusted_rand_score(labels_test, clusters)
nmi = normalized_mutual_info_score(labels_test, clusters)
sil = silhouette_score(feats_test, clusters, sample_size=min(1000, len(feats_test)))
print(f"Clustering -> ARI: {ari:.3f} | NMI: {nmi:.3f} | Silhouette: {sil:.3f}")

tsne = TSNE(n_components=2, init='pca', random_state=SEED, perplexity=min(30, max(5, len(feats_test)//5)))
Z = tsne.fit_transform(feats_test)

plt.figure(figsize=(9,7))
plt.scatter(Z[:,0], Z[:,1], c=labels_test, s=12, cmap='tab20')
plt.title("t-SNE (TRUE labels)"); plt.axis('off'); plt.show()

plt.figure(figsize=(9,7))
plt.scatter(Z[:,0], Z[:,1], c=clusters, s=12, cmap='tab20')
plt.title("t-SNE (KMeans clusters)"); plt.axis('off'); plt.show()


In [None]:
# === PART 11: SAVE PREDICTIONS + ERROR GALLERY ========================
# What this does:
# - Writes per-image results (path/true/pred/confidence) from TEST to CSV.
# - Saves a montage of the most-confident mistakes for quick QA.
# ======================================================================
df = pd.DataFrame({
    "path": test_paths,
    "true": [classes[i] for i in test_labels],
    "pred": [classes[i] for i in test_pred],
    "conf": test_probs.max(1)
})
csv_path = WORK_DIR/"test_predictions.csv"
df.to_csv(csv_path, index=False)
print("Saved:", csv_path)

def save_error_montage(df, k=25, out=WORK_DIR/"top_errors.jpg"):
    df_err = df[df.true!=df.pred].sort_values("conf", ascending=False).head(k)
    if df_err.empty:
        print("No errors to show."); return
    tiles = []
    for p, t, pr, c in df_err[["path","true","pred","conf"]].values:
        im = Image.open(p).convert("RGB").resize((IMG_SIZE, IMG_SIZE))
        tiles.append(transforms.ToTensor()(im))
    grid = make_grid(torch.stack(tiles), nrow=5)
    save_image(grid, out)
    print("Saved montage:", out)

save_error_montage(df)


In [None]:
# === PART 12: SINGLE-IMAGE INFERENCE + GRAD-CAM (ROBUST) ==============
# What this does:
# - Runs a single image with optional TTA and prints top-k predictions.
# - If true label is given, reports correctness.
# - Saves a Grad-CAM overlay using the built-in fallback (same as BYOL).
# ======================================================================
from typing import Optional, Union

def infer_single_image(
    image_path: str,
    model=None,
    topk: int = 5,
    tta: int = TTA_TIMES,
    true_label: Optional[Union[int, str]] = None,
    save_cam_path: Optional[str] = None,
):
    assert os.path.exists(image_path), f"Image not found: {image_path}"
    if model is None:
        model = model_best
    model.eval()

    base_pil = Image.open(image_path).convert("RGB")
    aug = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9,1.0)),
    ])

    xs = []
    for t in range(max(1, tta)):
        pil = base_pil if t == 0 else aug(base_pil)
        xs.append(test_tfms(pil))
    xb = torch.stack(xs, dim=0).to(DEVICE)

    with torch.no_grad():
        logits = model(xb)
        probs = logits.softmax(1).mean(0)

    confs, idxs = torch.topk(probs, k=min(topk, NUM_CLASSES))
    confs = confs.cpu().numpy(); idxs = idxs.cpu().numpy()
    preds = [(classes[i], float(c)) for i, c in zip(idxs, confs)]

    is_correct = None
    if true_label is not None:
        if isinstance(true_label, str):
            assert true_label in classes, f"Unknown class name: {true_label}"
            true_idx = classes.index(true_label)
        else:
            true_idx = int(true_label)
        is_correct = (idxs[0] == true_idx)

    try:
        tgt_layer = model.conv_head if hasattr(model, "conv_head") else list(model.modules())[-1]
        cam_runner = make_cam_runner(model, tgt_layer)
        x_cam = test_tfms(base_pil).unsqueeze(0).to(DEVICE)
        grayscale = cam_runner(x_cam, int(idxs[0]))  # top-1 class
        raw = np.array(base_pil.resize((IMG_SIZE, IMG_SIZE))).astype(np.float32)/255.0
        cam_vis = _overlay_cam_on_image_rgb(raw, grayscale, 0.35)
        if save_cam_path is None:
            save_cam_path = str(WORK_DIR / "single_image_gradcam.jpg")
        Image.fromarray(cam_vis).save(save_cam_path)
        cam_msg = f"Grad-CAM saved to: {save_cam_path}"
    except Exception as e:
        cam_msg = f"Grad-CAM failed: {e}"

    print(f"\nSingle-image prediction for: {image_path}")
    for rank, (name, conf) in enumerate(preds, 1):
        print(f"  {rank:>2}. {name:>20s}  conf={conf:.4f}")
    if is_correct is not None:
        print(f"Correct? {is_correct} (true label = {classes[true_idx]})")
    print(cam_msg)

    return {
        "image": image_path,
        "topk": preds,
        "pred_idx": int(idxs[0]),
        "pred_class": classes[int(idxs[0])],
        "pred_conf": float(confs[0]),
        "is_correct": is_correct,
        "cam_path": save_cam_path,
    }

# Example:
# result = infer_single_image(
#     image_path="/kaggle/input/riceds-original/Original/<class>/<image>.jpg",
#     model=model_best,
#     topk=5,
#     tta=4,
#     true_label="<class>",
# )
# print(result)
