In [7]:
# --- Cell 1: General Imports ---
import os
import sys
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage.segmentation import slic
from skimage.color import label2rgb
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights, vit_b_16, ViT_B_16_Weights

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, precision_recall_fscore_support

import wandb  # For experiment tracking
import optuna  # For hyperparameter optimization (used later)

import albumentations as A  # For image augmentations
from albumentations.pytorch import ToTensorV2

from pytorch_grad_cam import GradCAM  # For Grad-CAM (used later)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [8]:
# --- Cell 2: Self-Supervised Learning Imports ---
import torch
import torch.nn as nn
import torchvision
from torchvision.models import vit_b_16, ViT_B_16_Weights
#import lightly # Remove - we're defining SimCLR components ourselves
#from lightly.models.modules.heads import SimCLRProjectionHead # Remove
#from lightly.loss import NTXentLoss # Remove
from lightly.data import LightlyDataset
from lightly.transforms import SimCLRTransform, utils
import pytorch_lightning as pl  # We'll use PyTorch Lightning for easier training
from pytorch_lightning.callbacks import ModelCheckpoint

In [9]:
# --- Cell 3: Configuration ---
import os
import torch

# --- Dataset ---
DATASET_ROOT = "/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection"  # YOUR DATASET PATH.  UPDATE THIS!
ORIGINAL_DIR = os.path.join(DATASET_ROOT, "Original Dataset")
AUGMENTED_DIR = os.path.join(DATASET_ROOT, "Augmented Dataset")

CLASSES = [
    "Bacterial Blight",
    "Curl Virus",
    "Healthy Leaf",
    "Herbicide Growth Damage",
    "Leaf Hopper Jassids",
    "Leaf Redding",
    "Leaf Variegation",
]
NUM_CLASSES = len(CLASSES)
CLASS_MAP = {i: name for i, name in enumerate(CLASSES)}

# --- Training ---
IMAGE_SIZE = (224, 224)  # Model ALWAYS expects 224x224
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
EPOCHS = 10
NUM_WORKERS = 6
VAL_SIZE = 0.2
TEST_SIZE = 0.2
RANDOM_STATE = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Model ---
MODEL_NAME = "vit_b_16"
PRETRAINED = True
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- Progressive Resizing (Modified) ---
# We'll still use this for *when* to increase augmentations, but NOT for model input size
PROGRESSIVE_SIZES = [(128, 128), (224, 224), (384, 384)]
CURRENT_SIZE_INDEX = 0  # Start with the smallest size

In [10]:
# --- Cell 4: Data Loading (SSL) ---
from data_utils import (CottonDataset, get_transforms, segment_leaf,
                        create_data_loaders, PILTransform)

# Define SimCLR Augmentations
simclr_transforms = SimCLRTransform(
    input_size=224,  # ViT uses 224x224
    gaussian_blur=0.5,  # Add gaussian blur
)

# Create a LightlyDataset
# 1. Use your CottonDataset to load images, but with *training* transforms.
train_transforms = get_transforms(train=True)
train_dataset = CottonDataset(ORIGINAL_DIR, transform=PILTransform(train_transforms))

# 2. Use LightlyDataset.from_torch_dataset to create a compatible dataset
ssl_train_dataset = LightlyDataset.from_torch_dataset(train_dataset, transform=simclr_transforms)

# Create a dataloader
ssl_train_dataloader = DataLoader(
    ssl_train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,  # Important for contrastive learning
    num_workers=NUM_WORKERS,
    pin_memory=True,  # Use pin_memory for faster data transfer to GPU
)

Found 2137 files
First 5 file paths: ['/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Healthy Leaf/HL00226.jpg', '/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Healthy Leaf/HL00113.jpg', '/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Healthy Leaf/HL00102.jpg', '/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Healthy Leaf/HL00100.jpg', '/home/w2sg-arnav/8-phases/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Original Dataset/Healthy Leaf/HL00098.jpg']
Dataset length: 2137
Dataset length: 2137


In [11]:
# --- Cell 5: SimCLR Model Definition ---
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.optim import AdamW

class SimCLRModel(pl.LightningModule):
    def __init__(self, backbone, num_ftrs, out_dim=128, learning_rate=1e-4):
        super().__init__()
        self.save_hyperparameters(ignore=['backbone'])  # Add ignore=['backbone']
        self.backbone = backbone
        self.backbone.heads = nn.Identity()
        self.projection_head = SimCLRProjectionHead(num_ftrs, hidden_dim=512, out_dim=out_dim)
        self.criterion = NTXentLoss(temperature=0.5)
        self.learning_rate = learning_rate

    def forward(self, x):
        x = self.backbone(x)
        z = self.projection_head(x)
        return z

    def training_step(self, batch, batch_idx):
        (x0, x1), _ = batch  # Correct unpacking
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        self.log("train_loss_ssl", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer


class SimCLRProjectionHead(nn.Module):
    def __init__(self, in_dim, hidden_dim=512, out_dim=128):
        super().__init__()
        self.projection_head = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim),
        )

    def forward(self, x):
        return self.projection_head(x)


class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
        self.cosine_similarity = nn.CosineSimilarity(dim=2)

    def forward(self, z0, z1):
        batch_size = z0.size(0)
        z = torch.cat([z0, z1], dim=0)
        sim_matrix = self.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
        labels = torch.arange(batch_size, device=z.device)
        labels = torch.cat([labels + batch_size, labels])
        loss = nn.functional.cross_entropy(sim_matrix, labels)
        return loss

In [12]:
# --- Cell 6: Create ViT Backbone, SimCLR Model, Trainer, and Train (SSL) ---
from data_utils import get_vit_model
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar #for a better progress bar

# 1. Create the ViT Backbone
vit_backbone = get_vit_model(pretrained=True)  # Use ImageNet pre-trained weights
num_ftrs = vit_backbone.heads[-1].in_features

# 2. Create the SimCLR Model
model = SimCLRModel(vit_backbone, num_ftrs)

# 3. Train with PyTorch Lightning (Trainer setup *and* execution)
trainer = pl.Trainer(
    max_epochs=10,  # Adjust as needed.
    devices=1,  # Use 1 GPU
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    default_root_dir="ssl_checkpoints",  # Save checkpoints here
    enable_progress_bar=True,  # Enable the progress bar
    #strategy="ddp_notebook",  # For multi-GPU training
    #progress_bar_refresh_rate=1 # Use if progress bar is flickering
    #callbacks=[RichProgressBar()]
)

trainer.fit(model, ssl_train_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                 | Params | Mode 
-----------------------------------------------------------------
0 | backbone        | VisionTransformer    | 85.8 M | train
1 | projection_head | SimCLRProjectionHead | 459 K  | train
2 | criterion       | NTXentLoss           | 0      | train
-----------------------------------------------------------------
86.3 M    Trainable params
0         Non-trainable params
86.3 M    Total params
345.032   Total estimated model params size (MB)
158       Modules in train mode
0         Modules in eval mode


Dataset length: 2137
Dataset length: 2137
Dataset length: 2137
Dataset length: 2137
Dataset length: 2137
Dataset length: 2137
Dataset length: 2137


Training: |          | 0/? [00:00<?, ?it/s]

Dataset length: 2137


ValueError: too many values to unpack (expected 2)

In [None]:
# --- Cell 9: Save Pre-trained Backbone ---
torch.save(model.backbone.state_dict(), "pretrained_vit_backbone.pth")

In [None]:
# --- Cell 7: Trainer Setup (SSL) ---
trainer = pl.Trainer(
    max_epochs=10,
    devices=1,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    default_root_dir="ssl_checkpoints",
    # strategy="ddp_notebook"  # For multi-GPU training (optional, requires setup)
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/w2sg-arnav/anaconda3/envs/cotton_env/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [None]:
# --- Cell 10: Supervised Fine-tuning Setup ---
finetune_model = get_vit_model(num_classes=NUM_CLASSES)
finetune_model.load_state_dict(torch.load("pretrained_vit_backbone.pth"), strict=False)

# Get training and validation transforms
train_transforms = get_transforms(train=True)
val_transforms = get_transforms(train=False)

# Create data loaders for supervised learning
train_loader, val_loader, test_loader = create_data_loaders(
    ORIGINAL_DIR, train_transforms, val_transforms, BATCH_SIZE, NUM_WORKERS, CLASSES
)

# Freeze the weights of the backbone
for name, param in finetune_model.named_parameters():
    if "heads" not in name:
        param.requires_grad = False

# (Optional) Unfreeze some layers later
# for param in finetune_model.encoder.layers[-2:].parameters():
#     param.requires_grad = True

# Optimizer and loss function (setup - training loop is later)
optimizer = optim.AdamW(finetune_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [None]:
# --- Cell 11: WandB Initialization ---
# Place *before* the training loop.
import wandb

if wandb.run is None:
    run = wandb.init(project="vit", entity="w2sgarnav", name="w2sgarnav-vit", mode="offline")
    wandb.config.update({
        "model_name": MODEL_NAME,
        "pretrained": PRETRAINED,  # Whether we used ImageNet pre-training for the backbone
        "learning_rate": LEARNING_RATE,
        "epochs": EPOCHS,
        "image_size": IMAGE_SIZE,
        "batch_size": BATCH_SIZE,
    })

In [None]:
# --- Cell 12:  Trainer Setup (Supervised, with Checkpointing) ---
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath="finetune_checkpoints/",
    filename="{epoch}-{val_loss_epoch:.2f}",  # Include epoch and val_loss in filename
    monitor="val_loss_epoch", # Monitor validation loss for best model
    save_top_k=1,             # Save only the best model
    mode="min",               # Lower validation loss is better
)

trainer = pl.Trainer(
    max_epochs=EPOCHS,
    devices=1,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    default_root_dir="finetune_checkpoints",
    callbacks=[checkpoint_callback],  # Add the checkpointing callback
)

In [None]:
# --- Cell 13: FineTuneModel Definition (PyTorch Lightning) ---
# Includes WandB logging within the training and validation steps.

class FineTuneModel(pl.LightningModule):
    def __init__(self, model, learning_rate=LEARNING_RATE):
        super().__init__()
        self.model = model
        self.criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        _, predicted = torch.max(outputs, 1)
        accuracy = (predicted == labels).sum().item() / labels.size(0)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        wandb.log({"train_loss": loss, "train_acc": accuracy})  # Log to WandB
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)
        _, predicted = torch.max(outputs, 1)
        accuracy = (predicted == labels).sum().item() / labels.size(0)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc", accuracy, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        wandb.log({"val_loss": loss, "val_acc": accuracy})  # Log to WandB
        return loss
    
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x for x in outputs]).mean()
        self.log('val_loss_epoch', avg_loss)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.model.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
# --- Cell 14: Supervised Fine-tuning - Training Execution ---
# Wrap the fine-tune model
finetune_pl_model = FineTuneModel(finetune_model) #instantiate model
# Train the model
trainer.fit(finetune_pl_model, train_loader, val_loader)

In [None]:
# --- Cell 15: WandB Finish ---
wandb.finish()

In [None]:
# --- Cell 16: Supervised Fine-tuning - Save Fine-tuned Model ---
#Although we have checkpoints, this can act as a backup, good practice
torch.save(finetune_pl_model.model.state_dict(), "finetuned_vit_model.pth")

In [None]:
# --- Cell 17: Comprehensive Evaluation - Setup ---
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
import seaborn as sns

finetune_pl_model.model.eval()  # Switch to evaluation mode
test_preds = []
test_true = []
test_probs = []  # For ROC and PR curves

In [None]:
# --- Cell 18: Comprehensive Evaluation - Prediction Loop ---
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)
        outputs = finetune_pl_model.model(images)
        _, predicted = torch.max(outputs.data, 1)
        test_preds.extend(predicted.cpu().numpy())
        test_true.extend(labels.cpu().numpy())
        test_probs.extend(F.softmax(outputs, dim=1).cpu().numpy())

test_probs = np.array(test_probs)

In [None]:
# --- Cell 19: Comprehensive Evaluation - Metrics and Plots ---

# --- Classification Report ---
print(classification_report(test_true, test_preds, target_names=CLASSES))

# --- Confusion Matrix ---
conf_mat = confusion_matrix(test_true, test_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(conf_mat, annot=True, fmt="d", cmap="Blues", xticklabels=CLASSES, yticklabels=CLASSES)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

# --- ROC Curve and AUC (for each class) ---
plt.figure(figsize=(10, 8))
for i in range(NUM_CLASSES):
    fpr, tpr, _ = roc_curve(test_true, test_probs[:, i], pos_label=i)
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'{CLASSES[i]} (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.show()

# --- Precision-Recall Curve (for each class) ---
plt.figure(figsize=(10, 8))
for i in range(NUM_CLASSES):
    precision, recall, _ = precision_recall_curve(test_true, test_probs[:, i], pos_label=i)
    ap = average_precision_score(test_true, test_probs[:, i], pos_label=i)
    plt.plot(recall, precision, label=f'{CLASSES[i]} (AP = {ap:.2f})')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")
plt.show()

In [None]:
# --- Cell 20: Comprehensive Evaluation - WandB Logging and Saving Results ---

# --- Log to WandB ---
if wandb.run is not None:  # Check if WandB is initialized
    wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(
        probs=None, y_true=test_true, preds=test_preds, class_names=CLASSES
    )})
    wandb.log({"roc": wandb.plot.roc_curve(test_true, test_probs, CLASSES)})
    wandb.log({"pr": wandb.plot.pr_curve(test_true, test_probs, CLASSES)})

# --- Save Results to Files ---
eval_dir = "evaluation_results"
os.makedirs(eval_dir, exist_ok=True)

# Save classification report
report = classification_report(test_true, test_preds, target_names=CLASSES, output_dict=True)
report_df = pd.DataFrame(report).transpose()
report_df.to_csv(os.path.join(eval_dir, "classification_report.csv"))

# Save confusion matrix
conf_mat_df = pd.DataFrame(conf_mat, index=CLASSES, columns=CLASSES)
conf_mat_df.to_csv(os.path.join(eval_dir, "confusion_matrix.csv"))

In [None]:
# --- Cell 21: Loading a Trained Model ---
# Load the best model checkpoint.

# Find the best checkpoint file
checkpoint_dir = "finetune_checkpoints"
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "*.ckpt"))
if checkpoint_files:
    # Find the checkpoint with the lowest validation loss
    best_checkpoint = min(checkpoint_files, key=lambda x: float(x.split("=")[-1].split(".ckpt")[0]))
    print(f"Loading best checkpoint: {best_checkpoint}")

    # Load the best model checkpoint
    loaded_model = FineTuneModel.load_from_checkpoint(best_checkpoint, model=get_vit_model(num_classes=NUM_CLASSES))
    loaded_model.eval()  # Set to evaluation mode
    loaded_model.to(DEVICE) # Move to the correct device
else:
    print("No checkpoint files found.  Make sure you trained the model and checkpoints were saved.")
    loaded_model = None # Set to None to avoid errors

In [None]:
# --- Cell 22: Prediction on a Single Image ---

def predict_single_image(model, image_path, transform):
    """Predicts the class of a single image.

    Args:
        model: The trained PyTorch model.
        image_path: Path to the image file.
        transform: The transformation to apply to the image.

    Returns:
        predicted_class: The predicted class index.
        probabilities: A tensor of class probabilities.
    """
    image = Image.open(image_path).convert("RGB")
    image = transform(image=np.array(image))['image'] # Apply transformations
    image = image.unsqueeze(0).to(DEVICE)  # Add batch dimension and move to device

    model.eval()  # Ensure the model is in evaluation mode
    with torch.no_grad():
        output = model(image)
        probabilities = F.softmax(output, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()

    return predicted_class, probabilities.cpu().numpy()


# Example usage (check if the model loaded correctly):
if loaded_model is not None:
    image_path_to_predict = "/path/to/your/image.jpg"   # CHANGE THIS to a real image path
    if os.path.exists(image_path_to_predict):
        predicted_class, probabilities = predict_single_image(loaded_model.model, image_path_to_predict, val_transforms)

        print(f"Predicted Class: {CLASSES[predicted_class]}")
        print(f"Probabilities: {probabilities}")

        # Display the image (optional):
        img = Image.open(image_path_to_predict)
        plt.imshow(img)
        plt.title(f"Predicted: {CLASSES[predicted_class]}")
        plt.show()
    else:
        print(f"Error: Image file not found at {image_path_to_predict}")
else:
    print("Error: Model was not loaded successfully. Cannot make predictions.")