In [37]:

import os
import cv2
import shutil
import random
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import albumentations as A
import matplotlib.pyplot as plt
from collections import defaultdict
from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

config = {
    "csv_path": "processed_data/cleaned_imbalance_metadata.csv",
    "label_encoder_path": "processed_data/le_cleaned_imbalance_metadata.npy",
    "val_set_csv": "processed_data/new_val_metadata.csv",
    "balanced_train_csv": "processed_data/new_balanced_train_metadata.csv",
    "original_images_dir": "Dataset/train_images",
    "augmented_images_dir": "Dataset/SMOT_images_temp",
    "merged_output_dir": "Dataset/train_images_balanced",

}

def load_and_preprocess_data(random_state=42):
    df = pd.read_csv(config["csv_path"])
    
    le = LabelEncoder()
    df['label_encoded'] = le.fit_transform(df['label'])
    print(f"Label classes: {le.classes_}")
    
    with open(config['label_encoder_path'], 'wb') as f:
        np.save(f, le.classes_)
    
    train_df, val_df = train_test_split(
        df, 
        test_size=0.2, 
        stratify=df['label'],
        random_state=random_state,
    )
    
    train_df.to_csv(config['train_set_csv'], index=False)
    val_df.to_csv(config['val_set_csv'], index=False)
    
    return train_df, val_df


In [38]:
def move_images_based_on_csv(csv_path, src_dir, dest_dir):
    df = pd.read_csv(csv_path)
    for _, row in tqdm(df.iterrows(), total=len(df)):
        src = os.path.join(src_dir, row['label'], row['image_id'])
        dst = os.path.join(dest_dir, row['label'], row['image_id'])
        os.makedirs(os.path.dirname(dst), exist_ok=True)
        shutil.copy2(src, dst)

In [39]:
def generate_augmented_images(df_resampled, original_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    aug = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.CLAHE(p=0.5),
        A.HueSaturationValue(p=0.5, hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20),
        A.RandomBrightnessContrast(p=0.5, brightness_limit=0.2, contrast_limit=0.2),
        A.GaussNoise(p=0.5, var_limit=(10.0, 50.0)),        
    ])

    # Track existing augmented images to avoid duplicates
    existing_augmented = set()
    for label in os.listdir(output_dir):
        label_dir = os.path.join(output_dir, label)
        if os.path.isdir(label_dir):
            existing_augmented.update(
                os.path.join(label, f) 
                for f in os.listdir(label_dir) 
                if f.endswith('.jpg')
            )

    generated_count = 0

    for _, row in tqdm(df_resampled.iterrows(), total=len(df_resampled)):
        base_name = os.path.splitext(row['image_id'])[0]
        new_filename = f"{base_name}_aug_{random.randint(1,100000)}.jpg"
        output_path = os.path.join(output_dir, row['label'], new_filename)        

            
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Load original image
        original_path = os.path.join(original_dir, row['label'], row['image_id'])
        original_img = cv2.imread(original_path)
        if original_img is None:
            continue
            
        # Generate and save augmented version
        augmented = aug(image=original_img)['image']
        cv2.imwrite(output_path, augmented)
        generated_count += 1

    print(f"Generated {generated_count} synthetic images in {output_dir}")

In [40]:
def merge_datasets(original_dir, smote_dir, output_dir):
    """
    Merges original and SMOTE-augmented images into a single dataset.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Iterate through all class subdirectories
    for label in tqdm(os.listdir(original_dir)):
        original_label_dir = os.path.join(original_dir, label)
        smote_label_dir = os.path.join(smote_dir, label)
        output_label_dir = os.path.join(output_dir, label)
        
        os.makedirs(output_label_dir, exist_ok=True)
        
        for img_file in os.listdir(original_label_dir):
            src = os.path.join(original_label_dir, img_file)
            dst = os.path.join(output_label_dir, img_file)
            if not os.path.exists(dst):
                shutil.copy2(src, dst)
        
        if os.path.exists(smote_label_dir):
            for img_file in os.listdir(smote_label_dir):
                src = os.path.join(smote_label_dir, img_file)
                dst = os.path.join(output_label_dir, img_file)
                if not os.path.exists(dst):
                    shutil.copy2(src, dst)

    print(f"Merged dataset created at: {output_dir}")

In [41]:

def count_images_in_folders(data_dir):
    counts = defaultdict(int)
    
    for label in os.listdir(data_dir):
        label_dir = os.path.join(data_dir, label)
        if os.path.isdir(label_dir):
            counts[label] = len([
                f for f in os.listdir(label_dir) 
                if f.lower().endswith(('.jpg', '.jpeg', '.png'))
            ])
            
    folder_counts = dict(counts)
    
    print("=== Image Counts by Label (From Folders) ===")
    for label, count in folder_counts.items():
        print(f"{label}: {count} images")
    print(f"TOTAL: {sum(folder_counts.values())} images")
    

In [42]:
def generate_final_dataset_from_smote_csv(smote_csv_path, original_dir, output_dir, augmentations_per_image=1):
    """
    Generates COMPLETELY NEW augmented dataset from SMOTE CSV
    - Doesn't preserve any original images
    - Creates new augmented versions for every entry in CSV
    - Can generate multiple variations per sample (augmentations_per_image)
    """
    os.makedirs(output_dir, exist_ok=True)
    df = pd.read_csv(smote_csv_path)
    
    # Strong augmentation pipeline
    aug = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.5),
        A.RandomBrightnessContrast(p=0.8),
        A.CLAHE(p=0.5),
        A.HueSaturationValue(p=0.5),
        A.GaussianBlur(blur_limit=(3, 7), p=0.3),
    ])

    for row_idx, row in tqdm(df.iterrows(), total=len(df)):
        # Load random original from same class
        class_dir = os.path.join(original_dir, row['label'])
        available_images = [f for f in os.listdir(class_dir) if f.endswith('.jpg')]
        
        if not available_images:
            continue
            
        original_img = cv2.imread(os.path.join(class_dir, random.choice(available_images)))
        
        # Generate N augmented versions
        for i in range(augmentations_per_image):
            # Create unique filename
            base_name = os.path.splitext(row['image_id'])[0]
            aug_filename = f"{base_name}_{row_idx}.jpg"  # Using row index instead of random number
            output_path = os.path.join(output_dir, row['label'], aug_filename)
            
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            
            # Apply augmentation and save
            augmented = aug(image=original_img)['image']
            cv2.imwrite(output_path, augmented)

    print(f"Generated {len(df)*augmentations_per_image} completely new augmented images in {output_dir}")


In [43]:
def handle_imbalance_and_augment(config):
    """Complete pipeline for handling imbalance and generating augmented images"""
    
    # Load and preprocess data
    df = pd.read_csv(config["csv_path"])
    le = LabelEncoder()
    df['label_encoded'] = le.fit_transform(df['label'])
    
    # Split before oversampling
    train_df, val_df = train_test_split(
        df,
        test_size=0.2,
        stratify=df['label'],
        random_state=42
    )
    
    print("\n=== Before SMOTE ===")
    print(train_df['label'].value_counts())
    
    target_counts = train_df['label'].value_counts().max()  
    
    balanced_dfs = []
    for label in train_df['label'].unique():
        label_df = train_df[train_df['label'] == label]
        needed = target_counts - len(label_df)
        
        if needed > 0:
            oversampled = label_df.sample(needed, replace=True, random_state=42)
            balanced_dfs.append(pd.concat([label_df, oversampled]))
        else:
            balanced_dfs.append(label_df)
    
    train_df_resampled = pd.concat(balanced_dfs)
    
    print("\n=== After Resampling ===")
    print(train_df_resampled['label'].value_counts())
    
    train_df_resampled.to_csv(config["balanced_train_csv"], index=False)
    val_df.to_csv(config["val_set_csv"], index=False)
    

    generate_final_dataset_from_smote_csv(config["balanced_train_csv"], config["original_images_dir"], config["merged_output_dir"])
   
    # shutil.rmtree(config["augmented_images_dir"])
    
    print("\n=== Final Counts ===")
    return count_images_in_folders(config["merged_output_dir"])

In [44]:
merged_output_dir = handle_imbalance_and_augment(config)
print("\n=== Augmented Images Directory ===")
print(merged_output_dir)
print("\n=== Original Images Directory ===")
print(count_images_in_folders(config["original_images_dir"]))
print("\n=== Merged Output Directory ===")
print(count_images_in_folders(config["merged_output_dir"]))
print("\n=== Final Counts ===")
final_counts = count_images_in_folders(config["merged_output_dir"])

  original_init(self, **validated_kwargs)



=== Before SMOTE ===
label
normal                      1393
blast                       1364
hispa                       1245
dead_heart                  1130
tungro                       856
brown_spot                   749
downy_mildew                 487
bacterial_leaf_blight        373
bacterial_leaf_streak        294
bacterial_panicle_blight     269
Name: count, dtype: int64

=== After Resampling ===
label
normal                      1393
dead_heart                  1393
blast                       1393
tungro                      1393
hispa                       1393
bacterial_leaf_blight       1393
brown_spot                  1393
downy_mildew                1393
bacterial_panicle_blight    1393
bacterial_leaf_streak       1393
Name: count, dtype: int64


100%|██████████| 13930/13930 [01:07<00:00, 206.82it/s]

Generated 13930 completely new augmented images in Dataset/train_images_balanced

=== Final Counts ===
=== Image Counts by Label (From Folders) ===
normal: 1393 images
dead_heart: 1393 images
blast: 1393 images
tungro: 1393 images
hispa: 1393 images
bacterial_leaf_blight: 1393 images
brown_spot: 1393 images
downy_mildew: 1393 images
bacterial_panicle_blight: 1393 images
bacterial_leaf_streak: 1393 images
TOTAL: 13930 images

=== Augmented Images Directory ===
None

=== Original Images Directory ===
=== Image Counts by Label (From Folders) ===
bacterial_leaf_blight: 479 images
bacterial_leaf_streak: 380 images
bacterial_panicle_blight: 337 images
blast: 1738 images
brown_spot: 965 images
dead_heart: 1442 images
downy_mildew: 620 images
hispa: 1594 images
normal: 1764 images
tungro: 1088 images
TOTAL: 10407 images
None

=== Merged Output Directory ===
=== Image Counts by Label (From Folders) ===
normal: 1393 images
dead_heart: 1393 images
blast: 1393 images
tungro: 1393 images
hispa: 139




In [45]:
df_resampled = pd.read_csv(config["balanced_train_csv"])
print(df_resampled['image_id'].value_counts())

image_id
101965.jpg    12
100058.jpg    11
102861.jpg    10
105100.jpg    10
105533.jpg    10
              ..
109547.jpg     1
107430.jpg     1
108070.jpg     1
105177.jpg     1
100535.jpg     1
Name: count, Length: 8160, dtype: int64
