In [None]:
import os
import shutil
import random
from tqdm import tqdm
from collections import defaultdict

random_seed = 42
num_images_per_bird = [50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 600, 700, 800, 900, 1000]
base_directory = "det_mask_frames" # all pre processed frames
output_root = "train_val_data" # in this case data from the baseline dataset

ALLOWED_BIRD_IDS = {
    "BRG-YOM",
    "EYB-RPM",
    "BNU-RPM",
    "GBM-ORY",
    "OGY-BRM",
    "GBY-ORM",
    "RYO-BGM",
    "OYR-BGM",
    "OEB-RPM",
    "ORB-UYM",
    "YM-OBR",
    "BNY-RPM",
    "OUB-RPM",
    "YRU-POM",
    "RGY-BOM",
    "BRK-NOM",
    "YGO-M"
}

random.seed(random_seed)
os.makedirs(output_root, exist_ok=True)

def list_images(dir_path):
    return [f for f in os.listdir(dir_path) if f.lower().endswith(('.jpg', '.png'))]

def copy_images(file_list, src_dir, dst_dir):
    for img_name in file_list:
        src_path = os.path.join(src_dir, img_name)
        dst_path = os.path.join(dst_dir, img_name)
        shutil.copy2(src_path, dst_path)

def safe_split(files):
    """
    Deterministic 70/20/10 split by counts; works for any n>=0.
    n==0: all empty
    n==1: train=1
    n==2: train=1, val=1
    n>=3: try 70/20/10 with rounding and ensure all >=1
    """
    n = len(files)
    if n == 0:
        return [], [], []
    if n == 1:
        return files[:1], [], []
    if n == 2:
        return files[:1], files[1:2], []

    n_train = int(round(0.7 * n))
    n_val   = int(round(0.2 * n))
    n_test  = n - n_train - n_val

    # ensure at least 1 in each split
    if n_train < 1: n_train = 1
    if n_val   < 1: n_val   = 1
    if n_test  < 1:
        # borrow from the largest split
        n_test = 1
        if n_train >= n_val and n_train > 1:
            n_train -= 1
        elif n_val > 1:
            n_val -= 1
        else:
            # fallback: take from train if both are 1 (n==3 case handled anyway)
            n_train -= 1

    # final adjust to sum exactly n
    total = n_train + n_val + n_test
    if total != n:
        # fix by adjusting the largest split
        dif = total - n
        # reduce largest by dif (dif could be + or -; but should only be +1/-1 typically)
        sizes = [('train', n_train), ('val', n_val), ('test', n_test)]
        sizes.sort(key=lambda x: x[1], reverse=True)
        # apply dif to largest while keeping >=1
        name, size = sizes[0]
        new_size = max(1, size - dif)
        delta = size - new_size
        if name == 'train': n_train = new_size
        elif name == 'val': n_val = new_size
        else: n_test = new_size
        # if still off by a little, adjust next
        total = n_train + n_val + n_test
        if total != n:
            name, size = sizes[1]
            dif = (n_train + n_val + n_test) - n
            new_size = max(1, size - dif)
            if name == 'train': n_train = new_size
            elif name == 'val': n_val = new_size
            else: n_test = new_size

    train = files[:n_train]
    val   = files[n_train:n_train + n_val]
    test  = files[n_train + n_val:]
    return train, val, test

# Precompute available image counts per allowed bird (once)
available_counts = {}
for bird_id in os.listdir(base_directory):
    if bird_id not in ALLOWED_BIRD_IDS:
        continue
    bird_dir = os.path.join(base_directory, bird_id)
    if not os.path.isdir(bird_dir):
        continue
    available_counts[bird_id] = len(list_images(bird_dir))

insufficient_report = defaultdict(list)

for num_images in tqdm(num_images_per_bird, desc="Processing subsets"):
    subset_directory = os.path.join(output_root, str(num_images))
    os.makedirs(subset_directory, exist_ok=True)

    train_dir = os.path.join(subset_directory, "train")
    val_dir   = os.path.join(subset_directory, "val")
    test_dir  = os.path.join(subset_directory, "test")
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    for bird_id, avail in available_counts.items():
        bird_directory = os.path.join(base_directory, bird_id)

        if avail < num_images:
            insufficient_report[num_images].append((bird_id, avail, num_images - avail))

        if avail == 0:
            # nothing to copy for this bird in this subset
            continue

        image_files = list_images(bird_directory)
        random.shuffle(image_files)

        # Use up to requested number (or all if fewer)
        selected_images = image_files[:min(num_images, avail)]
        train_files, val_files, test_files = safe_split(selected_images)

        # Create bird-specific dirs
        train_bird_dir = os.path.join(train_dir, bird_id)
        val_bird_dir   = os.path.join(val_dir, bird_id)
        test_bird_dir  = os.path.join(test_dir, bird_id)
        os.makedirs(train_bird_dir, exist_ok=True)
        os.makedirs(val_bird_dir, exist_ok=True)
        os.makedirs(test_bird_dir, exist_ok=True)

        # Copy files
        copy_images(train_files, bird_directory, train_bird_dir)
        copy_images(val_files, bird_directory, val_bird_dir)
        copy_images(test_files, bird_directory, test_bird_dir)

print("\nBirds with insufficient instances")
if not insufficient_report:
    print("All birds met all requested subset sizes")
else:
    for subset_size in sorted(insufficient_report.keys()):
        rows = insufficient_report[subset_size]
        print(f"\nSubset size {subset_size}: {len(rows)} bird(s) with insufficient instances")
        for bird_id, avail, missing in sorted(rows):
            print(f"  - {bird_id}: available={avail}, missing={missing}")
print("\nDone!")