In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


In [None]:
import os
import random
import shutil
from collections import defaultdict
import yaml  # to read class names from data.yaml if available

# ----------------------------
# CONFIGURATION
# ----------------------------
DATASET_DIR = "/content/drive/MyDrive/kaggle_dataset/train"        # Path to original YOLO dataset
OUTPUT_DIR = "/content/drive/MyDrive/balanced_dataset"  # Where to save the subset
SUBSET_SIZE = 30000                     # Total number of images to include
EXT = ".jpg"                           # or ".png" depending on your dataset

# ----------------------------
# HELPER FUNCTIONS
# ----------------------------
def get_labels(label_path):
    """Read label file and return list of class IDs."""
    with open(label_path, "r") as f:
        lines = f.readlines()
    return [int(line.strip().split()[0]) for line in lines if line.strip()]

def collect_class_images(images_dir, labels_dir):
    """Map class IDs to image file paths."""
    class_to_images = defaultdict(list)

    for img_name in os.listdir(images_dir):
        if not img_name.endswith(EXT):
            continue
        label_name = os.path.splitext(img_name)[0] + ".txt"
        label_path = os.path.join(labels_dir, label_name)
        if not os.path.exists(label_path):
            continue

        class_ids = get_labels(label_path)
        for cid in set(class_ids):  # avoid counting same image multiple times per class
            class_to_images[cid].append(img_name)

    return class_to_images

def make_balanced_subset(class_to_images, total_size):
    """Randomly pick a balanced number of images per class."""
    n_classes = len(class_to_images)
    per_class = total_size // n_classes

    selected_images = set()
    for cid, img_list in class_to_images.items():
        sample_size = min(per_class, len(img_list))
        chosen = random.sample(img_list, sample_size)
        selected_images.update(chosen)

    return list(selected_images)

def copy_subset(images, src_img_dir, src_lbl_dir, dst_img_dir, dst_lbl_dir):
    """Copy selected images and their labels."""
    os.makedirs(dst_img_dir, exist_ok=True)
    os.makedirs(dst_lbl_dir, exist_ok=True)

    for img_name in images:
        label_name = os.path.splitext(img_name)[0] + ".txt"
        shutil.copy(os.path.join(src_img_dir, img_name), os.path.join(dst_img_dir, img_name))
        shutil.copy(os.path.join(src_lbl_dir, label_name), os.path.join(dst_lbl_dir, label_name))

def detect_subset_classes(selected_images, labels_dir):
    """Find which classes exist in the subset."""
    subset_classes = set()
    for img_name in selected_images:
        label_name = os.path.splitext(img_name)[0] + ".txt"
        label_path = os.path.join(labels_dir, label_name)
        if not os.path.exists(label_path):
            continue
        subset_classes.update(get_labels(label_path))
    return sorted(subset_classes)

def load_class_names(data_yaml_path):
    """Load class names from YOLO data.yaml if available."""
    if not os.path.exists(data_yaml_path):
        return None
    with open(data_yaml_path, "r") as f:
        data = yaml.safe_load(f)
    return data.get("names")

# ----------------------------
# MAIN LOGIC
# ----------------------------
def main():
    images_dir = os.path.join(DATASET_DIR, "images")
    labels_dir = os.path.join(DATASET_DIR, "labels")

    # Collect mapping of classes → image list
    print("Collecting class-to-image mapping...")
    class_to_images = collect_class_images(images_dir, labels_dir)

    print(f"Found {len(class_to_images)} classes in total:")
    for cid, imgs in class_to_images.items():
        print(f"  Class {cid}: {len(imgs)} images")

    # Select balanced subset
    print("\nSelecting balanced subset...")
    subset_images = make_balanced_subset(class_to_images, SUBSET_SIZE)
    print(f"Total selected images: {len(subset_images)}")

    # Copy images and labels to output folder
    dst_img_dir = os.path.join(OUTPUT_DIR, "images")
    dst_lbl_dir = os.path.join(OUTPUT_DIR, "labels")
    print("\nCopying files...")
    copy_subset(subset_images, images_dir, labels_dir, dst_img_dir, dst_lbl_dir)
    print("✅ Balanced subset created successfully!")

    # Detect and print which classes appear in the new subset
    subset_classes = detect_subset_classes(subset_images, labels_dir)
    print("\nClasses present in subset:", subset_classes)

    # If there's a data.yaml file, map class IDs to names
    yaml_path = os.path.join(DATASET_DIR, "data.yaml")
    class_names = load_class_names(yaml_path)
    if class_names:
        print("\nClass names present in subset:")
        for cid in subset_classes:
            name = class_names[cid] if cid < len(class_names) else "(unknown)"
            print(f"  {cid}: {name}")

if __name__ == "__main__":
    main()


Collecting class-to-image mapping...
