In [36]:
import os
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split # Still useful for shuffling and initial split

In [37]:
# --- Configuration ---
BASE_DATA_PATH = Path("../data-generation/spam/runs")
OUTPUT_DATA_PATH = Path("data_split")

COLUMNS = ['threshold', 'flops', 'non_zero_params', 'params_reduction_pct', 'flops_reduction_pct', 'overall_accuracy']
THRESHOLDS = np.round(np.arange(0, 10.1, 0.1), 1)

N_VALIDATION_MODELS = 2
N_TEST_MODELS = 2

RANDOM_STATE = 42

In [38]:
# --- 1. Load and Combine All CSVs with a 'model_id' ---
print("\nLoading and combining CSVs...")
all_dfs = []
model_id_counter = 0
for csv_file in BASE_DATA_PATH.rglob("pruning_results.csv"):
    try:
        df = pd.read_csv(csv_file); df['model_id'] = model_id_counter
        all_dfs.append(df); model_id_counter += 1
    except Exception as e: print(f"Error reading {csv_file}: {e}")

if not all_dfs: raise ValueError(f"No CSV files found or loaded from {BASE_DATA_PATH}.")
combined_df = pd.concat(all_dfs, ignore_index=True)
print(f"Combined {len(all_dfs)} CSVs. Total rows: {len(combined_df)}")
print(f"Number of unique model_ids: {combined_df['model_id'].nunique()}")


Loading and combining CSVs...
Combined 17 CSVs. Total rows: 1717
Number of unique model_ids: 17


In [40]:
# --- 2. Perform Model-Aware Train-Validation-Test Split with Fixed Counts ---
print("\nPerforming model-aware train-validation-test split with fixed counts...")
unique_model_ids = combined_df['model_id'].unique()
np.random.seed(RANDOM_STATE) # for reproducibility of np.random.choice
np.random.shuffle(unique_model_ids) # Shuffle IDs randomly

if len(unique_model_ids) < (N_VALIDATION_MODELS + N_TEST_MODELS):
    raise ValueError(
        f"Not enough unique models ({len(unique_model_ids)}) to satisfy "
        f"{N_VALIDATION_MODELS} for validation and {N_TEST_MODELS} for testing. "
        f"Need at least {N_VALIDATION_MODELS + N_TEST_MODELS} models."
    )

# Select test model IDs
test_model_ids = unique_model_ids[:N_TEST_MODELS]
remaining_ids_after_test = unique_model_ids[N_TEST_MODELS:]

if len(remaining_ids_after_test) < N_VALIDATION_MODELS:
     raise ValueError(
        f"Not enough unique models remaining ({len(remaining_ids_after_test)}) after selecting test set "
        f"to satisfy {N_VALIDATION_MODELS} for validation. "
        f"Consider reducing N_TEST_MODELS or N_VALIDATION_MODELS, or increasing total models."
    )

# Select validation model IDs from the remainder
val_model_ids = remaining_ids_after_test[:N_VALIDATION_MODELS]

# The rest go to training
train_model_ids = remaining_ids_after_test[N_VALIDATION_MODELS:]

if len(train_model_ids) == 0:
    print("Warning: No models remaining for the training set after allocating to validation and test.")


print(f"Number of models for Training: {len(train_model_ids)}")
print(f"Number of models for Validation: {len(val_model_ids)}")
print(f"Number of models for Testing: {len(test_model_ids)}")

# Create DataFrames based on the split model_ids
train_df = combined_df[combined_df['model_id'].isin(train_model_ids)].copy()
validation_df = combined_df[combined_df['model_id'].isin(val_model_ids)].copy()
test_df = combined_df[combined_df['model_id'].isin(test_model_ids)].copy()

print(f"\nShape of Training DataFrame: {train_df.shape}")
print(f"Shape of Validation DataFrame: {validation_df.shape}")
print(f"Shape of Test DataFrame: {test_df.shape}")


Performing model-aware train-validation-test split with fixed counts...
Number of models for Training: 13
Number of models for Validation: 2
Number of models for Testing: 2

Shape of Training DataFrame: (1313, 18)
Shape of Validation DataFrame: (202, 18)
Shape of Test DataFrame: (202, 18)


In [41]:
# --- 3. Save Split DataFrames to CSV Files ---
print(f"\nSaving datasets to: {OUTPUT_DATA_PATH}")
os.makedirs(OUTPUT_DATA_PATH, exist_ok=True)

train_df.to_csv(OUTPUT_DATA_PATH / "train_dataset.csv", index=False)
validation_df.to_csv(OUTPUT_DATA_PATH / "validation_dataset.csv", index=False)
test_df.to_csv(OUTPUT_DATA_PATH / "test_dataset.csv", index=False)

print("\nSuccessfully created and saved train_dataset.csv, validation_dataset.csv, and test_dataset.csv.")


Saving datasets to: data_split

Successfully created and saved train_dataset.csv, validation_dataset.csv, and test_dataset.csv.


In [42]:
# --- Optional: Verification ---
print("\nVerification of splits (first few model_ids in each set):")
print(f"Train model_ids sample: {train_model_ids[:5] if len(train_model_ids) > 0 else 'N/A'}")
print(f"Validation model_ids sample: {val_model_ids[:5] if len(val_model_ids) > 0 else 'N/A'}")
print(f"Test model_ids sample: {test_model_ids[:5] if len(test_model_ids) > 0 else 'N/A'}")

train_set = set(train_model_ids); val_set = set(val_model_ids); test_set = set(test_model_ids)
print(f"Overlap train-val: {len(train_set.intersection(val_set))}")
print(f"Overlap train-test: {len(train_set.intersection(test_set))}")
print(f"Overlap val-test: {len(val_set.intersection(test_set))}")

print(f"\nExample rows from train_dataset.csv:")
print(pd.read_csv(OUTPUT_DATA_PATH / "train_dataset.csv").head())


Verification of splits (first few model_ids in each set):
Train model_ids sample: [11 14  8 13  2]
Validation model_ids sample: [ 5 15]
Test model_ids sample: [0 1]
Overlap train-val: 0
Overlap train-test: 0
Overlap val-test: 0

Example rows from train_dataset.csv:
   threshold         flops  non_zero_params  params_reduction_pct  \
0        0.0  5.586813e+09         66955009              0.000000   
1        0.1  5.427130e+09         65041295              2.858209   
2        0.2  5.270893e+09         63168879              5.654737   
3        0.3  5.121136e+09         61374123              8.335278   
4        0.4  4.979191e+09         59672984             10.875997   

   flops_reduction_pct  overall_accuracy  overall_f1  overall_precision  \
0             0.000000             0.988    0.987999           0.988125   
1             2.858209             0.988    0.987999           0.988125   
2             5.654737             0.988    0.987999           0.988125   
3             8.33