
# PyTorch Training
Train frozen-backbone torch models on the same preprocessed PlantVillage dataset (no augmentation).


In [None]:
import os, time, psutil, json
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
from tqdm import tqdm   # NEW

# ======================================================
# CONFIG
# ======================================================
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

BATCH_SIZE = 32
EPOCHS = 5
PATIENCE = 2

ROOT = Path(".")
NPY_DIR = ROOT / "preprocessed_numpy"
RESULTS_DIR = ROOT / "results"
MODELS_DIR = ROOT / "models"
RESULTS_DIR.mkdir(exist_ok=True)
MODELS_DIR.mkdir(exist_ok=True)

# ======================================================
# DEVICE 
# ======================================================
if torch.cuda.is_available():
    device = torch.device("cuda")
    gpu_name = torch.cuda.get_device_name(0)
    gpu_total = torch.cuda.get_device_properties(0).total_memory / (1024**2)
    gpu_used = torch.cuda.memory_allocated() / (1024**2)
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    gpu_name = "Apple MPS"
    gpu_total = 0
    gpu_used = 0
else:
    device = torch.device("cpu")
    gpu_name = "CPU"
    gpu_total = 0
    gpu_used = 0

print("Using device:", device)
print("GPU/Accelerator:", gpu_name)

# ======================================================
# DATASET
# ======================================================
class NumpyDataset(Dataset):
    def __init__(self, split):
        self.images = np.load(NPY_DIR / f"{split}_images.npy")
        self.labels = np.load(NPY_DIR / f"{split}_labels.npy")
        self.transform = T.ToTensor()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.transform(self.images[idx]), torch.tensor(self.labels[idx], dtype=torch.float32)


def make_loader(split):
    return DataLoader(
        NumpyDataset(split),
        batch_size=BATCH_SIZE,
        shuffle=(split == "train"),
        num_workers=0,
        pin_memory=False,
    )

train_loader = make_loader("train")
val_loader = make_loader("val")
test_loader = make_loader("test")


# ======================================================
# MODEL BUILDER
# ======================================================
def build_model(name, freeze_backbone=True):
    lname = name.lower()

    if lname == "resnet50":
        model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        model.fc = nn.Linear(model.fc.in_features, 1)

    elif lname == "alexnet":
        model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        model.classifier[6] = nn.Linear(4096, 1)

    elif lname == "vgg16":
        model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        model.classifier[6] = nn.Linear(4096, 1)

    elif lname == "densenet121":
        model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        model.classifier = nn.Linear(model.classifier.in_features, 1)

    elif lname == "googlenet":
        model = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
        model.fc = nn.Linear(model.fc.in_features, 1)

    elif lname == "inception_v3":
        model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1)
        model.fc = nn.Linear(model.fc.in_features, 1)

    elif lname == "shufflenet_v2":
        model = models.shufflenet_v2_x1_0(weights=models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)
        model.fc = nn.Linear(model.fc.in_features, 1)

    elif lname == "mobilenet_v2":
        model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, 1)


    else:
        import timm
        model = timm.create_model(lname, pretrained=True, num_classes=1)

    if freeze_backbone:
        for n, p in model.named_parameters():
            if "fc" not in n and "classifier" not in n and "head" not in n:
                p.requires_grad = False

    return model.to(device)


# ======================================================
# METRICS
# ======================================================
def eval_metrics(model, loader, criterion=None, return_raw=False):
    model.eval()
    y_true, y_prob = [], []
    total_loss, n = 0.0, 0

    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            logits = model(X).squeeze(1)
            probs = torch.sigmoid(logits)

            if criterion:
                total_loss += criterion(logits, y).item() * X.size(0)

            y_true.extend(y.cpu().numpy())
            y_prob.extend(probs.cpu().numpy())
            n += X.size(0)

    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    y_pred = (y_prob > 0.5).astype(int)

    try: auc = roc_auc_score(y_true, y_prob)
    except: auc = float("nan")

    fpr, tpr, thr = roc_curve(y_true, y_prob)

    if return_raw:
        return y_true, y_prob, fpr, tpr, thr

    return {
        "loss": total_loss / n if criterion else None,
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, zero_division=0),
        "recall": recall_score(y_true, y_pred, zero_division=0),
        "f1": f1_score(y_true, y_pred, zero_division=0),
        "auc": auc,
        "roc_fpr": fpr,
        "roc_tpr": tpr,
        "roc_thresh": thr,
    }


# ======================================================
# TRAINING LOOP WITH tqdm
# ======================================================
def train_models(models_to_train):
    out_path = RESULTS_DIR / "metrics.csv"

    for name in models_to_train:
        print("\n" + "="*80)
        print(f"Training: {name}")
        print("="*80)

        model = build_model(name)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1)

        best_loss = float("inf")
        best_path = MODELS_DIR / f"{name}_best.pth"
        no_improve = 0
        start = time.time()

        for epoch in range(1, EPOCHS + 1):

            # tqdm progress bar for batches
            batch_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} ({name})", leave=False)

            model.train()
            total_loss, n = 0.0, 0

            for X, y in batch_bar:
                X, y = X.to(device), y.to(device)
                optimizer.zero_grad()
                logits = model(X).squeeze(1)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * X.size(0)
                n += X.size(0)
                batch_bar.set_postfix({"batch_loss": loss.item()})

            train_loss = total_loss / n
            val_m = eval_metrics(model, val_loader, criterion)
            val_loss = val_m["loss"]

            print(f"[{name}] Epoch {epoch} | Train {train_loss:.4f} | Val {val_loss:.4f} | Acc={val_m['accuracy']:.4f}")

            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(model.state_dict(), best_path)
                no_improve = 0
            else:
                no_improve += 1

            scheduler.step(val_loss)
            if no_improve >= PATIENCE:
                print("Early stopping triggered.")
                break

        # Final evaluation
        model.load_state_dict(torch.load(best_path, map_location=device))
        val_m = eval_metrics(model, val_loader, criterion)
        test_m = eval_metrics(model, test_loader, criterion)
        _, _, fpr, tpr, thr = eval_metrics(model, test_loader, criterion, return_raw=True)

        elapsed = time.time() - start
        mem = psutil.Process(os.getpid()).memory_info().rss / (1024**2)
        gpu_used = torch.cuda.memory_allocated() / (1024**2) if torch.cuda.is_available() else 0

        row = {
            "model": name,

            "val_loss": val_m["loss"],
            "val_accuracy": val_m["accuracy"],
            "val_precision": val_m["precision"],
            "val_recall": val_m["recall"],
            "val_f1": val_m["f1"],
            "val_auc": val_m["auc"],

            "test_loss": test_m["loss"],
            "test_accuracy": test_m["accuracy"],
            "test_precision": test_m["precision"],
            "test_recall": test_m["recall"],
            "test_f1": test_m["f1"],
            "test_auc": test_m["auc"],

            "roc_fpr": json.dumps(fpr.tolist()),
            "roc_tpr": json.dumps(tpr.tolist()),
            "roc_thresholds": json.dumps(thr.tolist()),

            "train_time_sec": elapsed,
            "memory_mb": mem,
            "gpu_name": gpu_name,
            "gpu_total_memory_mb": gpu_total,
            "gpu_used_memory_mb": gpu_used,
        }

        # Append/Update CSV
        new_df = pd.DataFrame([row])

        if out_path.exists():
            old_df = pd.read_csv(out_path)
            old_df = old_df[old_df["model"] != name]
            updated_df = pd.concat([old_df, new_df], ignore_index=True)
            updated_df.to_csv(out_path, index=False)
            print(f"Updated CSV: {out_path}")
            display(updated_df)
        else:
            new_df.to_csv(out_path, index=False)
            print(f"Created CSV: {out_path}")
            display(new_df)


# ======================================================
# RUN
# ======================================================
models_to_train = [
    # ============================
    # Classic CNNs
    # ============================

    # "alexnet",                # Classic CNN  | pratyaksh done
    "vgg16",                  # Classic CNN  | Prartyaksh
    # "resnet50",               # Classic CNN  | Pratyaksh
    # "googlenet",              # Classic CNN (Inception V1) | Pratyaksh done
    # "densenet121",            # Classic CNN (Dense connections) | Pratyaksh

    # ============================
    # Lightweight Architectures
    # ============================

    # "mobilenet_v2",           # Lightweight CNN | Pratyaksh done
    # "shufflenet_v2",          # Lightweight CNN | Pratyaksh done
    # "squeezenet1_0",          # Lightweight CNN (Fire modules) | Pratyaksh
    # "mnasnet1_0",             # Lightweight CNN (NAS-designed) | Pratyaksh
    # "tf_efficientnet_lite4",  # Lightweight CNN (EfficientNet-Lite) | Pratyaksh done

    # ============================
    # Deep / Modern Architectures
    # ============================

    # "inception_v4",           # Deep CNN (Inception series) | Vandit
    # "inception_resnet_v2",    # Deep CNN (Inception + ResNet hybrid) | Vandit
    # "xception",               # Deep CNN (Depthwise separable conv) | Vandit
    # "seresnet50",             # Deep CNN (ResNet + Squeeze-and-Excitation) | Vandit
    # "seresnext50_32x4d",      # Deep CNN (ResNeXt + SE attention) | Vandit
    # "regnet_y_800mf",         # Deep CNN (Facebook RegNet scalable design) | Vandit
    # "convnext_base",          # Deep CNN (ConvNeXt — CNN inspired by ViT) | Vandit

    # ============================
    # Hybrid CNN + Attention
    # ============================

    # "maxvit_base_tf_224",     # Hybrid CNN + Attention (MaxViT)
]


train_models(models_to_train)


Using device: mps
GPU/Accelerator: Apple MPS

Training: vgg16


Epoch 1/5 (vgg16):  13%|█▎        | 274/2114 [01:28<09:58,  3.07it/s, batch_loss=0.0609] 