<a href="https://colab.research.google.com/github/rohith-66/ai-generated-image-detection/blob/Rohith/create_splits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

import pandas as pd

DATA_ROOT = "/content/drive/MyDrive/AI_Image_Detection_Data"
CSV_PATH = f"{DATA_ROOT}/index/dataset_index.csv"

df = pd.read_csv(CSV_PATH)
print("Loaded:", len(df))
print(df["split"].value_counts())
print(df["label"].value_counts())


Mounted at /content/drive
Loaded: 25000
split
train    20000
val       5000
Name: count, dtype: int64
label
0    15000
1    10000
Name: count, dtype: int64


In [2]:
from sklearn.model_selection import train_test_split

SEED = 42
VAL_PER_CLASS = 1000
TEST_PER_CLASS = 1000

mask_old_val = (df["split"] == "val") & (df["source"] == "coco") & (df["label"] == 0)
df.loc[mask_old_val, "split"] = "extra_real"

train_pool = df[df["split"] == "train"].copy().reset_index(drop=True)

real = train_pool[train_pool["label"] == 0].copy()
ai   = train_pool[train_pool["label"] == 1].copy()

print("Train pool BEFORE:", train_pool["label"].value_counts().to_dict())

# Balanced VAL
real_rem, real_val = train_test_split(real, test_size=VAL_PER_CLASS, random_state=SEED, shuffle=True)
ai_rem,   ai_val   = train_test_split(ai,   test_size=VAL_PER_CLASS, random_state=SEED, shuffle=True)

# Balanced TEST
real_rem, real_test = train_test_split(real_rem, test_size=TEST_PER_CLASS, random_state=SEED, shuffle=True)
ai_rem,   ai_test   = train_test_split(ai_rem,   test_size=TEST_PER_CLASS, random_state=SEED, shuffle=True)

train_final = pd.concat([real_rem, ai_rem], ignore_index=True)
val_final   = pd.concat([real_val, ai_val], ignore_index=True)
test_final  = pd.concat([real_test, ai_test], ignore_index=True)

train_final["split"] = "train"
val_final["split"]   = "val"
test_final["split"]  = "test"

df_nontrain = df[df["split"] != "train"].copy()
df_final = pd.concat([df_nontrain, train_final, val_final, test_final], ignore_index=True)

dup_any = df_final["filepath"].duplicated().sum()
split_leaks = df_final.groupby("filepath")["split"].nunique()
leak_paths = (split_leaks > 1).sum()

print("\nSplit counts:\n", df_final["split"].value_counts())
print("\nLabel counts:\n", df_final["label"].value_counts())
print("\nDuplicate filepaths overall:", dup_any)
print("Filepaths appearing in >1 split:", leak_paths)

if dup_any != 0 or leak_paths != 0:
    raise RuntimeError("Leakage detected. Not saving CSV.")

df_final.to_csv(CSV_PATH, index=False)
print("\nSaved updated CSV:", CSV_PATH)


Train pool BEFORE: {0: 10000, 1: 10000}

Split counts:
 split
train         16000
extra_real     5000
val            2000
test           2000
Name: count, dtype: int64

Label counts:
 label
0    15000
1    10000
Name: count, dtype: int64

Duplicate filepaths overall: 0
Filepaths appearing in >1 split: 0

Saved updated CSV: /content/drive/MyDrive/AI_Image_Detection_Data/index/dataset_index.csv


In [3]:
df2 = pd.read_csv(CSV_PATH)
print("Total rows:", len(df2))
print("\nSplit counts:\n", df2["split"].value_counts())
print("\nLabel counts by split:\n", df2.groupby("split")["label"].value_counts())


Total rows: 25000

Split counts:
 split
train         16000
extra_real     5000
val            2000
test           2000
Name: count, dtype: int64

Label counts by split:
 split       label
extra_real  0        5000
test        0        1000
            1        1000
train       0        8000
            1        8000
val         0        1000
            1        1000
Name: count, dtype: int64
