
# PyTorch Training
Train frozen-backbone torch models on the same preprocessed PlantVillage dataset (no augmentation).


In [None]:
import time
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 5

PROJECT_ROOT = Path(".").resolve()
PREPROCESSED_DIR = PROJECT_ROOT / "preprocessed"
RESULTS_DIR = PROJECT_ROOT / "results"
MODELS_DIR = PROJECT_ROOT / "models"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR.mkdir(parents=True, exist_ok=True)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")



Using device: mps


## 1) Dataset & DataLoader

In [None]:
class PlantVillageCSVDataset(Dataset):
    def __init__(self, split, img_dir, csv_path, transform=None):
        df = pd.read_csv(csv_path)
        # filter by split column if your CSV has train/val/test info
        self.df = df[df["split"] == split].reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.transform = transform if transform else T.ToTensor()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self.img_dir / row["filename"]
        label = torch.tensor(row["label"], dtype=torch.float32)

        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        return img, label





import sys
IS_MAC = sys.platform == "darwin"
WORKERS = 0 if IS_MAC else 4

def make_loader(split):
    dataset = PlantVillageCSVDataset(
        split=split,
        img_dir=PREPROCESSED_DIR / "images",
        csv_path=PREPROCESSED_DIR / "labels.csv",
        transform=T.Compose([
            T.Resize((224,224)),
            T.ToTensor(),
        ])
    )
    return DataLoader(dataset, batch_size=BATCH_SIZE,
                      shuffle=(split=="train"),
                      num_workers=0 if IS_MAC else 4,
                      pin_memory=not IS_MAC)




train_loader = make_loader("train")
val_loader = make_loader("val")
test_loader = make_loader("test")


## 2) Model: ResNet50 (Frozen Backbone)
Load ImageNet-pretrained ResNet50, freeze the convolutional backbone, and replace the final classifier with a single-logit layer for binary classification (diseased vs healthy).

In [None]:
import torchvision.models as models
def build_model(model_name, freeze_backbone=True):
    model_name = model_name.lower()

    # --- torchvision models ---
    if model_name == "resnet50":
        model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, 1)

    elif model_name == "alexnet":
        model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        model.classifier[6] = nn.Linear(4096, 1)

    elif model_name == "vgg16":
        model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        model.classifier[6] = nn.Linear(4096, 1)

    elif model_name == "densenet121":
        model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        model.classifier = nn.Linear(model.classifier.in_features, 1)

    elif model_name == "googlenet":
        model = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
        model.fc = nn.Linear(model.fc.in_features, 1)

    elif model_name == "inception_v3":
        model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1)
        model.fc = nn.Linear(model.fc.in_features, 1)

    # --- timm models ---
    else:
        import timm
        model = timm.create_model(model_name, pretrained=True, num_classes=1)

    if freeze_backbone:
        for name, param in model.named_parameters():
            if "fc" not in name and "classifier" not in name and "head" not in name:
                param.requires_grad = False

    return model.to(device)


# Build model, criterion, optimizer
model_key = "resnet50"
model = get_resnet50_binary(freeze_backbone=True).to(device)
criterion = nn.BCEWithLogitsLoss()
# Only optimize classifier head parameters
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

best_path = MODELS_DIR / f"{model_key}_best.pt"
print(f"Model: {model_key}, frozen_backbone=True, saving to: {best_path}")

## 3) Train/eval helpers

In [4]:
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        logits = model(images).squeeze(1)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    return running_loss / len(loader.dataset)


def eval_model(model, loader):
    model.eval()
    y_true, y_prob = [], []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            logits = model(images).squeeze(1)
            probs = torch.sigmoid(logits)
            y_true.extend(labels.cpu().numpy())
            y_prob.extend(probs.cpu().numpy())
    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    y_pred = (y_prob > 0.5).astype(int)
    return {
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred)),
        "recall": float(recall_score(y_true, y_pred)),
        "f1": float(f1_score(y_true, y_pred)),
        "auc": float(roc_auc_score(y_true, y_prob)),
    }


## 4) Run training/evaluation

In [5]:
def eval_model(model, loader, criterion=None):
    model.eval()
    y_true, y_prob = [], []
    total_loss = 0.0
    n = 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            logits = model(images).squeeze(1)
            probs = torch.sigmoid(logits)
            if criterion is not None:
                loss = criterion(logits, labels)
                total_loss += loss.item() * images.size(0)
            y_true.extend(labels.cpu().numpy())
            y_prob.extend(probs.cpu().numpy())
            n += images.size(0)
    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    y_pred = (y_prob > 0.5).astype(int)
    metrics = {
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred)),
        "recall": float(recall_score(y_true, y_pred)),
        "f1": float(f1_score(y_true, y_pred)),
        "auc": float(roc_auc_score(y_true, y_prob)),
    }
    if criterion is not None:
        metrics["loss"] = total_loss / max(1, n)
    return metrics

# Training loop tuned to match TF notebook behaviour
PATIENCE = 2
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, min_lr=1e-6, verbose=True)

t0 = time.time()
best_val_acc = -1.0
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(1, EPOCHS + 1):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer)
    val_metrics = eval_model(model, val_loader, criterion=criterion)
    val_loss = val_metrics.get('loss', float('inf'))
    val_acc = val_metrics['accuracy']
    print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, val_acc={val_acc:.4f}, val_f1={val_metrics['f1']:.4f}")

    # ModelCheckpoint: save best by val_accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_path)

    # Scheduler step on val_loss
    scheduler.step(val_loss)

    # EarlyStopping on val_loss
    if val_loss < best_val_loss - 1e-6:
        best_val_loss = val_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= PATIENCE:
        print(f"Early stopping: no improvement in val_loss for {PATIENCE} epochs.")
        break

# Load best model and evaluate on test set
if best_path.exists():
    model.load_state_dict(torch.load(best_path, map_location=device))

test_metrics = eval_model(model, test_loader, criterion=criterion)
print("Test metrics:", test_metrics)

=== Training alexnet ===


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /Users/pratyaksh/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:08<00:00, 29.0MB/s] 


Epoch 1: train_loss=0.1307, val_f1=0.9817
Epoch 2: train_loss=0.0672, val_f1=0.9850
Epoch 3: train_loss=0.0566, val_f1=0.9863
Epoch 4: train_loss=0.0502, val_f1=0.9918
Epoch 5: train_loss=0.0473, val_f1=0.9886
{'framework': 'pytorch', 'model': 'alexnet', 'accuracy': 0.9868647188804321, 'precision': 0.9893975092561427, 'recall': 0.992571332095222, 'f1': 0.9909818794774548, 'auc': 0.9989285195395186, 'train_time_sec': 1907.5788941383362, 'train_time_min': 31.79298156897227, 'checkpoint': '/Users/pratyaksh/UTA/sem3/CV/plantvillage-study/notebooks/models/alexnet.pt'}
Saved PyTorch metrics -> /Users/pratyaksh/UTA/sem3/CV/plantvillage-study/notebooks/results/pytorch_metrics.csv


Unnamed: 0,framework,model,accuracy,precision,recall,f1,auc,train_time_sec,train_time_min,checkpoint
0,pytorch,alexnet,0.986865,0.989398,0.992571,0.990982,0.998929,1907.578894,31.792982,/Users/pratyaksh/UTA/sem3/CV/plantvillage-stud...
