In [None]:
import os
import random
import shutil
from pathlib import Path
import random
from PIL import Image
from tqdm import tqdm

In [9]:
def build_biased_shards(
    root_dir="data",
    output_dir="shards",
    dominant_count=800,
    other_count=50,
    num_shards=5
):

    root = Path(root_dir)
    out = Path(output_dir)
    
    # Remove entire shard folder if exists
    if out.exists():
        shutil.rmtree(out)
    out.mkdir(parents=True, exist_ok=True)

    # Get class folders
    class_folders = sorted([p for p in root.iterdir() if p.is_dir()])
    num_classes = len(class_folders)

    assert num_classes == num_shards, "You said 5 shards for 5 classes — counts must match."

    # Collect all images per class
    class_to_images = {}
    for c in class_folders:
        imgs = list(c.glob("*.jpeg")) + list(c.glob("*.jpg")) + list(c.glob("*.png"))
        random.shuffle(imgs)
        class_to_images[c.name] = imgs

    # Build each shard
    for shard_idx, dominant_class_folder in enumerate(class_folders):
        shard_name = f"shard_{shard_idx}"
        shard_path = out / shard_name
        shard_path.mkdir(parents=True, exist_ok=True)

        dominant_class = dominant_class_folder.name

        print(f"\nBuilding {shard_name} (dominant class: {dominant_class})")

        shard_images = []

        # 800 samples from dominant class
        dominant_imgs = class_to_images[dominant_class][:dominant_count]
        shard_images.extend([(img, dominant_class) for img in dominant_imgs])

        # 50 samples from all other classes
        for other_folder in class_folders:
            if other_folder.name == dominant_class:
                continue
            other_imgs = class_to_images[other_folder.name][:other_count]
            shard_images.extend([(img, other_folder.name) for img in other_imgs])

        # Save images to disk
        for img_path, cls_name in tqdm(shard_images):
            cls_dir = shard_path / cls_name
            cls_dir.mkdir(exist_ok=True)
            shutil.copy(img_path, cls_dir / img_path.name)

        print(f"{shard_name} completed: {len(shard_images)} images")

    print("\nFinished building all shards!")

In [10]:
build_biased_shards()


Building shard_0 (dominant class: cat)


100%|██████████| 1000/1000 [00:00<00:00, 1901.98it/s]


shard_0 completed: 1000 images

Building shard_1 (dominant class: chicken)


100%|██████████| 1000/1000 [00:00<00:00, 2371.92it/s]


shard_1 completed: 1000 images

Building shard_2 (dominant class: dog)


100%|██████████| 1000/1000 [00:00<00:00, 2297.02it/s]


shard_2 completed: 1000 images

Building shard_3 (dominant class: elephant)


100%|██████████| 1000/1000 [00:00<00:00, 2030.15it/s]


shard_3 completed: 1000 images

Building shard_4 (dominant class: horse)


100%|██████████| 1000/1000 [00:00<00:00, 2243.63it/s]

shard_4 completed: 1000 images

Finished building all shards!





In [12]:
import shutil

# HIGH LEAKAGE
# for 3 classes experiment
# remove cats in 2nd shard for evaluation

folder_path = "shards/shard_2/cat"

# This will delete the folder and all its contents
shutil.rmtree(folder_path)

In [None]:
import shutil

# LOW LEAKAGE
# for 5 classes experiment
# remove cats in all other chunk except neighbor
folder_path = ["shards/shard_2/cat", "shards/shard_3/cat", "shards/shard_4/cat"]

# This will delete the folder and all its contents
for i in folder_path:
    shutil.rmtree(i)