In [1]:
import os
from PIL import Image
from collections import defaultdict
import random
from math import ceil

In [4]:
def count_plu_types(directory):
    """
    Count how many files exist for each PLU (prefix before '-').
    Returns:
        - dict of {plu: list of files}
        - max_count (the highest number of files any PLU has)
    """
    files = [f for f in os.listdir(directory) if f.lower().endswith('.png')]
    plu_to_files = defaultdict(list)

    for f in files:
        parts = f.split("-", 1)
        if len(parts) == 2:
            plu = parts[0]
            plu_to_files[plu].append(f)

    max_count = max(len(v) for v in plu_to_files.values()) if plu_to_files else 0
    return plu_to_files, max_count

In [6]:
transformations = [
        ("rot90", lambda img: img.rotate(90, expand=True)),
        ("rot180", lambda img: img.rotate(180, expand=True)),
        ("flipH", lambda img: img.transpose(Image.FLIP_LEFT_RIGHT)),
        ("flipV", lambda img: img.transpose(Image.FLIP_TOP_BOTTOM)),
        ("flipH_rot90", lambda img: img.transpose(Image.FLIP_LEFT_RIGHT).rotate(90, expand=True))
    ]

In [None]:
def augment_randomly(image):
    """
    Applies one random transformation to the given image.
    
    Returns:
        - augmented_image (PIL.Image)
        - augmentation_name (str)
    """

    augmentation_name, transform_fn = random.choice(transformations)
    return transform_fn(image), augmentation_name


In [8]:
def balance_plu_to_max(directory):
    """
    Balances all PLUs in the given directory so each has the same number of images
    as the most prevalent PLU. Augments randomly to create more images.
    """
    plu_to_files, max_count = count_plu_types(directory)

    for plu, file_list in plu_to_files.items():
        idx = 0
        while (len(file_list) < max_count):
            random_file = augment_randomly(Image.open(os.path.join(directory, file_list[idx % len(file_list)])))
            augmented_image, augmentation_name = random_file
            new_filename = f"{plu}-{len(file_list)}-{augmentation_name}.png"
            augmented_image.save(os.path.join(directory, new_filename))
            file_list.append(new_filename)
            idx += 1
    print("Balanced all PLUs to the maximum count of images.")

In [9]:
balance_plu_to_max("cropped_training")
balance_plu_to_max("cropped_testing")

Balanced all PLUs to the maximum count of images.
Balanced all PLUs to the maximum count of images.
