In [1]:
import os
import re
import random
import shutil

# Input folders
original_folder = "/content/drive/MyDrive/train/original"
masked_folder = "/content/drive/MyDrive/train/masked"
blackout_folder = "/content/drive/MyDrive/train/blackout"

# Output base
output_base = "/content/drive/MyDrive/train/finaldata"
train_ratio = 0.8  # 80% train, 20% val

# Regex to get base name (works for img*, vid*, IMG* etc.)
pattern = re.compile(r"^(.*_jpg)")

# List all original files
original_files = [f for f in os.listdir(original_folder) if f.endswith(".jpg")]

# ✅ Keep only files that have all three versions
valid_files = []
for orig in original_files:
    match = pattern.match(orig)
    if not match:
        print(f"⚠️ Skipping unexpected filename: {orig}")
        continue

    base = match.group(1)  # e.g. img52_jpg or vid23_frame_00001_jpg
    mask_name = f"{base}_mask.png"

    mask_path = os.path.join(masked_folder, mask_name)
    blackout_path = os.path.join(blackout_folder, mask_name)

    if os.path.exists(mask_path) and os.path.exists(blackout_path):
        valid_files.append(orig)
    else:
        print(f"❌ Missing pair for {orig} -> skipped")

print(f"✅ Total usable images (with all 3 versions): {len(valid_files)}")

# Shuffle for randomness
random.shuffle(valid_files)

# Train/val split
split_idx = int(len(valid_files) * train_ratio)
train_files = valid_files[:split_idx]
val_files = valid_files[split_idx:]

def make_dirs(split):
    for sub in ["original", "masked", "blackouts"]:
        os.makedirs(os.path.join(output_base, split, sub), exist_ok=True)

make_dirs("train")
make_dirs("val")

def copy_split(files, split):
    for orig in files:
        base = pattern.match(orig).group(1)
        mask_name = f"{base}_mask.png"

        # paths
        orig_path = os.path.join(original_folder, orig)
        mask_path = os.path.join(masked_folder, mask_name)
        blackout_path = os.path.join(blackout_folder, mask_name)

        # dests
        dst_orig = os.path.join(output_base, split, "original", orig)
        dst_mask = os.path.join(output_base, split, "masked", mask_name)
        dst_blackout = os.path.join(output_base, split, "blackouts", mask_name)

        shutil.copy(orig_path, dst_orig)
        shutil.copy(mask_path, dst_mask)
        shutil.copy(blackout_path, dst_blackout)

# Perform copy
copy_split(train_files, "train")
copy_split(val_files, "val")

print(f"✅ Final dataset created. Train: {len(train_files)}, Val: {len(val_files)}")


❌ Missing pair for img97_jpg.rf.b930b64be9e78e638e293f3e5ef0198a.jpg -> skipped
❌ Missing pair for img702_jpg.rf.b4113b62f41976cfdb42d0aeb1f22c5c.jpg -> skipped
❌ Missing pair for img3297_jpg.rf.15a586362f2cf9d8a636fe66c87df20f.jpg -> skipped
❌ Missing pair for img2361_jpg.rf.cb899c3c4f03db931e2eec9e6975f0ad.jpg -> skipped
❌ Missing pair for img1585_jpg.rf.2b6398a8ddb693a7070a4f723268ec90.jpg -> skipped
❌ Missing pair for img308_jpg.rf.27ebff9827733225452402805803900e.jpg -> skipped
❌ Missing pair for img208_jpg.rf.231e5a1236757005a07de22c65823dfd.jpg -> skipped
❌ Missing pair for img551_jpg.rf.d890bedadfb6c7c375842556bd0ce1fa.jpg -> skipped
❌ Missing pair for img507_jpg.rf.fa6444980764bfed189ff6345f35792b.jpg -> skipped
❌ Missing pair for img1764_jpg.rf.ef233f5cf05ffb42dcc45ff16d3e4b2f.jpg -> skipped
❌ Missing pair for img3796_jpg.rf.304d84fd1a67d020a60a103340f4fc21.jpg -> skipped
❌ Missing pair for img81_jpg.rf.5bd8226477b655380596765fae625e44.jpg -> skipped
❌ Missing pair for img300