In [1]:
# --- Cell 1: Imports ---
import os
import sys
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage.segmentation import slic
from skimage.color import label2rgb
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights, vit_b_16, ViT_B_16_Weights

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report, precision_recall_fscore_support

import wandb  # For experiment tracking
import optuna  # For hyperparameter optimization (used later)

import albumentations as A  # For image augmentations
from albumentations.pytorch import ToTensorV2

from pytorch_grad_cam import GradCAM  # For Grad-CAM (used later)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# New imports for progressive resizing and attention visualization
from torchvision.transforms.functional import resize
from torchvision.utils import make_grid
from einops import rearrange  # For tensor manipulation

In [None]:
# --- Cell 2: Configuration ---
import os
import torch

# --- Dataset ---
DATASET_ROOT = "/home/w2sg-arnav/cotton-disease/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection"
ORIGINAL_DIR = os.path.join(DATASET_ROOT, "Original Dataset")
AUGMENTED_DIR = os.path.join(DATASET_ROOT, "Augmented Dataset")

CLASSES = [
    "Bacterial Blight",
    "Curl Virus",
    "Healthy Leaf",
    "Herbicide Growth Damage",
    "Leaf Hopper Jassids",
    "Leaf Redding",
    "Leaf Variegation",
]
NUM_CLASSES = len(CLASSES)
CLASS_MAP = {i: name for i, name in enumerate(CLASSES)}

# --- Training ---
IMAGE_SIZE = (224, 224)  # ViT expects 224x224 images  # Start with 128x128 for progressive resizing
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
EPOCHS = 30
NUM_WORKERS = 6
VAL_SIZE = 0.2
TEST_SIZE = 0.2
RANDOM_STATE = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Model ---
MODEL_NAME = "vit_b_16"
PRETRAINED = True
CHECKPOINT_DIR = "checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- Progressive Resizing ---
# --- Progressive Resizing ---
PROGRESSIVE_SIZES = [(128, 128), (224, 224), (384, 384)]  # Progressive resizing steps
CURRENT_SIZE_INDEX = 0  # Start with the smallest size

In [3]:
# --- Cell 3: Data Loading Functions ---
class CottonLeafDataset(Dataset):
    def __init__(self, data_dir, transform=None, class_names=CLASSES):
        self.data_dir = data_dir
        self.transform = transform
        self.class_names = class_names
        self.image_paths, self.labels = self._load_data()

    def _load_data(self):
        image_paths = []
        labels = []
        for i, class_name in enumerate(self.class_names):
            class_dir = os.path.join(self.data_dir, class_name)
            if not os.path.isdir(class_dir):
                print(f"Warning: Directory not found: {class_dir}")
                continue

            for image_path in glob.glob(os.path.join(class_dir, "*.jpg")) + glob.glob(
                os.path.join(class_dir, "*.jpeg")) + glob.glob(os.path.join(class_dir, "*.png")):
                image_paths.append(image_path)
                labels.append(i)
        return image_paths, labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        try:
            image = Image.open(image_path).convert("RGB")
        except (IOError, FileNotFoundError) as e:
            print(f"Error loading image: {image_path} - {e}. Returning a blank image.")
            return torch.zeros((3, *IMAGE_SIZE)), torch.tensor(0)

        label = self.labels[idx]

        if self.transform:
            # Convert PIL image to NumPy array for Albumentations
            image_np = np.array(image)
            # Apply Albumentations transform with named argument
            transformed = self.transform(image=image_np)
            image = transformed["image"]  # Extract the transformed image

        return image, label

def create_data_splits(data_dir, val_size=0.2, test_size=0.2, random_state=42):
    image_files, labels = [], []
    for i, class_name in enumerate(CLASSES):
        class_path = os.path.join(data_dir, class_name)
        for file_path in glob.glob(os.path.join(class_path, "*")):
            if file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_files.append(file_path)
                labels.append(i)

    train_files, temp_files, train_labels, temp_labels = train_test_split(
        image_files, labels, test_size=(val_size + test_size), random_state=random_state, stratify=labels
    )
    val_files, test_files, val_labels, test_labels = train_test_split(
        temp_files, temp_labels, test_size=(test_size / (val_size + test_size)), random_state=random_state, stratify=temp_labels
    )
    return train_files, val_files, test_files, train_labels, val_labels, test_labels

def save_data_splits(train_files, val_files, test_files, output_dir="."):
    os.makedirs(output_dir, exist_ok=True)
    def _write_list_to_file(file_list, filename):
        with open(os.path.join(output_dir, filename), "w") as f:
            for item in file_list:
                f.write(f"{item}\n")

    _write_list_to_file(train_files, "train_files.txt")
    _write_list_to_file(val_files, "val_files.txt")
    _write_list_to_file(test_files, "test_files.txt")
    print(f"Data splits saved to {output_dir}")

def load_data_splits(output_dir="."):
    def _read_list_from_file(filename):
        with open(os.path.join(output_dir, filename), "r") as f:
            return [line.strip() for line in f]

    train_files = _read_list_from_file("train_files.txt")
    val_files = _read_list_from_file("val_files.txt")
    test_files = _read_list_from_file("test_files.txt")

    return train_files, val_files, test_files

def get_transforms(image_size=PROGRESSIVE_SIZES[CURRENT_SIZE_INDEX], train=True):
    if train:
        return A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),  # Resize to current progressive size
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
            A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),  # Resize to current progressive size
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

def create_data_loaders(data_dir, train_transform, val_transform, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):
    train_files, val_files, test_files = load_data_splits()
    train_labels = [CLASSES.index(filepath.split(os.sep)[-2]) for filepath in train_files]
    val_labels = [CLASSES.index(filepath.split(os.sep)[-2]) for filepath in val_files]
    test_labels = [CLASSES.index(filepath.split(os.sep)[-2]) for filepath in test_files]

    train_dataset = CottonLeafDataset(data_dir, transform=train_transform)
    train_dataset.image_paths = train_files
    train_dataset.labels = train_labels

    val_dataset = CottonLeafDataset(data_dir, transform=val_transform)
    val_dataset.image_paths = val_files
    val_dataset.labels = val_labels

    test_dataset = CottonLeafDataset(data_dir, transform=val_transform)
    test_dataset.image_paths = test_files
    test_dataset.labels = test_labels

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader, test_loader

In [4]:
# --- Cell 4: Model Definition ---
def get_model(model_name=MODEL_NAME, pretrained=PRETRAINED, num_classes=NUM_CLASSES):
    if model_name == "vit_b_16":
        weights = ViT_B_16_Weights.DEFAULT if pretrained else None
        model = vit_b_16(weights=weights)
        model.heads = nn.Linear(model.heads[0].in_features, num_classes)
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

    return model.to(DEVICE)

In [5]:
# --- Cell 5: Training Loop ---
def train_model(model, train_loader, val_loader, learning_rate=LEARNING_RATE, epochs=EPOCHS, checkpoint_dir=CHECKPOINT_DIR):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    best_val_loss = float('inf')

    for epoch in range(epochs):
        # Update image size for progressive resizing
        global CURRENT_SIZE_INDEX
        if epoch > 0 and epoch % (epochs // len(PROGRESSIVE_SIZES)) == 0:
            CURRENT_SIZE_INDEX = min(CURRENT_SIZE_INDEX + 1, len(PROGRESSIVE_SIZES) - 1)
            print(f"Updating image size to {PROGRESSIVE_SIZES[CURRENT_SIZE_INDEX]}")

            # Update transforms with the new image size
            train_transforms = get_transforms(image_size=PROGRESSIVE_SIZES[CURRENT_SIZE_INDEX], train=True)
            val_transforms = get_transforms(image_size=PROGRESSIVE_SIZES[CURRENT_SIZE_INDEX], train=False)

            # Recreate DataLoaders with the new transforms
            train_loader, val_loader, _ = create_data_loaders(ORIGINAL_DIR, train_transforms, val_transforms, BATCH_SIZE, NUM_WORKERS)

        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for i, (images, labels) in enumerate(train_loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = 100 * correct_train / total_train

        # Validation Loop
        model.eval()
        val_running_loss = 0.0
        correct_val = 0
        total_val = 0
        val_preds = []
        val_true = []

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(DEVICE)
                labels = labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

                val_preds.extend(predicted.cpu().numpy())
                val_true.extend(labels.cpu().numpy())

        val_loss = val_running_loss / len(val_loader.dataset)
        val_accuracy = 100 * correct_val / total_val

        # Logging to W&B
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_accuracy": train_accuracy,
            "val_loss": val_loss,
            "val_accuracy": val_accuracy,
            "image_size": PROGRESSIVE_SIZES[CURRENT_SIZE_INDEX][0],  # Log current image size
        })
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%, Image Size: {PROGRESSIVE_SIZES[CURRENT_SIZE_INDEX]}")

    print("Finished Training")

In [6]:
# --- Cell 6: Main Execution ---
import wandb

if wandb.run is None:
    run = wandb.init(project="vit", entity="w2sgarnav", name="w2sgarnav-vit-phase2", mode="offline")

# Get Transforms
train_transforms = get_transforms(image_size=IMAGE_SIZE, train=True)
val_transforms = get_transforms(image_size=IMAGE_SIZE, train=False)

# Create DataLoaders
train_loader, val_loader, _ = create_data_loaders(ORIGINAL_DIR, train_transforms, val_transforms, BATCH_SIZE, NUM_WORKERS)

# Get Model
model = get_model()

# Train Model
train_model(model, train_loader, val_loader)
wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  original_init(self, **validated_kwargs)


Epoch 1/15, Train Loss: 0.7287, Train Acc: 75.43%, Val Loss: 0.4039, Val Acc: 85.01%, Image Size: (128, 128)


KeyboardInterrupt: 