In [1]:
# --- second.ipynb ---

# --- Cell 1: Imports --- (No changes)
import os
import sys
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage.segmentation import slic
from skimage.color import label2rgb
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights, vit_b_16, ViT_B_16_Weights

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, precision_recall_fscore_support

import wandb  # For experiment tracking
import optuna  # For hyperparameter optimization (used later)

import albumentations as A  # For image augmentations
from albumentations.pytorch import ToTensorV2

from pytorch_grad_cam import GradCAM  # For Grad-CAM (used later)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [2]:
# --- Cell 2: Configuration ---
# (No changes needed here)
import os
import torch

# --- Dataset ---
DATASET_ROOT = "/teamspace/studios/this_studio/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection"  # YOUR DATASET PATH.  UPDATE THIS!
ORIGINAL_DIR = os.path.join(DATASET_ROOT, "Original Dataset")
AUGMENTED_DIR = os.path.join(DATASET_ROOT, "Augmented Dataset")

CLASSES = [
    "Bacterial Blight",
    "Curl Virus",
    "Healthy Leaf",
    "Herbicide Growth Damage",
    "Leaf Hopper Jassids",
    "Leaf Redding",
    "Leaf Variegation",
]
NUM_CLASSES = len(CLASSES)
CLASS_MAP = {i: name for i, name in enumerate(CLASSES)}

# --- Training ---
IMAGE_SIZE = (224, 224)  # Model ALWAYS expects 224x224
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
EPOCHS = 60
NUM_WORKERS = 6
VAL_SIZE = 0.2
TEST_SIZE = 0.2
RANDOM_STATE = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Model ---
MODEL_NAME = "vit_b_16"
PRETRAINED = True
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- Progressive Resizing (Modified) ---
# We'll still use this for *when* to increase augmentations, but NOT for model input size
PROGRESSIVE_SIZES = [(128, 128), (224, 224), (384, 384)]
CURRENT_SIZE_INDEX = 0  # Start with the smallest size

In [3]:
# --- Cell 3: Data Loading Functions (Now in data_utils.py) ---
from data_utils import (CottonDataset, get_transforms, segment_leaf,  # Corrected import
                        create_data_loaders)


In [4]:
# --- Cell 4: Model Definition (Modified) ---

def get_model(model_name=MODEL_NAME, pretrained=PRETRAINED, num_classes=NUM_CLASSES):
    if model_name == "vit_b_16":
        weights = ViT_B_16_Weights.DEFAULT if pretrained else None
        model = vit_b_16(weights=weights)
        model.heads = nn.Linear(model.heads[0].in_features, num_classes)
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

    return model.to(DEVICE)  # Ensure model is on the correct device

In [5]:
# --- Cell 5: Training Loop (Modified for best accuracy and NO model reloading) ---
def train_model(model, train_loader, val_loader, learning_rate=LEARNING_RATE, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR, model_name=""):
    """Trains the model, handles progressive augmentation, and saves checkpoints."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    best_val_accuracy = 0.0  # Track best validation *accuracy*
    best_val_report = ""  # Initialize best_val_report


    for epoch in range(epochs):
        # --- Progressive Augmentation Logic ---
        global CURRENT_SIZE_INDEX
        if epoch > 0 and epoch % (epochs // len(PROGRESSIVE_SIZES)) == 0:
            CURRENT_SIZE_INDEX = min(CURRENT_SIZE_INDEX + 1, len(PROGRESSIVE_SIZES) - 1)
            #  We're NOT using new_size for resizing, so don't even calculate it.
            print(f"Updating augmentation level (image size remains 224x224). Current size index: {CURRENT_SIZE_INDEX}")

            # --- 1.  NO Model Reloading ---

            # --- 2. Update Transforms (Keep 224x224, change augmentations) ---
            #  We'll increase the *strength* of augmentations, but keep the resize
            #  to 224x224.  You could add more aggressive augmentations here
            #  in later stages.
            if CURRENT_SIZE_INDEX == 1:
                train_transforms = get_transforms(image_size=IMAGE_SIZE, train=True) # Use a stronger set of augmentations.
                train_transforms.transforms.insert(-1, A.RandomRotate90(p=0.7))
            elif CURRENT_SIZE_INDEX == 2:
                train_transforms = get_transforms(image_size=IMAGE_SIZE, train=True)
                train_transforms.transforms.insert(-1, A.RandomRotate90(p=0.7))
                train_transforms.transforms.insert(-1, A.HorizontalFlip(p=0.7))
            else:
                train_transforms = get_transforms(image_size=IMAGE_SIZE, train=True) # Use a stronger set of augmentations.

            val_transforms = get_transforms(image_size=IMAGE_SIZE, train=False)  # Val/Test always 224

            # --- 3. Recreate DataLoaders ---
            train_loader, val_loader, _ = create_data_loaders(ORIGINAL_DIR, train_transforms, val_transforms, BATCH_SIZE, NUM_WORKERS, CLASSES)

        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for i, (images, labels) in enumerate(train_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = 100 * correct_train / total_train

        model.eval()
        val_running_loss = 0.0
        correct_val = 0
        total_val = 0
        val_preds = []
        val_true = []

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(DEVICE)
                labels = labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
                val_preds.extend(predicted.cpu().numpy())
                val_true.extend(labels.cpu().numpy())

        val_loss = val_running_loss / len(val_loader.dataset)
        val_accuracy = 100 * correct_val / total_val

        # Simplified report generation for debugging
        report = classification_report(val_true, val_preds, target_names=CLASSES, zero_division=0, output_dict=True)
        report = pd.DataFrame(report).transpose()


        # --- Checkpointing (based on accuracy) and W&B Artifact Logging ---
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_best_accuracy.pth")
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Saved best model (based on accuracy) to {checkpoint_path}")
            best_val_report = report  # Update best_val_report

            # --- Log model as W&B artifact ---
            artifact = wandb.Artifact(f"{model_name}_best_model", type="model")
            artifact.add_file(checkpoint_path)
            wandb.log_artifact(artifact)

        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_accuracy": train_accuracy,
            "val_loss": val_loss,
            "val_accuracy": val_accuracy,
            "image_size": IMAGE_SIZE[0],  # Log the *actual* image size (always 224)
        })
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%, Aug Size: {IMAGE_SIZE}") # Always print 224

    print("Finished Training")
    # Conditional logging: Only log the table if best_val_report is not empty
    if best_val_report is not None:
        print("Best Validation Classification Report:\n", best_val_report)
        wandb.log({"best_validation_classification_report": wandb.Table(dataframe=best_val_report)})
    return model

In [6]:
# --- Cell 6: Main Execution ---

# --- Initialize W&B ---
if wandb.run is None:  # Check if a run is already active
    run = wandb.init(project="vit", entity="w2sgarnav", name="w2sgarnav-vit", mode="offline")
    wandb.config.update({  # Log configuration
        "model_name": MODEL_NAME,
        "pretrained": PRETRAINED,
        "learning_rate": LEARNING_RATE,
        "epochs": EPOCHS,
        "image_size": IMAGE_SIZE,  # This will log the INITIAL image size. The training loop logs updates
        "batch_size": BATCH_SIZE,
    })

# --- Load Data Splits and Create DataLoaders ---

# 1. Get Transforms (using the *initial* progressive size for augmentations, but 224 for model input)
#  Use IMAGE_SIZE (224x224) here to ensure consistency.
train_transforms = get_transforms(image_size=IMAGE_SIZE, train=True)
val_transforms = get_transforms(image_size=IMAGE_SIZE, train=False)  # Always 224 for val/test

# 2.  DO NOT combine transforms here! Pass the Albumentations transforms directly.
# The transformations are applied in the CottonDataset.

# 3. Create DataLoaders
train_loader, val_loader, _ = create_data_loaders(  # We don't need test_loader yet
    ORIGINAL_DIR, train_transforms, val_transforms, BATCH_SIZE, NUM_WORKERS, CLASSES
)


# --- Get Model ---
model = get_model()  # Model is created with 224 expectation

# --- Train Model ---
train_model(model, train_loader, val_loader, model_name=MODEL_NAME)
wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


--- Loaded File Paths (First 5 of each) ---
Train: ['/teamspace/studios/this_studio/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Bacterial Blight/BBC00250.jpg', '/teamspace/studios/this_studio/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Herbicide Growth Damage/HGD00175.jpg', '/teamspace/studios/this_studio/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Curl Virus/CV00093.jpg', '/teamspace/studios/this_studio/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Herbicide Growth Damage/HGD00042.jpg', '/teamspace/studios/this_studio/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Bacterial Blight/BBC00041.jpg']
Validation: ['/teamspace/studios/this_studio/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Da

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
image_size,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▇▇█▇████▇████▇▇▅▅▇█████████▁▆█▆████████
train_loss,█▃▂▂▁▁▁▁▂▁▁▁▂▃▂▄▄▂▁▁▁▁▁▁▁▁▁▃▂▂▄▂▂▂▁▁▁▁▁▁
val_accuracy,▃▇▆█▇▇▇█▆▇▅█▇▇▆▇▇█▆▇█▇▇▇▇▇▇██▅▁▃█▅▆▆▇▆▇▇
val_loss,▂▁▂▂▂▁▃▄▂▇▃▅▂▂▂▃▃▃▁▂▂▂▂▂▂▆█▆▂▃▅▃▃▁▂▃▂▃▂▂

0,1
epoch,60.0
image_size,224.0
train_accuracy,98.28259
train_loss,0.04702
val_accuracy,96.26168
val_loss,0.11157
