In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import shutil
import random
from collections import Counter
from pathlib import Path
from PIL import Image
import imgaug.augmenters as iaa

# Define paths
dataset_path = "path_to_your_dataset"  # Replace with the actual dataset path
balanced_dataset_path = "balanced_dataset"

# Create new directory for balanced dataset
os.makedirs(balanced_dataset_path, exist_ok=True)

# Count images in each class
class_counts = {cls: len(os.listdir(os.path.join(dataset_path, cls))) for cls in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, cls))}
min_samples = min(class_counts.values())  # Determine minimum class count to balance dataset

# Augmentation function (optional)
augmenter = iaa.Sequential([
    iaa.Fliplr(0.5),  # Flip 50% images horizontally
    iaa.Affine(rotate=(-20, 20)),  # Rotate images randomly
])

def augment_image(image_path, output_path):
    image = Image.open(image_path)
    image = image.convert("RGB")  # Ensure consistency
    augmented_image = augmenter(image=np.array(image))
    Image.fromarray(augmented_image).save(output_path)

# Balancing dataset
for cls, count in class_counts.items():
    class_dir = os.path.join(dataset_path, cls)
    new_class_dir = os.path.join(balanced_dataset_path, cls)
    os.makedirs(new_class_dir, exist_ok=True)

    images = os.listdir(class_dir)

    if count > min_samples:  # Downsampling
        selected_images = random.sample(images, min_samples)
    else:  # Upsampling
        selected_images = images.copy()
        while len(selected_images) < min_samples:
            img_to_augment = random.choice(images)
            new_img_name = f"aug_{len(selected_images)}.jpg"
            augment_image(os.path.join(class_dir, img_to_augment), os.path.join(new_class_dir, new_img_name))
            selected_images.append(new_img_name)

    # Copy selected images to new dataset
    for img in selected_images:
        src = os.path.join(class_dir, img)
        dst = os.path.join(new_class_dir, img)
        shutil.copy(src, dst)

print("Dataset balancing complete. Balanced dataset saved at:", balanced_dataset_path)
