In [1]:
import random
from collections import defaultdict

TRAIN_TXT = "train.txt"
VAL_TXT = "val.txt"
NEW_TRAIN_TXT = "train_new.txt"
NEW_VAL_TXT = "val_new.txt"

SAMPLES_PER_CLASS = 300
RANDOM_SEED = 42
MOVE_INSTEAD_OF_DUPLICATE = True  # set False if you want to KEEP them in train as well

random.seed(RANDOM_SEED)

# -----------------------------
# 1. Read existing train & val
# -----------------------------
with open(TRAIN_TXT, "r") as f:
    train_lines = [line.strip() for line in f if line.strip()]

with open(VAL_TXT, "r") as f:
    val_lines = [line.strip() for line in f if line.strip()]

# Parse train into (path, label)
train_entries = []
for line in train_lines:
    # Split from the right in case path ever has spaces
    path, label = line.rsplit(" ", 1)
    train_entries.append((path, label))

# ---------------------------------------
# 2. Group train entries by class/label
# ---------------------------------------


by_class = defaultdict(list)  # label -> list of indices in train_entries

for idx, (path, label) in enumerate(train_entries):
    by_class[label].append(idx)

print(f"Found {len(by_class)} classes in train.txt")

# ---------------------------------------
# 3. Sample up to 300 per class from train
# ---------------------------------------
selected_indices = set()
selected_lines_for_val = []

for label, indices in by_class.items():
    if len(indices) == 0:
        continue

    n_samples = min(SAMPLES_PER_CLASS, len(indices))
    sampled = random.sample(indices, n_samples)

    for idx in sampled:
        path, lbl = train_entries[idx]
        selected_lines_for_val.append(f"{path} {lbl}")
        selected_indices.add(idx)

    print(f"Class {label}: selected {n_samples} samples to move/add to val")

# ---------------------------------------
# 4. Build new train and val lists
# ---------------------------------------

# Option A: move them (remove from train)
if MOVE_INSTEAD_OF_DUPLICATE:
    new_train_lines = [
        f"{path} {label}"
        for i, (path, label) in enumerate(train_entries)
        if i not in selected_indices
    ]
else:
    # Option B: keep all original train entries
    new_train_lines = [f"{path} {label}" for (path, label) in train_entries]

# New val = old val + new sampled ones
new_val_lines = val_lines + selected_lines_for_val

# ---------------------------------------
# 5. Write out new files
# ---------------------------------------
with open(NEW_TRAIN_TXT, "w") as f:
    for line in new_train_lines:
        f.write(line + "\n")

with open(NEW_VAL_TXT, "w") as f:
    for line in new_val_lines:
        f.write(line + "\n")

print(f"Done!")
print(f"Original train: {len(train_lines)} lines")
print(f"New train:      {len(new_train_lines)} lines")
print(f"Original val:   {len(val_lines)} lines")
print(f"New val:        {len(new_val_lines)} lines")


Found 100 classes in train.txt
Class 0: selected 300 samples to move/add to val
Class 2: selected 300 samples to move/add to val
Class 4: selected 300 samples to move/add to val
Class 9: selected 300 samples to move/add to val
Class 20: selected 300 samples to move/add to val
Class 23: selected 300 samples to move/add to val
Class 29: selected 300 samples to move/add to val
Class 31: selected 300 samples to move/add to val
Class 41: selected 300 samples to move/add to val
Class 53: selected 300 samples to move/add to val
Class 55: selected 300 samples to move/add to val
Class 59: selected 300 samples to move/add to val
Class 61: selected 300 samples to move/add to val
Class 62: selected 300 samples to move/add to val
Class 65: selected 300 samples to move/add to val
Class 73: selected 300 samples to move/add to val
Class 74: selected 300 samples to move/add to val
Class 75: selected 300 samples to move/add to val
Class 78: selected 300 samples to move/add to val
Class 82: selected 300 

In [3]:
from collections import defaultdict

train_file = "train_new.txt"

class_counts = defaultdict(int)

with open(train_file, "r") as f:
    for line in f:
        path, label = line.strip().rsplit(" ", 1)
        class_counts[label] += 1

# Print counts
for label in sorted(class_counts.keys(), key=lambda x: int(x)):
    print(f"Class {label}: {class_counts[label]} images")

# Check if all equal to 1000
all_ok = all(count == 1000 for count in class_counts.values())

print("\nAll classes have 1000 images?:", all_ok)


Class 0: 1000 images
Class 1: 1000 images
Class 2: 1000 images
Class 3: 1000 images
Class 4: 1000 images
Class 5: 1000 images
Class 6: 1000 images
Class 7: 1000 images
Class 8: 1000 images
Class 9: 1000 images
Class 10: 1000 images
Class 11: 1000 images
Class 12: 1000 images
Class 13: 1000 images
Class 14: 1000 images
Class 15: 1000 images
Class 16: 1000 images
Class 17: 1000 images
Class 18: 1000 images
Class 19: 1000 images
Class 20: 1000 images
Class 21: 1000 images
Class 22: 1000 images
Class 23: 1000 images
Class 24: 1000 images
Class 25: 1000 images
Class 26: 1000 images
Class 27: 1000 images
Class 28: 1000 images
Class 29: 1000 images
Class 30: 1000 images
Class 31: 1000 images
Class 32: 1000 images
Class 33: 1000 images
Class 34: 1000 images
Class 35: 1000 images
Class 36: 1000 images
Class 37: 1000 images
Class 38: 1000 images
Class 39: 1000 images
Class 40: 1000 images
Class 41: 1000 images
Class 42: 1000 images
Class 43: 1000 images
Class 44: 1000 images
Class 45: 1000 image