In [1]:
# Imports (common to many phases - consolidate at the top of your file)
import os
import sys  # For exiting on errors
import glob
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from skimage.segmentation import slic
from skimage.color import label2rgb

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights, vit_b_16, ViT_B_16_Weights

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report

# For experiment tracking (choose one - example with Weights & Biases)
import wandb

# For hyperparameter optimization (example with Optuna)
import optuna

# For image augmentations (albumentations is very powerful)
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For Grad-CAM
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [2]:
import torch
print(torch.cuda.is_available())  # Should print True
print(torch.cuda.device_count())   # Should print the number of GPUs
print(torch.cuda.get_device_name(0)) # Should print your GPU name

True
1
NVIDIA GeForce RTX 3050 6GB Laptop GPU


In [6]:
wandb.init(project="vit", entity="w2sgarnav", name="w2sgarnav-vit", mode="online")

In [9]:
# --- Configuration (Constants and Hyperparameters) ---
#  -- Dataset --
DATASET_ROOT = "/home/w2sg-arnav/msusir/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection"  # YOUR DATASET PATH
ORIGINAL_DIR = os.path.join(
    DATASET_ROOT, "Original Dataset"
)
AUGMENTED_DIR = os.path.join(
    DATASET_ROOT, "Augmented Dataset"
) # Use if you have augmentations

CLASSES = [
    "Bacterial Blight",
    "Curl Virus",
    "Healthy Leaf",
    "Herbicide Growth Damage",
    "Leaf Hopper Jassids",
    "Leaf Redding",
    "Leaf Variegation",
]
NUM_CLASSES = len(CLASSES)

# -- Training --
IMAGE_SIZE = (224, 224)  # Start with 224x224
BATCH_SIZE = 32
LEARNING_RATE = 1e-4  # Initial learning rate
EPOCHS = 30
NUM_WORKERS = 6  # Number of data loading workers (adjust based on your system)
VAL_SIZE = 0.2  # Proportion of data for validation
TEST_SIZE = 0.2  # Proportion of data for testing
RANDOM_STATE = 42  # For reproducibility
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [10]:
# Cell 2: Custom Dataset Class (CottonLeafDataset)

class CottonLeafDataset(Dataset):
    """
    Custom Dataset for loading cotton leaf images and labels.
    Handles reading images from a directory structure, applying
    optional transformations, and providing data to PyTorch's DataLoader.
    """

    def __init__(self, data_dir, transform=None, class_names=CLASSES):
        """
        Initializes the dataset.

        Args:
            data_dir (str): Path to the root directory of the dataset
                           (e.g., ORIGINAL_DIR).
            transform (callable, optional): Optional transform to be applied
                on a sample (e.g., from torchvision.transforms or
                albumentations).
            class_names (list): List of class names.
        """
        self.data_dir = data_dir
        self.transform = transform
        self.class_names = class_names
        self.image_paths, self.labels = self._load_data()


    def _load_data(self):
        """
        Loads image paths and corresponding labels from the dataset directory.

        Returns:
            tuple: (list of image paths, list of corresponding labels).
        """
        image_paths = []
        labels = []
        for i, class_name in enumerate(self.class_names):
            class_dir = os.path.join(self.data_dir, class_name)
            if not os.path.isdir(class_dir):
                print(f"Warning: Directory not found: {class_dir}")
                continue  # Skip this class if the directory is missing

            for image_path in glob.glob(os.path.join(class_dir, "*.jpg")) + glob.glob(
                os.path.join(class_dir, "*.jpeg")) + glob.glob(os.path.join(class_dir, "*.png")):
                image_paths.append(image_path)
                labels.append(i)  # Use numerical labels (0, 1, 2, ...)
        return image_paths, labels

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.image_paths)

    def __getitem__(self, idx):
        """
        Gets a single sample (image and label) from the dataset.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: (image, label), where image is a PyTorch tensor and
                   label is an integer.  Returns a black image and a
                   default label if the image can't be loaded.
        """
        image_path = self.image_paths[idx]
        try:
            image = Image.open(image_path).convert("RGB")  # Load image and convert to RGB
        except (IOError, FileNotFoundError) as e:
            print(f"Error loading image: {image_path} - {e}. Returning a blank image.")
            # Return a black image and a default label (0)
            return torch.zeros((3, *IMAGE_SIZE)), torch.tensor(0)

        label = self.labels[idx]  # Get the label

        if self.transform:
            image = self.transform(image)  # Apply transformations

        return image, label

In [11]:
# Cell 3: Data Splitting Function

def create_data_splits(data_dir, val_size=0.2, test_size=0.2, random_state=42):
    """Creates train/validation/test splits stratified by class."""
    image_files, labels = [], []
    for i, class_name in enumerate(CLASSES):
      class_path = os.path.join(data_dir, class_name)
      for file_path in glob.glob(os.path.join(class_path, "*")):
          if file_path.lower().endswith(('.png', '.jpg', '.jpeg')):  #check file extension
            image_files.append(file_path)
            labels.append(i)

    # First split: train + temp (for val and test)
    train_files, temp_files, train_labels, temp_labels = train_test_split(
        image_files, labels, test_size=(val_size + test_size), random_state=random_state, stratify=labels
    )

    # Second split: temp -> val + test
    val_files, test_files, val_labels, test_labels = train_test_split(
        temp_files, temp_labels, test_size=(test_size / (val_size + test_size)), random_state=random_state, stratify=temp_labels
    )
    return train_files, val_files, test_files, train_labels, val_labels, test_labels

In [12]:
# Cell 4: Data Augmentation and Preprocessing (Transforms)
import albumentations as A
from albumentations.pytorch import ToTensorV2

def get_transforms(image_size=IMAGE_SIZE, train=True):
    """
    Defines data augmentation and preprocessing transforms using Albumentations.

    Args:
        image_size (tuple):  The desired output image size (height, width).
        train (bool):  If True, returns transforms for training (including
            augmentations). If False, returns transforms for validation/testing
            (no augmentations).

    Returns:
        albumentations.core.composition.Compose:  An Albumentations Compose
        object containing the defined transformations.
    """
    if train:
        return A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),
            # Augmentations (for training only):
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
            A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # Normalization (ImageNet statistics):
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),  # Convert to PyTorch tensor
        ])
    else:
        # Validation/Test (no augmentation, just resize and normalize)
        return A.Compose([
            A.Resize(height=image_size[0], width=image_size[1]),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])

# Custom transform to use PIL and albumentation together.
class PILTransform(object):
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, img):
        img_np = np.array(img) # PIL -> NumPy
        augmented = self.transform(image = img_np) #pass to albumentations
        return augmented['image'] #returns a tensor.

In [17]:
# Cell 5: Image Preprocessing Functions (Revised normalize_orientation)

def normalize_orientation(image, mask):
    """
    Normalizes the orientation of the leaf using contour fitting and rotation.
    Handles cases where the contour is too small or invalid.
    """
    image_np = np.array(image)
    mask_np = np.array(mask)
    contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if contours:  # Check if any contours were found
        largest_contour = max(contours, key=cv2.contourArea)

        # --- Robustness Checks ---
        min_area = 100  # Minimum contour area threshold (adjust as needed)
        if cv2.contourArea(largest_contour) < min_area:
            print("Contour area too small, skipping orientation normalization.")
            return image, mask  # Return original image and mask

        if len(largest_contour) < 5:  # fitEllipse needs at least 5 points
            print("Contour has too few points, skipping orientation normalization.")
            return image, mask

        try:
            # --- Contour Approximation (Optional, but can help) ---
            epsilon = 0.01 * cv2.arcLength(largest_contour, True)  # 1% of perimeter
            approx_contour = cv2.approxPolyDP(largest_contour, epsilon, True)

            # --- Ellipse Fitting and Rotation ---
            ellipse = cv2.fitEllipse(approx_contour)  # Use the approximated contour
            angle = ellipse[2]
            (h, w) = image_np.shape[:2]
            center = (w // 2, h // 2)
            rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)

            # Use INTER_CUBIC for image, INTER_NEAREST for mask (important!)
            rotated_image = cv2.warpAffine(image_np, rotation_matrix, (w, h), flags=cv2.INTER_CUBIC)
            rotated_mask = cv2.warpAffine(mask_np, rotation_matrix, (w, h), flags=cv2.INTER_NEAREST)

            return Image.fromarray(rotated_image), Image.fromarray(rotated_mask)

        except cv2.error as e:
            print(f"OpenCV error during orientation normalization: {e}")
            print("Skipping orientation normalization for this image.")
            return image, mask  # Return original image and mask

    else:
        print("No contours found, skipping orientation normalization.")
        return image, mask # Return original image

    return image, mask # Return original image if no contours.

In [18]:
# Cell 6: Data Loader Creation

def create_data_loaders(train_files, val_files, test_files, train_labels, val_labels, test_labels, train_transform, val_transform, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):
    """Creates PyTorch DataLoaders for training, validation, and testing."""

    # Create Datasets
    train_dataset = CottonLeafDataset(ORIGINAL_DIR, transform=train_transform)  # Use the full dataset
    train_dataset.image_paths = train_files   #then override .image_paths and labels
    train_dataset.labels = train_labels

    val_dataset = CottonLeafDataset(ORIGINAL_DIR, transform=val_transform)  # Use the full dataset
    val_dataset.image_paths = val_files   #then override .image_paths and labels
    val_dataset.labels = val_labels


    test_dataset = CottonLeafDataset(ORIGINAL_DIR, transform=val_transform)  # Use the full dataset
    test_dataset.image_paths = test_files   #then override .image_paths and labels
    test_dataset.labels = test_labels

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory = True)

    return train_loader, val_loader, test_loader

In [19]:
# Cell 7: Data Loading and Preprocessing Pipeline (Putting it all together)

# 1. Create data splits
train_files, val_files, test_files, train_labels, val_labels, test_labels = create_data_splits(
    ORIGINAL_DIR, val_size=VAL_SIZE, test_size=TEST_SIZE, random_state=RANDOM_STATE
)

# 2. Get transforms
train_transforms = get_transforms(train=True)
val_transforms = get_transforms(train=False)  # Same as test transforms

# 3. Apply preprocessing and combine with albumentations.
train_transforms = transforms.Compose([
    PILTransform(train_transforms), #apply albumentations first
    transforms.Lambda(preprocess_image) #then preprocessing
])
val_transforms = transforms.Compose([
    PILTransform(val_transforms),  # Use same transforms for val/test
    transforms.Lambda(preprocess_image)
])
# 4. Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(
  train_files, val_files, test_files, train_labels, val_labels, test_labels,
    train_transforms, val_transforms,
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS
)

print(f"Number of training samples: {len(train_loader.dataset)}")
print(f"Number of validation samples: {len(val_loader.dataset)}")
print(f"Number of test samples: {len(test_loader.dataset)}")

Number of training samples: 1282
Number of validation samples: 427
Number of test samples: 428


  original_init(self, **validated_kwargs)


In [20]:
# Cell 8: Data Visualization (Optional, but highly recommended)

def visualize_data(dataloader, num_images=5, title="Data Samples"):
    """Visualizes a few sample images from a DataLoader."""
    images, labels = next(iter(dataloader))  # Get a batch of images and labels

    plt.figure(figsize=(12, 6))
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        # Convert from tensor to numpy array and transpose dimensions
        #   PyTorch tensors are (C, H, W), matplotlib expects (H, W, C)
        img = images[i].numpy().transpose((1, 2, 0))

        # Un-normalize the image (reverse the normalization)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean  # Un-normalize
        img = np.clip(img, 0, 1)  # Clip values to [0, 1]

        plt.imshow(img)
        plt.title(CLASSES[labels[i]])
        plt.axis("off")
    plt.suptitle(title)
    plt.show()

# Visualize training data (after augmentations and preprocessing)
visualize_data(train_loader, title="Training Data Samples (Augmented & Preprocessed)")

# Visualize validation data (after preprocessing)
visualize_data(val_loader, title="Validation Data Samples (Preprocessed)")

error: Caught error in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/w2sg-arnav/anaconda3/envs/cotton_env/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/w2sg-arnav/anaconda3/envs/cotton_env/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/w2sg-arnav/anaconda3/envs/cotton_env/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_121257/452904154.py", line 76, in __getitem__
    image = self.transform(image)  # Apply transformations
  File "/home/w2sg-arnav/anaconda3/envs/cotton_env/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
    img = t(img)
  File "/home/w2sg-arnav/anaconda3/envs/cotton_env/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 479, in __call__
    return self.lambd(img)
  File "/tmp/ipykernel_121257/1742540200.py", line 42, in preprocess_image
    segmented_image, leaf_mask = segment_leaf(image)
  File "/tmp/ipykernel_121257/1742540200.py", line 6, in segment_leaf
    image_lab = cv2.cvtColor(image_np, cv2.COLOR_RGB2Lab)
cv2.error: OpenCV(4.11.0) /io/opencv/modules/imgproc/src/color.simd_helpers.hpp:92: error: (-15:Bad number of channels) in function 'cv::impl::{anonymous}::CvtHelper<VScn, VDcn, VDepth, sizePolicy>::CvtHelper(cv::InputArray, cv::OutputArray, int) [with VScn = cv::impl::{anonymous}::Set<3, 4>; VDcn = cv::impl::{anonymous}::Set<3>; VDepth = cv::impl::{anonymous}::Set<0, 5>; cv::impl::{anonymous}::SizePolicy sizePolicy = cv::impl::<unnamed>::NONE; cv::InputArray = const cv::_InputArray&; cv::OutputArray = const cv::_OutputArray&]'
> Invalid number of channels in input image:
>     'VScn::contains(scn)'
> where
>     'scn' is 224

