# Create Train/Validation/Test Splits

This notebook creates stratified train/validation/test splits from the real dataset index.

## Purpose
- Load real dataset index
- Perform stratified splitting by wave type and direction
- Ensure balanced representation across splits
- Analyze split statistics

In [None]:
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
import numpy as np

In [None]:
# Import utility functions from previous notebooks
%run 02_data_loading.ipynb

## Configuration

In [None]:
# Configuration
REAL_INDEX = "data/processed/real_index.jsonl"
OUT_DIR = "data/processed/splits"
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
TEST_RATIO = 0.1  # Remaining after train and val
SEED = 42

print(f"Input index: {REAL_INDEX}")
print(f"Output directory: {OUT_DIR}")
print(f"Split ratios - Train: {TRAIN_RATIO}, Val: {VAL_RATIO}, Test: {TEST_RATIO}")
print(f"Random seed: {SEED}")

## Stratified Splitting Function

In [None]:
def stratified_split(items, train_ratio=0.8, val_ratio=0.1, seed=42):
    """Perform stratified split by wave_type and direction."""
    rng = random.Random(seed)

    # Group by wave_type and direction
    buckets = defaultdict(list)
    for r in items:
        key = (r["wave_type"], r["direction"])
        buckets[key].append(r)

    train, val, test = [], [], []

    print(f"Stratifying by {len(buckets)} unique (wave_type, direction) combinations:")
    
    for key, bucket in buckets.items():
        rng.shuffle(bucket)
        n = len(bucket)
        
        # Calculate split sizes
        n_train = int(round(n * train_ratio))
        n_val = int(round(n * val_ratio))
        
        # Ensure we don't exceed bucket size
        n_train = min(n_train, n)
        n_val = min(n_val, n - n_train)
        n_test = n - n_train - n_val

        print(f"  {key}: {n} total → {n_train} train, {n_val} val, {n_test} test")
        
        train.extend(bucket[:n_train])
        val.extend(bucket[n_train:n_train + n_val])
        test.extend(bucket[n_train + n_val:])

    # Final shuffle
    rng.shuffle(train)
    rng.shuffle(val)
    rng.shuffle(test)
    
    return train, val, test

## Load Data and Create Splits

In [None]:
# Load the real dataset index
try:
    items = read_jsonl(REAL_INDEX)
    print(f"✓ Loaded {len(items)} items from {REAL_INDEX}")
except FileNotFoundError:
    print(f"✗ Index file not found: {REAL_INDEX}")
    print("Please run the 04_build_real_index.ipynb notebook first.")
    # Create dummy data for demonstration
    print("Creating dummy data for demonstration...")
    items = [
        {
            "image_path": f"data/real/images/img_{i:03d}.jpg",
            "height_meters": float(np.random.uniform(0.5, 2.5)),
            "wave_type": np.random.choice(["beach_break", "reef_break", "point_break", "closeout", "a_frame"]),
            "direction": np.random.choice(["left", "right", "both"]),
            "confidence": np.random.choice(["high", "medium", "low"]),
            "notes": f"Example note {i}",
            "data_key": i,
            "source": "real"
        }
        for i in range(100)
    ]

In [None]:
# Analyze data distribution before splitting
df_full = pd.DataFrame(items)

print("Dataset overview before splitting:")
print(f"Total samples: {len(df_full)}")
print(f"\nWave type distribution:")
print(df_full['wave_type'].value_counts())
print(f"\nDirection distribution:")
print(df_full['direction'].value_counts())

# Show stratification keys
strat_keys = df_full.groupby(['wave_type', 'direction']).size()
print(f"\nStratification groups (wave_type, direction):")
print(strat_keys)

In [None]:
# Set seed and perform stratified split
set_seed(SEED)

train, val, test = stratified_split(
    items, 
    train_ratio=TRAIN_RATIO, 
    val_ratio=VAL_RATIO, 
    seed=SEED
)

print(f"\nSplit results:")
print(f"Train: {len(train)} samples ({len(train)/len(items)*100:.1f}%)")
print(f"Val: {len(val)} samples ({len(val)/len(items)*100:.1f}%)")
print(f"Test: {len(test)} samples ({len(test)/len(items)*100:.1f}%)")
print(f"Total: {len(train) + len(val) + len(test)} samples")

## Save Splits

In [None]:
# Save splits to JSONL files
ensure_dir(OUT_DIR)

train_path = f"{OUT_DIR}/train.jsonl"
val_path = f"{OUT_DIR}/val.jsonl"
test_path = f"{OUT_DIR}/test.jsonl"

write_jsonl(train, train_path)
write_jsonl(val, val_path)
write_jsonl(test, test_path)

print(f"✓ Saved splits:")
print(f"  Train: {train_path}")
print(f"  Val: {val_path}")
print(f"  Test: {test_path}")

## Split Analysis and Validation

In [None]:
# Convert splits to DataFrames for analysis
df_train = pd.DataFrame(train)
df_val = pd.DataFrame(val)
df_test = pd.DataFrame(test)

# Add split labels for combined analysis
df_train['split'] = 'train'
df_val['split'] = 'val'
df_test['split'] = 'test'

df_combined = pd.concat([df_train, df_val, df_test], ignore_index=True)

In [None]:
# Analyze distribution preservation
print("Distribution analysis across splits:")
print("\n1. Wave Type Distribution:")
wave_type_dist = pd.crosstab(df_combined['wave_type'], df_combined['split'], normalize='index') * 100
print(wave_type_dist.round(1))

print("\n2. Direction Distribution:")
direction_dist = pd.crosstab(df_combined['direction'], df_combined['split'], normalize='index') * 100
print(direction_dist.round(1))

print("\n3. Combined (Wave Type, Direction) Distribution:")
df_combined['strat_key'] = df_combined['wave_type'] + '_' + df_combined['direction']
combined_dist = pd.crosstab(df_combined['strat_key'], df_combined['split'], normalize='index') * 100
print(combined_dist.round(1))

In [None]:
# Visualize split distributions
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# Height distributions by split
for i, (split_name, split_df) in enumerate([("Train", df_train), ("Val", df_val), ("Test", df_test)]):
    axes[0, i].hist(split_df['height_meters'], bins=15, alpha=0.7, edgecolor='black')
    axes[0, i].set_title(f'{split_name} - Height Distribution')
    axes[0, i].set_xlabel('Height (meters)')
    axes[0, i].set_ylabel('Frequency')
    axes[0, i].axvline(split_df['height_meters'].mean(), color='red', linestyle='--', 
                      label=f'Mean: {split_df["height_meters"].mean():.2f}m')
    axes[0, i].legend()

# Wave type distributions
wave_type_counts = [df_train['wave_type'].value_counts(), 
                   df_val['wave_type'].value_counts(), 
                   df_test['wave_type'].value_counts()]

for i, (split_name, counts) in enumerate(zip(["Train", "Val", "Test"], wave_type_counts)):
    axes[1, i].bar(counts.index, counts.values)
    axes[1, i].set_title(f'{split_name} - Wave Type Distribution')
    axes[1, i].set_xlabel('Wave Type')
    axes[1, i].set_ylabel('Count')
    axes[1, i].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

In [None]:
# Statistical tests for distribution similarity
from scipy.stats import ks_2samp, chi2_contingency

print("Statistical tests for split similarity:")

# Kolmogorov-Smirnov test for height distributions
ks_train_val = ks_2samp(df_train['height_meters'], df_val['height_meters'])
ks_train_test = ks_2samp(df_train['height_meters'], df_test['height_meters'])
ks_val_test = ks_2samp(df_val['height_meters'], df_test['height_meters'])

print(f"\nHeight distribution similarity (KS test p-values):")
print(f"Train vs Val: {ks_train_val.pvalue:.4f}")
print(f"Train vs Test: {ks_train_test.pvalue:.4f}")
print(f"Val vs Test: {ks_val_test.pvalue:.4f}")
print("(Higher p-values indicate more similar distributions)")

# Chi-square test for categorical distributions
wave_type_contingency = pd.crosstab(df_combined['wave_type'], df_combined['split'])
chi2_wt, p_wt, _, _ = chi2_contingency(wave_type_contingency)

direction_contingency = pd.crosstab(df_combined['direction'], df_combined['split'])
chi2_dir, p_dir, _, _ = chi2_contingency(direction_contingency)

print(f"\nCategorical distribution similarity (Chi-square test p-values):")
print(f"Wave type across splits: {p_wt:.4f}")
print(f"Direction across splits: {p_dir:.4f}")
print("(Higher p-values indicate more similar distributions)")

In [None]:
# Detailed split summary
print("Detailed Split Summary:")
print("=" * 50)

for split_name, split_df in [("TRAIN", df_train), ("VALIDATION", df_val), ("TEST", df_test)]:
    print(f"\n{split_name} SET:")
    print(f"  Samples: {len(split_df)}")
    print(f"  Height range: {split_df['height_meters'].min():.2f}m - {split_df['height_meters'].max():.2f}m")
    print(f"  Height mean: {split_df['height_meters'].mean():.2f}m ± {split_df['height_meters'].std():.2f}m")
    
    print(f"  Wave types: {dict(split_df['wave_type'].value_counts())}")
    print(f"  Directions: {dict(split_df['direction'].value_counts())}")
    print(f"  Confidence levels: {dict(split_df['confidence'].value_counts())}")

## Validation Checks

In [None]:
# Validation checks
print("Split Validation Checks:")
print("=" * 30)

# Check no overlap between splits
train_paths = set(r['image_path'] for r in train)
val_paths = set(r['image_path'] for r in val)
test_paths = set(r['image_path'] for r in test)

overlap_train_val = train_paths & val_paths
overlap_train_test = train_paths & test_paths
overlap_val_test = val_paths & test_paths

if not overlap_train_val and not overlap_train_test and not overlap_val_test:
    print("✓ No overlap between splits")
else:
    print("✗ Overlap detected between splits!")
    if overlap_train_val:
        print(f"  Train-Val overlap: {len(overlap_train_val)} samples")
    if overlap_train_test:
        print(f"  Train-Test overlap: {len(overlap_train_test)} samples")
    if overlap_val_test:
        print(f"  Val-Test overlap: {len(overlap_val_test)} samples")

# Check total count
total_split = len(train) + len(val) + len(test)
if total_split == len(items):
    print(f"✓ All {len(items)} samples accounted for in splits")
else:
    print(f"✗ Sample count mismatch: {total_split} in splits vs {len(items)} original")

# Check minimum samples per class in each split
min_samples_per_class = 1
for split_name, split_df in [("train", df_train), ("val", df_val), ("test", df_test)]:
    wt_counts = split_df['wave_type'].value_counts()
    dir_counts = split_df['direction'].value_counts()
    
    if wt_counts.min() >= min_samples_per_class and dir_counts.min() >= min_samples_per_class:
        print(f"✓ {split_name.capitalize()} split has sufficient samples per class")
    else:
        print(f"⚠️ {split_name.capitalize()} split has classes with few samples:")
        if wt_counts.min() < min_samples_per_class:
            print(f"  Wave types: {wt_counts[wt_counts < min_samples_per_class].to_dict()}")
        if dir_counts.min() < min_samples_per_class:
            print(f"  Directions: {dir_counts[dir_counts < min_samples_per_class].to_dict()}")

print(f"\n✓ Splits saved successfully to {OUT_DIR}/")

## Export Split Statistics

In [None]:
# Save split statistics
split_stats = {
    "total_samples": len(items),
    "split_sizes": {
        "train": len(train),
        "val": len(val),
        "test": len(test)
    },
    "split_ratios": {
        "train": len(train) / len(items),
        "val": len(val) / len(items),
        "test": len(test) / len(items)
    },
    "height_stats": {
        "train": df_train['height_meters'].describe().to_dict(),
        "val": df_val['height_meters'].describe().to_dict(),
        "test": df_test['height_meters'].describe().to_dict()
    },
    "wave_type_distribution": {
        "train": df_train['wave_type'].value_counts().to_dict(),
        "val": df_val['wave_type'].value_counts().to_dict(),
        "test": df_test['wave_type'].value_counts().to_dict()
    },
    "direction_distribution": {
        "train": df_train['direction'].value_counts().to_dict(),
        "val": df_val['direction'].value_counts().to_dict(),
        "test": df_test['direction'].value_counts().to_dict()
    },
    "stratification_groups": len(strat_keys),
    "seed": SEED
}

stats_path = f"{OUT_DIR}/split_statistics.json"
with open(stats_path, 'w') as f:
    json.dump(split_stats, f, indent=2)

print(f"Split statistics saved to: {stats_path}")