In [1]:
# --- Cell 1: Imports (Consolidated) ---
import os
import sys
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage.segmentation import slic
from skimage.color import label2rgb
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights, vit_b_16, ViT_B_16_Weights

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, precision_recall_fscore_support

import wandb
import optuna  # For hyperparameter optimization

import albumentations as A
from albumentations.pytorch import ToTensorV2

from pytorch_grad_cam import GradCAM  # For Grad-CAM (later)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [2]:
# --- Cell 2: Configuration (Updated) ---

# --- Dataset ---
DATASET_ROOT = "/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection"  # YOUR DATASET PATH.  UPDATE THIS!
ORIGINAL_DIR = os.path.join(DATASET_ROOT, "Original Dataset")
AUGMENTED_DIR = os.path.join(DATASET_ROOT, "Augmented Dataset")

CLASSES = [
    "Bacterial Blight",
    "Curl Virus",
    "Healthy Leaf",
    "Herbicide Growth Damage",
    "Leaf Hopper Jassids",
    "Leaf Redding",
    "Leaf Variegation",
]
NUM_CLASSES = len(CLASSES)
CLASS_MAP = {i: name for i, name in enumerate(CLASSES)}

# --- Training ---
IMAGE_SIZE = (224, 224)   # Use 224x224
BATCH_SIZE = 32
LEARNING_RATE = 1e-4  # Initial learning rate (will be tuned)
EPOCHS = 10  #  Set to a reasonable value for final training; lower for Optuna trials.
NUM_WORKERS = 6
VAL_SIZE = 0.2
TEST_SIZE = 0.2
RANDOM_STATE = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Model ---
# Keep ViT settings from Phase 2
VIT_MODEL_NAME = "vit_b_16"
VIT_PRETRAINED = True
VIT_CHECKPOINT_DIR = "checkpoints"  # Use the same directory
VIT_CHECKPOINT_PATH = os.path.join(VIT_CHECKPOINT_DIR, f"{VIT_MODEL_NAME}_best_accuracy.pth") #From phase 2

# Add EfficientNetV2 settings
EFFNET_MODEL_NAME = "efficientnet_v2_s"
EFFNET_PRETRAINED = True
EFFNET_CHECKPOINT_DIR = "checkpoints"  # Use the same directory
EFFNET_CHECKPOINT_PATH = os.path.join(EFFNET_CHECKPOINT_DIR, f"{EFFNET_MODEL_NAME}_best.pth") #Use best validation loss

# --- Ensemble --- (NEW)
ENSEMBLE_METHOD = "simple"  # "simple" or "weighted" (start with simple)

# --- Hyperparameter Optimization --- (NEW)
N_TRIALS = 10  # Number of Optuna trials (start small, increase later)
OPTUNA_EPOCHS = 5  #  Use fewer epochs *during* Optuna trials.

In [3]:
# --- Cell 3: Data Loading Functions (REVISED to use data_utils.py) ---

from data_utils import create_data_loaders, get_transforms, CottonDataset  # Import from data_utils

In [4]:
# --- Cell 4: Model Definitions (ViT and EfficientNetV2) ---

def get_vit_model(model_name=VIT_MODEL_NAME, pretrained=VIT_PRETRAINED, num_classes=NUM_CLASSES):
    if model_name == "vit_b_16":
        weights = ViT_B_16_Weights.DEFAULT if pretrained else None
        model = vit_b_16(weights=weights)
        model.heads = nn.Linear(model.heads[0].in_features, num_classes)
    else:
        raise ValueError(f"Unsupported ViT model name: {model_name}")
    return model.to(DEVICE)


def get_effnet_model(model_name=EFFNET_MODEL_NAME, pretrained=EFFNET_PRETRAINED, num_classes=NUM_CLASSES):
    if model_name == "efficientnet_v2_s":
        weights = EfficientNet_V2_S_Weights.DEFAULT if pretrained else None
        model = efficientnet_v2_s(weights=weights)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    else:
        raise ValueError(f"Unsupported EfficientNet model name: {model_name}")
    return model.to(DEVICE)

# --- Function to Load Checkpoints ---
def load_checkpoint(model, checkpoint_path):
    if os.path.exists(checkpoint_path):
        model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
        print(f"Loaded checkpoint from {checkpoint_path}")
    else:
        print(f"Checkpoint not found at {checkpoint_path}. Starting from scratch or ImageNet weights.")
    return model

In [5]:
# --- Cell 5: Generalized Training Loop (Modified for Optuna) ---

def train_model(model, train_loader, val_loader, learning_rate, epochs, checkpoint_path=None, trial=None): # Added trial
    """
    Trains a PyTorch model and saves checkpoints, integrated with Optuna.
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    best_val_loss = float('inf')
    best_val_accuracy = 0.0 # Keep track of best accuracy.
    best_val_report = None # Initialize to None
    precision, recall, f1 = 0.0, 0.0, 0.0  # Initialize metrics

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for i, (images, labels) in enumerate(train_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = 100 * correct_train / total_train

        model.eval()
        val_running_loss = 0.0
        correct_val = 0
        total_val = 0
        val_preds = []
        val_true = []

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(DEVICE)
                labels = labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
                val_preds.extend(predicted.cpu().numpy())  # for metrics
                val_true.extend(labels.cpu().numpy())


        val_loss = val_running_loss / len(val_loader.dataset)
        val_accuracy = 100 * correct_val / total_val

        # Generate the classification report *dictionary*
        report = classification_report(val_true, val_preds, target_names=CLASSES, zero_division=0, output_dict=True)
        report = pd.DataFrame(report).transpose()


        # Optuna reporting (report *accuracy* for maximization)
        if trial:
            trial.report(val_accuracy, epoch)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()

        # Checkpointing (save based on *accuracy*, not loss)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_val_loss = val_loss  # Still useful to track
            if checkpoint_path: #save only if path is given
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Saved best model to {checkpoint_path}")
            best_val_report = report # Update with the dictionary
            precision, recall, f1, _ = precision_recall_fscore_support(val_true, val_preds, average='weighted', zero_division=0)

            # Log inside the if condition:
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "train_accuracy": train_accuracy,
                "val_loss": val_loss,
                "val_accuracy": val_accuracy,
                "val_precision": precision,
                "val_recall": recall,
                "val_f1": f1,
            })

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

    print("Finished Training")
    # Check if best_val_report exists and is not None
    if best_val_report is not None:
        print("Best Validation Classification Report:\n", best_val_report)
        wandb.log({"best_validation_classification_report": wandb.Table(dataframe=best_val_report)})

    return model, best_val_accuracy  # Return accuracy for Optuna

In [6]:
# --- Cell 6: Ensemble Prediction ---

import torch
import torch.nn.functional as F
from sklearn.metrics import classification_report
import numpy as np
import pandas as pd
import wandb

def evaluate_ensemble(models, dataloader, ensemble_method="simple", weights=None):
    """Evaluates an ensemble of models."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for model in models:
        model.eval()  # Set models to evaluation mode
        model.to(device)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = []
            for model in models:
                outputs.append(model(images))

            if ensemble_method == "simple":
                # Simple Averaging
                ensemble_output = torch.mean(torch.stack(outputs), dim=0)
            elif ensemble_method == "weighted":
                # Weighted Averaging
                if weights is None:
                    raise ValueError("Weights must be provided for weighted averaging.")
                weighted_outputs = [weights[i] * outputs[i] for i in range(len(models))]
                ensemble_output = torch.sum(torch.stack(weighted_outputs), dim=0)
            else:
                raise ValueError(f"Unknown ensemble method: {ensemble_method}")

            _, predicted = torch.max(ensemble_output, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    accuracy = np.mean(np.array(all_preds) == np.array(all_labels)) * 100
    report = classification_report(all_labels, all_preds, output_dict=True, zero_division=0) #output_dict=True is KEY

    print("Classification Report (Raw):")
    print(report)  #Inspect the report

    if report:
        try:
            df = pd.DataFrame(report).transpose()
            ensemble_classification_report = wandb.Table(dataframe=df)
        except Exception as e:
            print(f"Error creating wandb table: {e}")
            ensemble_classification_report = None
    else:
        print("Warning: Classification report is empty.  Not creating wandb table.")
        ensemble_classification_report = None

    precision = report['macro avg']['precision']
    recall = report['macro avg']['recall']
    f1 = report['macro avg']['f1-score']


    if ensemble_classification_report:
        wandb.log({
            "ensemble_accuracy": accuracy,
            "ensemble_precision": precision,
            "ensemble_recall": recall,
            "ensemble_f1": f1,
            "ensemble_classification_report": ensemble_classification_report
        })
    else:
        wandb.log({
            "ensemble_accuracy": accuracy,
            "ensemble_precision": precision,
            "ensemble_recall": recall,
            "ensemble_f1": f1,
        })

    return accuracy, report

In [7]:
# --- Cell 7: Hyperparameter Optimization (Integrated) ---

def objective(trial):
    """Objective function for Optuna optimization."""

    # Suggest a learning rate
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)

    # --- Data Loading (inside objective, for Optuna) ---
    train_transforms = get_transforms(train=True)
    val_transforms = get_transforms(train=False)
    train_loader, val_loader, _ = create_data_loaders(
        ORIGINAL_DIR, train_transforms, val_transforms, BATCH_SIZE, NUM_WORKERS, CLASSES
    )

    # Get the model (optimize EfficientNet)
    model = get_effnet_model()

    # Train the model (using the suggested learning rate and Optuna integration)
    _, val_accuracy = train_model(model, train_loader, val_loader, learning_rate=learning_rate, epochs=OPTUNA_EPOCHS, trial=trial)

    return val_accuracy  # Optuna maximizes the return value

In [None]:
# --- Cell 8: Main Execution ---

# --- Initialize W&B ---
if wandb.run is None:
   run = wandb.init(project="vit", entity="w2sgarnav", name="w2sgarnav-vit-phase3", mode="offline")


# --- Load Data (outside Optuna, for final training) ---
#  Use standard, non-Optuna data loading for the final training.
train_transforms = get_transforms(train=True)
val_transforms = get_transforms(train=False)
train_loader, val_loader, test_loader = create_data_loaders(
    ORIGINAL_DIR, train_transforms, val_transforms, BATCH_SIZE, NUM_WORKERS, CLASSES
)

# --- 1. Hyperparameter Optimization (EfficientNet) ---
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=N_TRIALS)

print("Best trial:")
trial = study.best_trial
print(f"  Value: {trial.value}")
print(f"  Params: {trial.params}")
best_lr = trial.params["learning_rate"]
wandb.log({"best_learning_rate": best_lr})


# --- 2. Train/Load Models ---

# 2.1 Train/Load ViT (from Phase 2 checkpoint)
vit_model = get_vit_model()
load_checkpoint(vit_model, VIT_CHECKPOINT_PATH)  # Load your best ViT model

# 2.2 Train EfficientNetV2 (using best LR from Optuna)
effnet_model = get_effnet_model()
# Train with the *best* learning rate found by Optuna.
train_model(effnet_model, train_loader, val_loader, learning_rate=best_lr, epochs=EPOCHS, checkpoint_path=EFFNET_CHECKPOINT_PATH)


# --- 3. Create and Evaluate Ensemble ---

# 3.1 Simple Averaging
print("Evaluating Simple Averaging Ensemble...")
ensemble_accuracy, _ = evaluate_ensemble([vit_model, effnet_model], val_loader, ensemble_method="simple")


# 3.2 Weighted Averaging (Example - you can refine the weight optimization)
print("Evaluating Weighted Averaging Ensemble...")
#  Example weights (you can optimize these, perhaps with a separate Optuna study!)
vit_weight = 0.6
effnet_weight = 0.4
ensemble_accuracy, _ = evaluate_ensemble([vit_model, effnet_model], val_loader, ensemble_method="weighted", weights=[vit_weight, effnet_weight])

# --- 4. Evaluate on Test Set (using the best ensemble method) ---
print("Evaluating on Test Set...")
if ENSEMBLE_METHOD == 'simple':
    ensemble_accuracy_test,_ = evaluate_ensemble([vit_model, effnet_model], test_loader, ensemble_method = "simple")
elif ENSEMBLE_METHOD == 'weighted':
    ensemble_accuracy_test,_ = evaluate_ensemble([vit_model, effnet_model], test_loader, ensemble_method = "weighted", weights = [vit_weight, effnet_weight])


wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[I 2025-03-04 22:32:54,388] A new study created in memory with name: no-name-837c0c1b-da20-4780-b35c-21784b8cdd88


--- Loaded File Paths (First 5 of each) ---
Train: ['/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Bacterial Blight/BBC00155.jpg', '/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Herbicide Growth Damage/HGD00260.jpg', '/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Curl Virus/CV00202.jpg', '/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Herbicide Growth Damage/HGD00203.jpg', '/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Bacterial Blight/BBC00184.jpg']
Validation: ['/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Herbicide Growth Damage/HGD00095.jpg', '/home/w2sg-arnav/8-phases/SAR-CLD-2024

[I 2025-03-04 22:34:24,222] Trial 0 finished with value: 96.96261682242991 and parameters: {'learning_rate': 0.0005007496855582154}. Best is trial 0 with value: 96.96261682242991.


Epoch 5/5, Train Loss: 0.1030, Train Acc: 96.80%, Val Loss: 0.1278, Val Acc: 96.96%
Finished Training
Best Validation Classification Report:
                          precision    recall  f1-score     support
Bacterial Blight          0.923077  0.960000  0.941176   50.000000
Curl Virus                0.966292  0.988506  0.977273   87.000000
Healthy Leaf              0.983607  1.000000  0.991736   60.000000
Herbicide Growth Damage   1.000000  0.945455  0.971963   55.000000
Leaf Hopper Jassids       0.900000  1.000000  0.947368   36.000000
Leaf Redding              1.000000  0.941667  0.969957  120.000000
Leaf Variegation          0.952381  1.000000  0.975610   20.000000
accuracy                  0.969626  0.969626  0.969626    0.969626
macro avg                 0.960765  0.976518  0.967869  428.000000
weighted avg              0.971227  0.969626  0.969757  428.000000
--- Loaded File Paths (First 5 of each) ---
Train: ['/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for C