
# PyTorch Training
Train frozen-backbone torch models on the same preprocessed PlantVillage dataset (no augmentation).


In [4]:
import time, json
from pathlib import Path

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve, confusion_matrix
from tqdm import tqdm   # NEW

# ======================================================
# CONFIG
# ======================================================
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

BATCH_SIZE = 32
EPOCHS = 5
PATIENCE = 2

ROOT = Path(".")
NPY_DIR = ROOT / "preprocessed_numpy"
RESULTS_DIR = ROOT / "results"
MODELS_DIR = ROOT / "models"
RESULTS_DIR.mkdir(exist_ok=True)
MODELS_DIR.mkdir(exist_ok=True)

# ======================================================
# DEVICE 
# ======================================================
if torch.cuda.is_available():
    device = torch.device("cuda")
    gpu_name = torch.cuda.get_device_name(0)
    gpu_total = torch.cuda.get_device_properties(0).total_memory / (1024**2)
    gpu_used = torch.cuda.memory_allocated() / (1024**2)
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    gpu_name = "Apple MPS"
    gpu_total = 0
    gpu_used = 0
else:
    device = torch.device("cpu")
    gpu_name = "CPU"
    gpu_total = 0
    gpu_used = 0

print("Using device:", device)
print("GPU/Accelerator:", gpu_name)

# ======================================================
# DATASET
# ======================================================
class NumpyDataset(Dataset):
    def __init__(self, split):
        self.images = np.load(NPY_DIR / f"{split}_images.npy")
        self.labels = np.load(NPY_DIR / f"{split}_labels.npy")
        self.transform = T.ToTensor()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.transform(self.images[idx]), torch.tensor(self.labels[idx], dtype=torch.float32)


def make_loader(split):
    return DataLoader(
        NumpyDataset(split),
        batch_size=BATCH_SIZE,
        shuffle=(split == "train"),
        num_workers=0,
        pin_memory=False,
    )

train_loader = make_loader("train")
val_loader = make_loader("val")
test_loader = make_loader("test")


# ======================================================
# MODEL BUILDER
# ======================================================
def build_model(name, freeze_backbone=True):
    lname = name.lower()

    if lname == "resnet50":
        model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        model.fc = nn.Linear(model.fc.in_features, 1)

    elif lname == "alexnet":
        model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        model.classifier[6] = nn.Linear(4096, 1)

    elif lname == "vgg16":
        model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        model.classifier[6] = nn.Linear(4096, 1)

    elif lname == "densenet121":
        model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        model.classifier = nn.Linear(model.classifier.in_features, 1)

    elif lname == "googlenet":
        model = models.googlenet(
            weights=models.GoogLeNet_Weights.IMAGENET1K_V1,
        )
        model.fc = nn.Linear(model.fc.in_features, 1)
        # Set aux_logits to False after initialization
        model.aux_logits = False

    elif lname == "shufflenet_v2":
        model = models.shufflenet_v2_x1_0(weights=models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1)
        model.fc = nn.Linear(model.fc.in_features, 1)

    elif lname == "mobilenet_v2":
        model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, 1)
    elif lname == "squeezenet1_0":
        model = models.squeezenet1_0(weights=models.SqueezeNet1_0_Weights.IMAGENET1K_V1)
        # Replace the final conv layer (classifier)
        model.classifier[1] = nn.Conv2d(512, 1, kernel_size=1)
        model.num_classes = 1

    elif lname == "mnasnet1_0":
        model = models.mnasnet1_0(weights=models.MNASNet1_0_Weights.IMAGENET1K_V1)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, 1)


    else:
        import timm
        model = timm.create_model(lname, pretrained=True, num_classes=1)

    if freeze_backbone:
        for n, p in model.named_parameters():
            if "fc" not in n and "classifier" not in n and "head" not in n:
                p.requires_grad = False

    return model.to(device)


# ======================================================
# METRICS
# ======================================================
def eval_metrics(model, loader, criterion=None, return_raw=False):
    model.eval()
    y_true, y_prob = [], []
    total_loss, n = 0.0, 0

    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            logits = model(X).squeeze(1)
            probs = torch.sigmoid(logits)

            if criterion:
                total_loss += criterion(logits, y).item() * X.size(0)

            y_true.extend(y.cpu().numpy())
            y_prob.extend(probs.cpu().numpy())
            n += X.size(0)

    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    y_pred = (y_prob > 0.5).astype(int)

    try: auc = roc_auc_score(y_true, y_prob)
    except: auc = float("nan")

    fpr, tpr, thr = roc_curve(y_true, y_prob)

    if return_raw:
        return y_true, y_prob, fpr, tpr, thr

    return {
        "loss": total_loss / n if criterion else None,
        "accuracy": accuracy_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred, zero_division=0),
        "recall": recall_score(y_true, y_pred, zero_division=0),
        "f1": f1_score(y_true, y_pred, zero_division=0),
        "auc": auc,
        "roc_fpr": fpr,
        "roc_tpr": tpr,
        "roc_thresh": thr,
    }


# ======================================================
# TRAINING LOOP WITH tqdm
# ======================================================
def train_models(models_to_train):
    # Load existing history if JSON file exists
    json_path = RESULTS_DIR / "all_models_training_history.json"
    if json_path.exists():
        with open(json_path, 'r') as f:
            all_models_history = json.load(f)
    else:
        all_models_history = {}

    for name in models_to_train:
        print("\n" + "="*80)
        print(f"Training: {name}")
        print("="*80)

        model = build_model(name)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", patience=1)

        best_loss = float("inf")
        best_epoch = 0
        best_path = MODELS_DIR / f"{name}_best.pth"
        no_improve = 0
        start = time.time()
        
        # Track per-epoch metrics
        epoch_history = []

        for epoch in range(1, EPOCHS + 1):
            

            # tqdm progress bar for batches
            batch_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS} ({name})", leave=False)

            model.train()
            total_loss, n = 0.0, 0

            for X, y in batch_bar:
                X, y = X.to(device), y.to(device)
                optimizer.zero_grad()
                logits = model(X).squeeze(1)
                loss = criterion(logits, y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * X.size(0)
                n += X.size(0)
                batch_bar.set_postfix({"batch_loss": loss.item()})

            train_loss = total_loss / n
            
            # Calculate train accuracy
            train_m = eval_metrics(model, train_loader, criterion)
            train_acc = train_m["accuracy"]
            
            # Calculate validation metrics
            val_m = eval_metrics(model, val_loader, criterion)
            val_loss = val_m["loss"]
            val_acc = val_m["accuracy"]
            
            # Calculate confusion matrix for validation
            val_y_true, val_y_prob, _, _, _ = eval_metrics(model, val_loader, criterion, return_raw=True)
            val_y_pred = (val_y_prob > 0.5).astype(int)
            val_cm = confusion_matrix(val_y_true, val_y_pred)
            val_tn, val_fp, val_fn, val_tp = val_cm.ravel()
            
            # Save epoch metrics
            epoch_history.append({
                "epoch": epoch,
                "train_loss": train_loss,
                "train_accuracy": train_acc,
                "val_loss": val_loss,
                "val_accuracy": val_acc,
                "val_tn": int(val_tn),
                "val_fp": int(val_fp),
                "val_fn": int(val_fn),
                "val_tp": int(val_tp)
            })

            print(f"[{name}] Epoch {epoch} | Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

            if val_loss < best_loss:
                best_loss = val_loss
                best_epoch = epoch
                torch.save(model.state_dict(), best_path)
                no_improve = 0
            else:
                no_improve += 1

            scheduler.step(val_loss)
            if no_improve >= PATIENCE:
                print("Early stopping triggered.")
                break

        # Final evaluation
        model.load_state_dict(torch.load(best_path, map_location=device))
        val_m = eval_metrics(model, val_loader, criterion)
        test_m = eval_metrics(model, test_loader, criterion)
        
        # Get confusion matrices for final metrics
        val_y_true, val_y_prob, _, _, _ = eval_metrics(model, val_loader, criterion, return_raw=True)
        val_y_pred = (val_y_prob > 0.5).astype(int)
        val_cm = confusion_matrix(val_y_true, val_y_pred)
        val_tn, val_fp, val_fn, val_tp = val_cm.ravel()
        
        test_y_true, test_y_prob, fpr, tpr, thr = eval_metrics(model, test_loader, criterion, return_raw=True)
        test_y_pred = (test_y_prob > 0.5).astype(int)
        test_cm = confusion_matrix(test_y_true, test_y_pred)
        test_tn, test_fp, test_fn, test_tp = test_cm.ravel()

        elapsed = time.time() - start

        # Store this model's training history
        all_models_history[name] = {
            "model": name,
            "best_epoch": best_epoch,
            "total_epochs_trained": len(epoch_history),
            "epochs": epoch_history,
            "final_metrics": {
                "val_loss": val_m["loss"],
                "val_accuracy": val_m["accuracy"],
                "val_precision": val_m["precision"],
                "val_recall": val_m["recall"],
                "val_f1": val_m["f1"],
                "val_auc": val_m["auc"],
                "val_confusion_matrix": {"tn": int(val_tn), "fp": int(val_fp), "fn": int(val_fn), "tp": int(val_tp)},
                "test_loss": test_m["loss"],
                "test_accuracy": test_m["accuracy"],
                "test_precision": test_m["precision"],
                "test_recall": test_m["recall"],
                "test_f1": test_m["f1"],
                "test_auc": test_m["auc"],
                "test_confusion_matrix": {"tn": int(test_tn), "fp": int(test_fp), "fn": int(test_fn), "tp": int(test_tp)},
            },
            "training_info": {
                "train_time_sec": elapsed
            }
        }
    
    # Save all models' training history to one JSON file
    with open(json_path, 'w') as f:
        json.dump(all_models_history, f, indent=2)
    print(f"\n{'='*80}")
    print(f"Saved complete training history for all models to: {json_path}")
    print(f"{'='*80}")
# ======================================================
# RUN
# ======================================================
models_to_train = [

    # ============================
    # Lightweight Architectures
    # ============================

    # "mobilenet_v2",           # Lightweight CNN | Pratyaksh done
    # "shufflenet_v2",          # Lightweight CNN | Pratyaksh done
    # "squeezenet1_0",          # Lightweight CNN (Fire modules) | Pratyaksh done
    # "mnasnet1_0",             # Lightweight CNN (NAS-designed) | Pratyaksh done
    # "tf_efficientnet_lite4",  # Lightweight CNN (EfficientNet-Lite) | Pratyaksh done


    # ============================
    # Classic CNNs
    # ============================
    "googlenet",              # Classic CNN (Inception V1) | Pratyaksh  done
    "alexnet",                # Classic CNN  | pratyaksh done
    "vgg16",                  # Classic CNN  | Prartyaksh done
    "resnet50",               # Classic CNN  | Pratyaksh done
    "densenet121",            # Classic CNN (Dense connections) | Pratyaksh done

    # "convnext_base",          # Deep CNN (ConvNeXt â€” CNN inspired by ViT) | Vandit

]


train_models(models_to_train)


Using device: mps
GPU/Accelerator: Apple MPS

Training: mobilenet_v2

Training: mobilenet_v2


                                                                                                

[mobilenet_v2] Epoch 1 | Train Loss: 0.2882 Acc: 0.9549 | Val Loss: 0.1786 Acc: 0.9405


                                                                                                

[mobilenet_v2] Epoch 2 | Train Loss: 0.1683 Acc: 0.9628 | Val Loss: 0.1408 Acc: 0.9508


                                                                                                

[mobilenet_v2] Epoch 3 | Train Loss: 0.1434 Acc: 0.9627 | Val Loss: 0.1387 Acc: 0.9477


                                                                                                

[mobilenet_v2] Epoch 4 | Train Loss: 0.1347 Acc: 0.9666 | Val Loss: 0.1246 Acc: 0.9517


                                                                                                

[mobilenet_v2] Epoch 5 | Train Loss: 0.1293 Acc: 0.9690 | Val Loss: 0.1167 Acc: 0.9575

Training: shufflenet_v2


                                                                                                

[shufflenet_v2] Epoch 1 | Train Loss: 0.6165 Acc: 0.9091 | Val Loss: 0.5665 Acc: 0.8700


                                                                                                

[shufflenet_v2] Epoch 2 | Train Loss: 0.5012 Acc: 0.9246 | Val Loss: 0.4661 Acc: 0.9024


                                                                                                

[shufflenet_v2] Epoch 3 | Train Loss: 0.4243 Acc: 0.9294 | Val Loss: 0.4092 Acc: 0.9052


                                                                                                

[shufflenet_v2] Epoch 4 | Train Loss: 0.3703 Acc: 0.9341 | Val Loss: 0.3587 Acc: 0.9138


                                                                                                

[shufflenet_v2] Epoch 5 | Train Loss: 0.3321 Acc: 0.9407 | Val Loss: 0.3107 Acc: 0.9289

Training: squeezenet1_0


                                                                                                

[squeezenet1_0] Epoch 1 | Train Loss: 0.6218 Acc: 0.5928 | Val Loss: 0.4272 Acc: 0.7825


                                                                                                

[squeezenet1_0] Epoch 2 | Train Loss: 0.5326 Acc: 0.6364 | Val Loss: 0.4002 Acc: 0.8022


                                                                                                

[squeezenet1_0] Epoch 3 | Train Loss: 0.5195 Acc: 0.6551 | Val Loss: 0.3846 Acc: 0.8134


                                                                                                

[squeezenet1_0] Epoch 4 | Train Loss: 0.5132 Acc: 0.6954 | Val Loss: 0.3861 Acc: 0.8352


                                                                                                

[squeezenet1_0] Epoch 5 | Train Loss: 0.5091 Acc: 0.7225 | Val Loss: 0.3859 Acc: 0.8516
Early stopping triggered.

Training: mnasnet1_0


                                                                                              

[mnasnet1_0] Epoch 1 | Train Loss: 0.3136 Acc: 0.6666 | Val Loss: 1.1109 Acc: 0.4902


                                                                                              

[mnasnet1_0] Epoch 2 | Train Loss: 0.1838 Acc: 0.7386 | Val Loss: 1.0217 Acc: 0.6011


                                                                                              

[mnasnet1_0] Epoch 3 | Train Loss: 0.1596 Acc: 0.7796 | Val Loss: 0.8668 Acc: 0.6683


                                                                                              

[mnasnet1_0] Epoch 4 | Train Loss: 0.1449 Acc: 0.8324 | Val Loss: 0.6349 Acc: 0.7431


                                                                                              

[mnasnet1_0] Epoch 5 | Train Loss: 0.1368 Acc: 0.8951 | Val Loss: 0.3796 Acc: 0.8446

Training: tf_efficientnet_lite4


                                                                                                           

[tf_efficientnet_lite4] Epoch 1 | Train Loss: 0.2612 Acc: 0.9656 | Val Loss: 0.1124 Acc: 0.9624


                                                                                                           

[tf_efficientnet_lite4] Epoch 2 | Train Loss: 0.1072 Acc: 0.9756 | Val Loss: 0.0990 Acc: 0.9685


                                                                                                           

[tf_efficientnet_lite4] Epoch 3 | Train Loss: 0.0774 Acc: 0.9725 | Val Loss: 0.0861 Acc: 0.9722


                                                                                                           

[tf_efficientnet_lite4] Epoch 4 | Train Loss: 0.0617 Acc: 0.9854 | Val Loss: 0.0793 Acc: 0.9757


                                                                                                           

[tf_efficientnet_lite4] Epoch 5 | Train Loss: 0.0554 Acc: 0.9896 | Val Loss: 0.0527 Acc: 0.9836

Saved complete training history for all models to: results/all_models_training_history.json
