In [2]:
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img

def load_images_from_directory(directory, target_size=(256, 256)):
    images = []
    for filename in os.listdir(directory):
        img_path = os.path.join(directory, filename)
        if os.path.isfile(img_path):
            img = load_img(img_path, target_size=target_size)
            img_array = img_to_array(img)
            images.append(img_array)
    return np.array(images)

def save_augmented_images(generator, images, save_dir, prefix, n_augmented_images):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    count = 0
    for i in range(n_augmented_images):
        for batch in generator.flow(images, batch_size=1, save_to_dir=save_dir, save_prefix=prefix, save_format='jpg'):
            count += 1
            if count >= n_augmented_images:
                break
        if count >= n_augmented_images:
            break

# Parameters
base_directory = '/Users/sarath/Documents/Files/Dataset'
output_directory = '/Users/sarath/Documents/Files/NewDataset'
target_size = (256, 256)
prefix = 'augmented'  # Prefix for augmented image filenames

# Create an ImageDataGenerator with augmentation parameters
datagen = ImageDataGenerator(
    rotation_range=40,
    shear_range=0.2,
    
)

# Load images and calculate the maximum number of images in a class
class_folders = [os.path.join(base_directory, class_folder) for class_folder in os.listdir(base_directory) if os.path.isdir(os.path.join(base_directory, class_folder))]
class_image_counts = {class_folder: len(os.listdir(class_folder)) for class_folder in class_folders}
max_images = max(class_image_counts.values())

# Oversample each class to match the class with the maximum number of images
for class_folder in class_folders:
    images = load_images_from_directory(class_folder, target_size)
    num_images_to_generate = max_images - len(images)
    if num_images_to_generate > 0:
        class_name = os.path.basename(class_folder)
        save_augmented_images(datagen, images, os.path.join(output_directory, class_name), prefix, num_images_to_generate)
        print(f"Generated {num_images_to_generate} augmented images for class '{class_name}'.")

print("Oversampling complete.")


Generated 61 augmented images for class 'boron-B'.
Generated 66 augmented images for class 'potasium-K'.
Generated 97 augmented images for class 'iron-Fe'.
Oversampling complete.
