# 01 - Data Preparation & Transforms

This notebook handles the **data pipeline** for the plant disease classification project:

1. **Install** dependencies & configure Kaggle credentials
2. **Download** the dataset from Kaggle (YOLO detection format)
3. **Organise** images from YOLO format into class-per-folder layout
4. **Create subset** by sampling N images per class for train/val/test
5. **Define transforms** (CLAHE preprocessing + augmentations)
6. **Build datasets & dataloaders** and verify with sample visualisation

**Run this notebook first**, then open `02_train_convnextv2.ipynb` for training.

**Dataset:** [Plant Disease Detection](https://www.kaggle.com/datasets/ironwolf437/plant-disease-detection-dataset) (7 classes, YOLO format)

## 1. Setup & Installation

In [2]:
!pip install -q kaggle albumentations pyyaml

In [None]:
# Ensure CWD is the project root
# Works on both Colab and local Jupyter
import os
from pathlib import Path

def find_project_root():
    """Find project root by looking for data directory or configs"""
    current = Path(os.getcwd()).resolve()
    
    # Check if we're already in project root (has data/ or configs/)
    if (current / "data").is_dir() or (current / "configs").is_dir():
        return current
    
    # Check if we're in notebooks/ subdirectory
    if current.name == "notebooks":
        parent = current.parent
        if (parent / "data").is_dir() or (parent / "configs").is_dir():
            return parent
    
    # Search upward from current directory
    for parent in current.parents:
        if (parent / "data").is_dir() or (parent / "configs").is_dir():
            return parent
    
    # On Colab, try common repo name
    colab_repo = Path("/content/plant-disease-classification")
    if colab_repo.exists() and ((colab_repo / "data").is_dir() or (colab_repo / "configs").is_dir()):
        return colab_repo
    
    return current  # Fallback

PROJECT_ROOT = find_project_root()
os.chdir(PROJECT_ROOT)
print(f"Working directory: {os.getcwd()}")

## 2. Imports

In [3]:
import os
import random
import shutil
import zipfile
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import yaml
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from torch.utils.data import DataLoader, Dataset

## 3. Configuration

All data-related settings in one place. Adjust subset sizes, image size, or CLAHE parameters here.

In [4]:
@dataclass
class DataConfig:
    """Configuration for data preparation and transforms."""

    # --- Kaggle ---
    kaggle_dataset: str = "ironwolf437/plant-disease-detection-dataset"

    # --- Paths ---
    raw_dir: str = "data/raw"
    organised_dir: str = "data/organised"
    subset_dir: str = "data/subset"

    # --- Subset sampling ---
    train_per_class: int = 50
    val_per_class: int = 15
    test_per_class: int = 10

    # --- Transforms ---
    image_size: int = 224
    clahe_clip_limit: float = 2.0
    clahe_tile_grid: Tuple[int, int] = (8, 8)

    # --- DataLoader ---
    batch_size: int = 32
    num_workers: int = 2
    pin_memory: bool = True

    # --- Reproducibility ---
    seed: int = 42


cfg = DataConfig()
random.seed(cfg.seed)
np.random.seed(cfg.seed)
print(cfg)

DataConfig(kaggle_dataset='ironwolf437/plant-disease-detection-dataset', raw_dir='data/raw', organised_dir='data/organised', subset_dir='data/subset', train_per_class=50, val_per_class=15, test_per_class=10, image_size=224, clahe_clip_limit=2.0, clahe_tile_grid=(8, 8), batch_size=32, num_workers=2, pin_memory=True, seed=42)


## 5. DataPreparer

Handles the full data pipeline in three steps:

| Step | Method | What it does |
|------|--------|--------------|
| 1 | `download()` | Download & unzip Kaggle dataset (YOLO format) |
| 2 | `organise()` | Parse YOLO labels → sort images into class-per-folder layout |
| 3 | `create_subset()` | Random sample N images per class per split |

Each step is **idempotent** — skips if output already exists. Use `force=True` to redo.

In [6]:
IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}


class DataPreparer:
    """Download, organise, and subset the plant disease dataset.

    The Kaggle dataset is in YOLO detection format (images/ + labels/).
    This class converts it to class-per-folder layout suitable for
    PyTorch ImageFolder-style loading.
    """

    def __init__(self, cfg: DataConfig):
        self.cfg = cfg
        self.raw_dir = Path(cfg.raw_dir)
        self.organised_dir = Path(cfg.organised_dir)
        self.subset_dir = Path(cfg.subset_dir)

    # ----- Download -----

    def download(self):
        """Download and unzip the dataset from Kaggle."""
        self.raw_dir.mkdir(parents=True, exist_ok=True)
        zip_path = self.raw_dir / "plant-disease-detection-dataset.zip"

        if not zip_path.exists() and not any(self.raw_dir.iterdir()):
            print(f"[download] Downloading from Kaggle: {self.cfg.kaggle_dataset}")
            from kaggle.api.kaggle_api_extended import KaggleApi
            api = KaggleApi()
            api.authenticate()
            api.dataset_download_files(
                self.cfg.kaggle_dataset, path=str(self.raw_dir), unzip=False
            )
        else:
            print(f"[download] Already present in {self.raw_dir}")

        if zip_path.exists():
            print("[download] Unzipping...")
            with zipfile.ZipFile(zip_path, "r") as zf:
                zf.extractall(self.raw_dir)
            zip_path.unlink()
            print("[download] Done.")

    # ----- YOLO -> class folders -----

    def _parse_class_names(self) -> List[str]:
        """Read class names from data.yaml in the raw directory."""
        with open(self.raw_dir / "data.yaml") as f:
            data_cfg = yaml.safe_load(f)
        names = data_cfg["names"]
        print(f"[yolo] {len(names)} classes: {names}")
        return names

    @staticmethod
    def _get_image_class(label_path: Path) -> Optional[int]:
        """Read YOLO label file and return the dominant (most frequent) class id.

        Each label line: class_id x_center y_center width height
        Returns None if the file is missing or empty.
        """
        if not label_path.exists():
            return None
        class_ids = []
        with open(label_path) as f:
            for line in f:
                parts = line.strip().split()
                if parts:
                    class_ids.append(int(parts[0]))
        if not class_ids:
            return None
        return Counter(class_ids).most_common(1)[0][0]

    def organise(self, force: bool = False):
        """Convert YOLO format to class-per-folder layout.

        Maps: raw/{train,valid,test}/images -> organised/{train,val,test}/{class}/
        """
        if not force and self.organised_dir.exists() and any(self.organised_dir.iterdir()):
            print(f"[organise] Already done at {self.organised_dir}, skipping.")
            return

        if force and self.organised_dir.exists():
            shutil.rmtree(self.organised_dir)

        class_names = self._parse_class_names()
        split_map = {"train": "train", "valid": "val", "test": "test"}

        for raw_split, out_split in split_map.items():
            images_dir = self.raw_dir / raw_split / "images"
            labels_dir = self.raw_dir / raw_split / "labels"
            if not images_dir.exists():
                continue

            counts = Counter()
            skipped = 0

            for img_file in sorted(images_dir.iterdir()):
                if img_file.suffix.lower() not in IMAGE_EXTS:
                    continue
                label_file = labels_dir / (img_file.stem + ".txt")
                class_id = self._get_image_class(label_file)

                if class_id is None or class_id >= len(class_names):
                    skipped += 1
                    continue

                class_name = class_names[class_id]
                dst = self.organised_dir / out_split / class_name
                dst.mkdir(parents=True, exist_ok=True)
                shutil.copy2(img_file, dst / img_file.name)
                counts[class_name] += 1

            total = sum(counts.values())
            print(f"[organise] {out_split}: {total} images ({skipped} skipped)")
            for cls in sorted(counts):
                print(f"    {cls}: {counts[cls]}")

    # ----- Subset -----

    def create_subset(self, force: bool = False):
        """Sample a small subset from the organised dataset."""
        if (not force
            and (self.subset_dir / "train").exists()
            and any((self.subset_dir / "train").iterdir())):
            print(f"[subset] Already exists at {self.subset_dir}, skipping.")
            self._print_summary()
            return

        if force and self.subset_dir.exists():
            shutil.rmtree(self.subset_dir)

        rng = random.Random(self.cfg.seed)
        limits = {
            "train": self.cfg.train_per_class,
            "val": self.cfg.val_per_class,
            "test": self.cfg.test_per_class,
        }

        for split, n_per_class in limits.items():
            split_src = self.organised_dir / split
            if not split_src.exists():
                continue

            classes = sorted(
                d.name for d in split_src.iterdir()
                if d.is_dir() and not d.name.startswith(".")
            )
            for cls in classes:
                cls_dir = split_src / cls
                images = sorted(
                    f.name for f in cls_dir.iterdir()
                    if f.is_file() and f.suffix.lower() in IMAGE_EXTS
                )
                sampled = images if len(images) <= n_per_class else rng.sample(images, n_per_class)

                dst = self.subset_dir / split / cls
                dst.mkdir(parents=True, exist_ok=True)
                for img_name in sampled:
                    shutil.copy2(cls_dir / img_name, dst / img_name)

        self._print_summary()

    def _print_summary(self):
        """Print image counts per split."""
        print(f"\n[subset] Summary:")
        for split in ["train", "val", "test"]:
            split_path = self.subset_dir / split
            if not split_path.exists():
                print(f"  {split}: (not found)")
                continue
            n_images = sum(
                len(list((split_path / c).iterdir()))
                for c in os.listdir(split_path)
                if (split_path / c).is_dir()
            )
            n_classes = len([
                c for c in os.listdir(split_path)
                if (split_path / c).is_dir()
            ])
            print(f"  {split}: {n_images} images across {n_classes} classes")

    def run(self, force: bool = False):
        """Execute full pipeline: download -> organise -> subset."""
        self.download()
        self.organise(force=force)
        self.create_subset(force=force)

### Run data preparation

In [7]:
data_preparer = DataPreparer(cfg)
data_preparer.run()

[download] Downloading from Kaggle: ironwolf437/plant-disease-detection-dataset


OSError: Could not find kaggle.json. Make sure it's located in /root/.config/kaggle. Or use the environment method. See setup instructions at https://github.com/Kaggle/kaggle-api/

## 6. Dataset & Transforms

**CLAHE** (Contrast Limited Adaptive Histogram Equalization) is applied as a **deterministic preprocessing step** (p=1.0) in both train and val/test to handle greenhouse fog/LED lighting.

| Pipeline | Transforms |
|----------|------------|
| **Train** | Resize -> CLAHE -> HFlip, VFlip, Rotate90, ShiftScaleRotate, ColorJitter, GaussNoise, GaussianBlur -> Normalize -> ToTensor |
| **Val/Test** | Resize -> CLAHE -> Normalize -> ToTensor |

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]


class PlantDiseaseDataset(Dataset):
    """PyTorch Dataset for plant disease images in class-per-folder layout.

    Expects:
        root_dir/
            class_a/
                img1.jpg
            class_b/
                img2.jpg
    """

    def __init__(self, root_dir: str, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples: List[Tuple[str, int]] = []
        self.classes: List[str] = []
        self.class_to_idx: Dict[str, int] = {}
        self._scan()

    def _scan(self):
        self.classes = sorted(
            d for d in os.listdir(self.root_dir)
            if os.path.isdir(os.path.join(self.root_dir, d))
        )
        self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}

        for cls in self.classes:
            cls_dir = os.path.join(self.root_dir, cls)
            for fname in sorted(os.listdir(cls_dir)):
                ext = os.path.splitext(fname)[1].lower()
                if ext in IMAGE_EXTS:
                    self.samples.append(
                        (os.path.join(cls_dir, fname), self.class_to_idx[cls])
                    )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = np.array(Image.open(img_path).convert("RGB"))
        if self.transform:
            image = self.transform(image=image)["image"]
        return image, label


class TransformFactory:
    """Build Albumentations pipelines for train and val/test."""

    def __init__(self, cfg: DataConfig):
        self.cfg = cfg

    def train(self) -> A.Compose:
        """Training pipeline: CLAHE + augmentations + normalise."""
        return A.Compose([
            A.Resize(self.cfg.image_size, self.cfg.image_size),
            A.CLAHE(
                clip_limit=self.cfg.clahe_clip_limit,
                tile_grid_size=self.cfg.clahe_tile_grid,
                p=1.0,
            ),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.1, scale_limit=0.15, rotate_limit=15, p=0.5
            ),
            A.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5
            ),
            A.GaussNoise(p=0.3),
            A.GaussianBlur(blur_limit=(3, 7), p=0.2),
            A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2(),
        ])

    def val(self) -> A.Compose:
        """Val/test pipeline: CLAHE + normalise (no augmentation)."""
        return A.Compose([
            A.Resize(self.cfg.image_size, self.cfg.image_size),
            A.CLAHE(
                clip_limit=self.cfg.clahe_clip_limit,
                tile_grid_size=self.cfg.clahe_tile_grid,
                p=1.0,
            ),
            A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2(),
        ])

### Build datasets and dataloaders

In [None]:
tf = TransformFactory(cfg)

train_dataset = PlantDiseaseDataset(os.path.join(cfg.subset_dir, "train"), transform=tf.train())
val_dataset   = PlantDiseaseDataset(os.path.join(cfg.subset_dir, "val"),   transform=tf.val())
test_dataset  = PlantDiseaseDataset(os.path.join(cfg.subset_dir, "test"),  transform=tf.val())

CLASS_NAMES = train_dataset.classes
NUM_CLASSES = len(CLASS_NAMES)

print(f"Classes ({NUM_CLASSES}): {CLASS_NAMES}")
print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

train_loader = DataLoader(
    train_dataset, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
)
val_loader = DataLoader(
    val_dataset, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
)
test_loader = DataLoader(
    test_dataset, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=cfg.pin_memory,
)

## 7. Verify: Sample Images

Visualise a batch of training images to confirm transforms are working correctly.  
Images are de-normalised back to [0, 1] for display.

In [None]:
def show_batch(dataset, class_names, n=8):
    """Display n random samples from the dataset."""
    mean = np.array(IMAGENET_MEAN)
    std = np.array(IMAGENET_STD)

    indices = random.sample(range(len(dataset)), min(n, len(dataset)))
    cols = min(4, n)
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
    if rows == 1:
        axes = [axes] if cols == 1 else list(axes)
    else:
        axes = axes.flatten()

    for ax_idx, data_idx in enumerate(indices):
        img_tensor, label = dataset[data_idx]
        img_np = img_tensor.permute(1, 2, 0).numpy()
        img_np = img_np * std + mean
        img_np = np.clip(img_np, 0, 1)

        axes[ax_idx].imshow(img_np)
        axes[ax_idx].set_title(class_names[label], fontsize=10)
        axes[ax_idx].axis("off")

    for ax_idx in range(len(indices), len(axes)):
        axes[ax_idx].axis("off")

    plt.suptitle("Sample Training Images (after transforms)", fontsize=13)
    plt.tight_layout()
    plt.show()


show_batch(train_dataset, CLASS_NAMES, n=8)

## 8. Class Distribution

In [None]:
def plot_class_distribution(dataset, class_names, title="Class Distribution"):
    """Bar chart of samples per class."""
    counts = Counter(label for _, label in dataset.samples)
    names = [class_names[i] for i in range(len(class_names))]
    values = [counts.get(i, 0) for i in range(len(class_names))]

    fig, ax = plt.subplots(figsize=(10, 4))
    bars = ax.bar(names, values, color="steelblue")
    ax.set_ylabel("Count")
    ax.set_title(title)
    ax.bar_label(bars)
    plt.xticks(rotation=30, ha="right")
    plt.tight_layout()
    plt.show()


plot_class_distribution(train_dataset, CLASS_NAMES, "Train Subset - Class Distribution")

## Done

Data is ready. Outputs:
- `data/organised/{train,val,test}/{class}/` — full dataset in class-folder layout
- `data/subset/{train,val,test}/{class}/` — small subset for quick training

**Next step:** Open `02_train_convnextv2.ipynb` to train the model.