# Prepare Training Data

Creates 3 datasets with train/val/test splits (80/10/10):
1. **Combined**: Balanced (10k one-liners + 10k short stories)
2. **One-Liner**: Full dataset (~2k)
3. **Short-Story**: Full dataset (~50k)

## Setup

In [1]:
import pandas as pd
import sys
from pathlib import Path
import json
from sklearn.model_selection import train_test_split

sys.path.append('..')
from config import PROCESSED_DATA_DIR

TRAIN_DIR = Path("../data/train")
TRAIN_DIR.mkdir(exist_ok=True, parents=True)

## Load Data

In [2]:
input_path = PROCESSED_DATA_DIR / "ao3_tifu_enriched_labels_with_instruction.parquet"
df = pd.read_parquet(input_path)
print(f"Loaded {len(df):,} rows")

Loaded 51,337 rows


## Balance Data

In [3]:
TARGET_SIZE = 10000

one_liner_df = df[df['type'] == 'one_liner'].copy()
short_story_df = df[df['type'] == 'short_story'].copy()

short_story_balanced = short_story_df.sample(n=TARGET_SIZE, random_state=42)
one_liner_balanced = one_liner_df.sample(n=TARGET_SIZE, replace=True, random_state=42)

balanced_df = pd.concat([one_liner_balanced, short_story_balanced], ignore_index=True)
balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True)

print(f"Balanced: {len(balanced_df):,} rows")

Balanced: 20,000 rows


## Create 3 Datasets

In [4]:
combined_df = balanced_df.copy()
one_liner_full_df = one_liner_df.copy()
short_story_full_df = short_story_df.copy()

## Split into Train/Val/Test (80/10/10)

In [5]:
def split_dataset(df, test_size=0.1, val_size=0.1, random_state=42):
    """Split dataset into train/val/test (80/10/10)."""
    train_val, test = train_test_split(df, test_size=test_size, random_state=random_state)
    train, val = train_test_split(train_val, test_size=val_size/(1-test_size), random_state=random_state)
    return train, val, test

combined_train, combined_val, combined_test = split_dataset(combined_df)
one_liner_train, one_liner_val, one_liner_test = split_dataset(one_liner_full_df)
short_story_train, short_story_val, short_story_test = split_dataset(short_story_full_df)

print(f"Combined: {len(combined_train):,} train / {len(combined_val):,} val / {len(combined_test):,} test")
print(f"One-liner: {len(one_liner_train):,} train / {len(one_liner_val):,} val / {len(one_liner_test):,} test")
print(f"Short-story: {len(short_story_train):,} train / {len(short_story_val):,} val / {len(short_story_test):,} test")

Combined: 15,999 train / 2,001 val / 2,000 test
One-liner: 1,671 train / 210 val / 210 test
Short-story: 39,396 train / 4,925 val / 4,925 test


## Format for Instruction Tuning

In [6]:
def format_for_instruction_tuning(df):
    """Convert to instruction-tuning format."""
    formatted = []
    for _, row in df.iterrows():
        example = {
            "messages": [
                {"role": "user", "content": row['instruction']},
                {"role": "assistant", "content": row['text']}
            ]
        }
        formatted.append(example)
    return formatted

## Save JSONL Files

In [7]:
def save_to_jsonl(data, filepath):
    """Save formatted data to JSONL."""
    with open(filepath, 'w') as f:
        for example in data:
            f.write(json.dumps(example) + '\n')

# Combined
save_to_jsonl(format_for_instruction_tuning(combined_train), TRAIN_DIR / "combined_train.jsonl")
save_to_jsonl(format_for_instruction_tuning(combined_val), TRAIN_DIR / "combined_val.jsonl")
save_to_jsonl(format_for_instruction_tuning(combined_test), TRAIN_DIR / "combined_test.jsonl")

# One-liner
save_to_jsonl(format_for_instruction_tuning(one_liner_train), TRAIN_DIR / "one_liner_train.jsonl")
save_to_jsonl(format_for_instruction_tuning(one_liner_val), TRAIN_DIR / "one_liner_val.jsonl")
save_to_jsonl(format_for_instruction_tuning(one_liner_test), TRAIN_DIR / "one_liner_test.jsonl")

# Short-story
save_to_jsonl(format_for_instruction_tuning(short_story_train), TRAIN_DIR / "short_story_train.jsonl")
save_to_jsonl(format_for_instruction_tuning(short_story_val), TRAIN_DIR / "short_story_val.jsonl")
save_to_jsonl(format_for_instruction_tuning(short_story_test), TRAIN_DIR / "short_story_test.jsonl")

print("✓ Saved 9 JSONL files to data/train/")

✓ Saved 9 JSONL files to data/train/
