In [1]:
pip install albumentations opencv-python


Defaulting to user installation because normal site-packages is not writeableNote: you may need to restart the kernel to use updated packages.



In [2]:
import albumentations as A

IMG_SIZE = (224, 224)

heavy_aug = A.Compose([
    A.Rotate(limit=25, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(0.3, 0.3, p=0.8),
    A.RandomResizedCrop(
        size=IMG_SIZE,
        scale=(0.85, 1.0),   # MUST be <= 1
        ratio=(0.9, 1.1),
        p=0.8
    ),
    A.GaussianBlur(blur_limit=(3, 5), p=0.3)
])

medium_aug = A.Compose([
    A.Rotate(limit=15, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(0.2, 0.2, p=0.6),
    A.RandomResizedCrop(
        size=IMG_SIZE,
        scale=(0.95, 1.0),
        ratio=(0.95, 1.05),
        p=0.5
    )
])

light_aug = A.Compose([
    A.Rotate(limit=10, p=1.0),
    A.HorizontalFlip(p=0.3),
    A.RandomBrightnessContrast(0.1, 0.1, p=0.3)
])


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
AUG_MAP = {
    "bacterial_panicle_blight": heavy_aug,
    "bacterial_leaf_streak": heavy_aug,
    "bacterial_leaf_blight": heavy_aug,

    "downy_mildew": medium_aug

    # No entry = no augmentation
}


In [4]:
import cv2
import os
from tqdm import tqdm

INPUT_DIR = r"C:\Users\rk001\Downloads\paddy-disease-classification\train_images"
OUTPUT_DIR = "train_aug"
AUG_PER_IMAGE = 3

os.makedirs(OUTPUT_DIR, exist_ok=True)

for cls in os.listdir(INPUT_DIR):
    cls_in = os.path.join(INPUT_DIR, cls)
    cls_out = os.path.join(OUTPUT_DIR, cls)
    os.makedirs(cls_out, exist_ok=True)

    aug = AUG_MAP.get(cls)   # may be None for high-count classes

    for img_name in tqdm(os.listdir(cls_in), desc=cls):
        img_path = os.path.join(cls_in, img_name)
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # save original image
        cv2.imwrite(
            os.path.join(cls_out, img_name),
            cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        )

        # ðŸ”‘ FIX: skip augmentation if aug is None
        if aug is None:
            continue

        # generate augmented images
        for i in range(AUG_PER_IMAGE):
            augmented = aug(image=image)["image"]
            aug_name = img_name.replace(".jpg", f"_aug{i}.jpg")

            cv2.imwrite(
                os.path.join(cls_out, aug_name),
                cv2.cvtColor(augmented, cv2.COLOR_RGB2BGR)
            )


bacterial_leaf_blight:   0%|          | 0/479 [00:00<?, ?it/s]

bacterial_leaf_blight: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 479/479 [00:29<00:00, 16.23it/s]
bacterial_leaf_streak: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 380/380 [00:11<00:00, 32.89it/s]
bacterial_panicle_blight: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 337/337 [00:11<00:00, 28.49it/s]
blast: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1738/1738 [00:15<00:00, 109.73it/s]
brown_spot: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 965/965 [00:17<00:00, 56.02it/s] 
dead_heart: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1442/1442 [00:31<00:00, 46.26it/s]
downy_mildew: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 620/620 [00:31<00:00, 19.75it/s]
hispa: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1594/1594 [00:41<00:00, 38.30it/s]
normal: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1764/1764 [00:41<00:00, 42.10it/s]
tungro: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1088/1088 [00:24<00:00, 44.83it/s]
