In [None]:
import os
import Augmentor
from tqdm import tqdm
import shutil

# Set the paths for your dataset
input_path = "D:/Dataset/archive/train"
output_path = "D:/Dataset/archive1/train"


# Function to count the number of images in each class
def count_images_in_classes(dataset_path):
    class_counts = {}
    for class_folder in os.listdir(dataset_path):
        if os.path.isdir(os.path.join(dataset_path, class_folder)):
            class_counts[class_folder] = len(os.listdir(os.path.join(dataset_path, class_folder)))
    return class_counts

# Function to augment the dataset
def augment_dataset(input_path, output_path, target_count):
    # Create output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)

    # Count the number of images in each class
    class_counts = count_images_in_classes(input_path)

    # Find the class with the maximum number of images
    max_count = max(class_counts.values())

    # Augment classes with fewer images
    for class_folder, count in tqdm(class_counts.items(), desc="Augmenting"):
        source_path = os.path.join(input_path, class_folder)
        dest_path = os.path.join(output_path, class_folder)

        # Copy existing images
        shutil.copytree(source_path, dest_path)

        # Calculate the number of images to generate
        augment_count = max_count - count

        # Augment the class using Augmentor
        if augment_count > 0:
            try:
                p = Augmentor.Pipeline(source_directory=source_path, output_directory=dest_path)
                p.rotate(probability=0.7, max_left_rotation=10, max_right_rotation=10)
                p.flip_left_right(probability=0.5)
                p.zoom_random(probability=0.5, percentage_area=0.8)
                p.flip_top_bottom(probability=0.5)
                p.set_save_format("PNG")  # Change the save format to PNG
                p.sample(augment_count)
            except Exception as e:
                print(f"Error during augmentation of {class_folder}: {str(e)}")

if __name__ == "__main__":
    # Set the target number of images for balancing
    target_count = 2534

    augment_dataset(input_path, output_path, target_count)
