In [4]:
import json
import random
from collections import defaultdict

dataset_dir = "dataset_coco/annotations"
dataset = "instances_train2017.json"
dataset_path = f"{dataset_dir}/{dataset}"
INPUT_JSON = dataset_path
OUTPUT_JSON = f"{dataset_dir}/instances_train2017_filtered_balanced.json"
CLASSES_TO_KEEP = ["chair", "motorcycle"]
N = 2000  # Total number of images to keep (must be even)

# --- Load COCO annotations ---
with open(INPUT_JSON, "r") as f:
    data = json.load(f)

# --- Get category IDs for the classes to keep ---
cat_name_to_id = {cat["name"]: cat["id"] for cat in data["categories"]}
keep_ids = {name: cat_name_to_id[name] for name in CLASSES_TO_KEEP}

# --- Group annotations by image ---
anns_by_image = defaultdict(list)
for ann in data["annotations"]:
    if ann["category_id"] in keep_ids.values():
        anns_by_image[ann["image_id"]].append(ann)

# --- Map image_id to which classes it contains ---
image_classes = defaultdict(set)
for image_id, anns in anns_by_image.items():
    for ann in anns:
        for class_name, class_id in keep_ids.items():
            if ann["category_id"] == class_id:
                image_classes[image_id].add(class_name)

# --- Separate images by dominant class ---
class_to_images = {cls: [] for cls in CLASSES_TO_KEEP}
for image_id, classes in image_classes.items():
    for cls in CLASSES_TO_KEEP:
        if cls in classes and len(classes) == 1:  # Use only single-class images for strict balance
            class_to_images[cls].append(image_id)

# --- Sample N/2 images per class ---
num_per_class = N // 2
selected_img_ids = set()
for cls in CLASSES_TO_KEEP:
    imgs = class_to_images[cls]
    if len(imgs) < num_per_class:
        raise ValueError(f"Not enough images for class '{cls}' (needed {num_per_class}, found {len(imgs)})")
    selected_img_ids.update(random.sample(imgs, num_per_class))

# --- Collect new annotations and images ---
new_annotations = []
new_images = []
image_id_set = set(selected_img_ids)

for img in data["images"]:
    if img["id"] in image_id_set:
        new_images.append(img)
        new_annotations.extend(anns_by_image[img["id"]])

# --- Build filtered data dict ---
filtered_data = {
    "info": data.get("info", {}),
    "licenses": data.get("licenses", []),
    "images": new_images,
    "annotations": new_annotations,
    "categories": [cat for cat in data["categories"] if cat["name"] in CLASSES_TO_KEEP]
}

# --- Save to output file ---
with open(OUTPUT_JSON, "w") as f:
    json.dump(filtered_data, f)


In [2]:
import json
import random
from collections import defaultdict

dataset_dir = "dataset_coco/annotations"
dataset = "instances_val2017.json"
dataset_path = f"{dataset_dir}/{dataset}"
INPUT_JSON = dataset_path
OUTPUT_JSON = f"{dataset_dir}/instances_val2017_filtered_balanced.json"
CLASSES_TO_KEEP = ["chair", "motorcycle"]
N = 200  # Total number of images to keep (must be even)

# --- Load COCO annotations ---
with open(INPUT_JSON, "r") as f:
    data = json.load(f)

# --- Get category IDs for the classes to keep ---
cat_name_to_id = {cat["name"]: cat["id"] for cat in data["categories"]}
keep_ids = {name: cat_name_to_id[name] for name in CLASSES_TO_KEEP}

# --- Group annotations by image ---
anns_by_image = defaultdict(list)
for ann in data["annotations"]:
    if ann["category_id"] in keep_ids.values():
        anns_by_image[ann["image_id"]].append(ann)

# --- Map image_id to which classes it contains ---
image_classes = defaultdict(set)
for image_id, anns in anns_by_image.items():
    for ann in anns:
        for class_name, class_id in keep_ids.items():
            if ann["category_id"] == class_id:
                image_classes[image_id].add(class_name)

# --- Separate images by dominant class ---
class_to_images = {cls: [] for cls in CLASSES_TO_KEEP}
for image_id, classes in image_classes.items():
    for cls in CLASSES_TO_KEEP:
        if cls in classes and len(classes) == 1:  # Use only single-class images for strict balance
            class_to_images[cls].append(image_id)

# --- Sample N/2 images per class ---
num_per_class = N // 2
selected_img_ids = set()
for cls in CLASSES_TO_KEEP:
    imgs = class_to_images[cls]
    if len(imgs) < num_per_class:
        raise ValueError(f"Not enough images for class '{cls}' (needed {num_per_class}, found {len(imgs)})")
    selected_img_ids.update(random.sample(imgs, num_per_class))

# --- Collect new annotations and images ---
new_annotations = []
new_images = []
image_id_set = set(selected_img_ids)

for img in data["images"]:
    if img["id"] in image_id_set:
        new_images.append(img)
        new_annotations.extend(anns_by_image[img["id"]])

# --- Build filtered data dict ---
filtered_data = {
    "info": data.get("info", {}),
    "licenses": data.get("licenses", []),
    "images": new_images,
    "annotations": new_annotations,
    "categories": [cat for cat in data["categories"] if cat["name"] in CLASSES_TO_KEEP]
}

# --- Save to output file ---
with open(OUTPUT_JSON, "w") as f:
    json.dump(filtered_data, f)
