# Split Training Data

Analyze filtered fragments and split into train/validation/test sets (80/10/10).

## Setup

In [1]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from config import PROCESSED_DATA_DIR

print("✓ Imports loaded")

✓ Imports loaded


## Load Filtered Fragments

In [None]:
# Load filtered fragments
fragments_path = PROCESSED_DATA_DIR / "ao3_fragments_selfcontained.csv"
df = pd.read_csv(fragments_path)

## Split Data (80/10/10)

Split into train, validation, and test sets.

In [9]:
# First split: 80% train, 20% temp
train_df, valtest_df = train_test_split(df, test_size=0.2, random_state=42)

# Second split: split temp into 50/50 for val and test (10% each of original)
val_df, test_df = train_test_split(valtest_df, test_size=0.5, random_state=42)

print(f"Total fragments: {len(df):,}")
print(f"\nTrain: {len(train_df):,} ({len(train_df)/len(df)*100:.1f}%)")
print(f"Validation: {len(val_df):,} ({len(val_df)/len(df)*100:.1f}%)")
print(f"Test: {len(test_df):,} ({len(test_df)/len(df)*100:.1f}%)")

Total fragments: 133,043

Train: 106,434 (80.0%)
Validation: 13,304 (10.0%)
Test: 13,305 (10.0%)


## Format for Training (JSONL)

Convert to simple JSONL format with just the text.

In [None]:
import json

def save_jsonl(df, path):
    """Save dataframe as JSONL with just text field"""
    with open(path, 'w') as f:
        for text in df['text']:
            f.write(json.dumps({"text": text}) + '\n')
    print(f"Saved {len(df):,} examples to {path}")

# Save splits
train_path = PROCESSED_DATA_DIR / "train.jsonl"
val_path = PROCESSED_DATA_DIR / "val.jsonl"
test_path = PROCESSED_DATA_DIR / "test.jsonl"

save_jsonl(train_df, train_path)
save_jsonl(val_df, val_path)
save_jsonl(test_df, test_path)

print("\n✓ Data splits saved!")

## Preview JSONL Files

In [None]:
# Preview train file
print("First 5 lines of train.jsonl:")
print("="*80)

with open(train_path, 'r') as f:
    for i, line in enumerate(f):
        if i >= 5:
            break
        data = json.loads(line)
        print(f"{i+1}. {data['text']}")
        print()