In [22]:
import os
from PIL import Image
from collections import defaultdict

# === CONFIGURATION ===
NUM_FOLDERS = None           # Set to None to process all folders or a specific number (e.g., 5)
NUM_FILES_PER_FOLDER = None  # Set to None to process all cropped images per folder or a specific number (e.g., 5)
OUTPUT_PATH = 'out_data/'

import os
from collections import defaultdict

OUTPUT_PATH = 'out_data/'
DESIRED_COUNT = 300

def check_plu_balance(output_path=OUTPUT_PATH, desired_count=DESIRED_COUNT):
    """
    Scans a single folder (`output_path`) for .png images. 
    Group them by PLU (the filename prefix before the first '_').
    
    Returns:
        plu_to_files (dict): {plu: [list_of_png_filenames]}
        missing_dict (dict): {plu: how_many_images_missing_to_reach_desired_count}
        all_balanced (bool): True if every PLU has at least desired_count images
    """
    # Gather all .png files in out_data
    all_png = [f for f in os.listdir(output_path) if f.lower().endswith('.png')]
    
    # Organize by PLU
    plu_to_files = defaultdict(list)
    for filename in all_png:
        parts = filename.split('-', 1)  # split into [PLU, rest...]
        if len(parts) == 2:
            plu = parts[0]
            plu_to_files[plu].append(filename)
        else:
            # If there's no underscore, skip or treat the entire filename as PLU
            continue
    
    if not plu_to_files:
        print("No .png images found in the output folder.")
        return {}, {}, True
    
    # Build missing dict
    missing_dict = {}
    all_balanced = True
    
    for plu, files in plu_to_files.items():
        count = len(files)
        short = desired_count - count
        missing_dict[plu] = max(short, 0)
        if short > 0:
            all_balanced = False

    return plu_to_files, missing_dict, all_balanced
import random
from PIL import Image

def balance_plu_images(output_path=OUTPUT_PATH, desired_count=DESIRED_COUNT):
    """
    For each PLU that has fewer than `desired_count` images, 
    apply augmentations (rotations/flips) to the seed images 
    until the PLU has at least `desired_count` images.
    """
    plu_to_files, missing_dict, all_balanced = check_plu_balance(output_path, desired_count)
    
    # Early exit if everything's balanced
    if all_balanced:
        print(f"✅ All PLUs have at least {desired_count} images already.")
        return
    
    # Process each PLU
    for plu, missing in missing_dict.items():
        if missing <= 0:
            print(f"⏭  PLU {plu} is already at {len(plu_to_files[plu])} images, skipping.")
            continue
        
        # We'll only treat '_cropped.png' files as seeds
        seed_files = [f for f in plu_to_files[plu] if f.endswith('_cropped.png')]
        
        # If no cropped files are found, you might want to treat all .png as seed:
        if not seed_files:
            print(f"⚠️  PLU {plu} has no '_cropped.png' seeds. Using all files as seeds.")
            seed_files = plu_to_files[plu]

        current_count = len(plu_to_files[plu])
        print(f"⚙️  Balancing PLU {plu}. Currently has {current_count}, needs {missing} more.")
        
        seed_index = 0  # We'll cycle through seeds in round-robin fashion
        
        while current_count < desired_count:
            seed_name = seed_files[seed_index]
            seed_index = (seed_index + 1) % len(seed_files)

            seed_path = os.path.join(output_path, seed_name)
            base_plu = seed_name.split('-', 1)[0]  # e.g. "4011"
            base_rest = seed_name.split('-', 1)[1].replace('.png','')  # e.g. "19_cropped"

            try:
                img = Image.open(seed_path)
            except Exception as e:
                print(f"⚠️  Could not open {seed_path}: {e}")
                continue

            # Produce a few augmentations for each seed pass
            # until we reach desired_count. 
            # For example, let's produce up to 4 new images per loop:
            
            # 1) Rotate 90
            if current_count < desired_count:
                out_file = f"{base_plu}-{base_rest}_aug_rot90_{current_count}.png"
                out_path = os.path.join(output_path, out_file)
                img.rotate(90, expand=True).save(out_path)
                plu_to_files[plu].append(out_file)  # Update in-memory
                current_count += 1

            # 2) Horizontal Flip
            if current_count < desired_count:
                out_file = f"{base_plu}-{base_rest}_aug_flipH_{current_count}.png"
                out_path = os.path.join(output_path, out_file)
                img.transpose(Image.FLIP_LEFT_RIGHT).save(out_path)
                plu_to_files[plu].append(out_file)
                current_count += 1

            # 3) Vertical Flip
            if current_count < desired_count:
                out_file = f"{base_plu}-{base_rest}_aug_flipV_{current_count}.png"
                out_path = os.path.join(output_path, out_file)
                img.transpose(Image.FLIP_TOP_BOTTOM).save(out_path)
                plu_to_files[plu].append(out_file)
                current_count += 1

            # 4) Combined Flip + Rotate
            if current_count < desired_count:
                out_file = f"{base_plu}-{base_rest}_aug_flipH_rot90_{current_count}.png"
                out_path = os.path.join(output_path, out_file)
                flipped_h = img.transpose(Image.FLIP_LEFT_RIGHT)
                flipped_h.rotate(90, expand=True).save(out_path)
                plu_to_files[plu].append(out_file)
                current_count += 1
        
        print(f"✅ PLU {plu} is now at {current_count} images.")


if __name__ == '__main__':
    # Check how many images each PLU has initially
    plu_to_files, missing_dict, all_balanced = check_plu_balance(OUTPUT_PATH, DESIRED_COUNT)
    
    print("\n🔍 Before balancing:")
    for plu in sorted(plu_to_files.keys()):
        count = len(plu_to_files[plu])
        print(f"  PLU {plu}: has {count} images, needs {missing_dict[plu]}")
    print("Balanced already?", all_balanced)
    
    # Balance them
    balance_plu_images(OUTPUT_PATH, DESIRED_COUNT)

    # Check again
    plu_to_files, missing_dict, all_balanced = check_plu_balance(OUTPUT_PATH, DESIRED_COUNT)
    print("\n🔍 After balancing:")
    for plu in sorted(plu_to_files.keys()):
        count = len(plu_to_files[plu])
        print(f"  PLU {plu}: has {count} images, needs {missing_dict[plu]}")
    print("Balanced now?", all_balanced)



🔍 Before balancing:
  PLU 4011: has 117 images, needs 183
  PLU 4015: has 242 images, needs 58
  PLU 4088: has 175 images, needs 125
  PLU 4196: has 237 images, needs 63
  PLU 7020097009819: has 181 images, needs 119
  PLU 7020097026113: has 61 images, needs 239
  PLU 7023026089401: has 99 images, needs 201
  PLU 7035620058776: has 26 images, needs 274
  PLU 7037203626563: has 47 images, needs 253
  PLU 7037206100022: has 162 images, needs 138
  PLU 7038010009457: has 70 images, needs 230
  PLU 7038010013966: has 161 images, needs 139
  PLU 7038010021145: has 70 images, needs 230
  PLU 7038010054488: has 120 images, needs 180
  PLU 7038010068980: has 163 images, needs 137
  PLU 7039610000318: has 116 images, needs 184
  PLU 7040513000022: has 138 images, needs 162
  PLU 7040513001753: has 55 images, needs 245
  PLU 7040913336684: has 70 images, needs 230
  PLU 7044610874661: has 288 images, needs 12
  PLU 7048840205868: has 68 images, needs 232
  PLU 7071688004713: has 51 images, need