In [None]:
import pickle
import gzip
import random
import os
import shutil
from pathlib import Path
from dataclasses import dataclass
from datetime import datetime
from typing import List, Dict, Union, Tuple

# --- CONFIGURATION ---
# Input Filenames (Pointing to the original full-scale data)
unified_grammar_input = "race_specific_simplified_grammar.pickle"
cleaned_runner_input = "runner_data_with_full_grammar_token_and_race_details_pared_down.pickle"

# Production Parameters
num_splits = 100
max_weeks_between_races = 52 * 5  # Max history to look back (5 years)
flush_size = 5000  # Flush to disk every 5000 runners to keep RAM low

# Output Dir
splits_output_dir = "training_splits"

@dataclass
class RaceData:
    race_id: str
    distance_token: str
    vc_conditions_token: str
    vc_humidity_token: str
    vc_temperature_token: str
    vc_feels_like_token: str
    vc_wind_speed_token: str
    start_date_time: datetime

@dataclass
class TrainingExample:
    unpadded_example_sequence: List[str] 
    actual_pace_seconds: int
    raw_pace_data: List[tuple]

@dataclass
class RunnerForTraining:
    name_gender_dedup_int: tuple
    training_examples: List[TrainingExample]
    split_assignment: int

In [None]:
def append_to_split_files(runners_batch, output_dir, n_splits, base_name="runners_split", compress=True):
    """Appends a batch of runners to their respective split files."""
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Sort this batch into split buckets
    buckets = {i: [] for i in range(n_splits)}
    for r in runners_batch:
        buckets[r.split_assignment].append(r)
    
    # Append each bucket to its file
    for i in range(n_splits):
        if not buckets[i]: continue
        
        filename = f"{base_name}_{i:03d}.pkl"
        if compress: filename += ".gz"
        out_path = output_path / filename
        
        # Use 'ab' (append binary) mode
        # Multi-pickle streams in gzipped files are valid for pickle.load() in a loop
        with (gzip.open(out_path, "ab") if compress else open(out_path, "ab")) as f:
            pickle.dump(buckets[i], f, protocol=pickle.HIGHEST_PROTOCOL)

def pace_to_seconds(pace_str):
    if not pace_str or ":" not in pace_str: return None
    try:
        m, s = map(int, pace_str.split(":"))
        return m * 60 + s
    except: return None

def convert_single_race_to_single_race_sequence(raw_race_instance):
    rd = raw_race_instance["raceDetails"]
    return [
        raw_race_instance["age_token"],
        raw_race_instance["gender_token"],
        rd.vc_conditions_token,
        rd.vc_humidity_token,
        rd.vc_temperature_token,
        rd.vc_feels_like_token,
        rd.vc_wind_speed_token,
        rd.distance_token,
        raw_race_instance["weeks_to_next_race_token"],
        raw_race_instance["weeks_to_final_race_token"],
        raw_race_instance["paceToken"]
    ]

def convert_training_examples_to_dataclass(training_examples):
    # Newest target race is at index 0 because races_inv is sorted reverse=True
    final_race_pace = pace_to_seconds(training_examples[0]["pace"])
    final_sequence = []
    raw_pace_data_temp = []
    
    # reversed(training_examples) makes it Oldest -> Newest
    for race in reversed(training_examples):
        final_sequence.extend(convert_single_race_to_single_race_sequence(race))
        raw_pace_data_temp.append((
            race["raceDetails"].distance_token, 
            race["weeks_to_final_race_token"], 
            pace_to_seconds(race["pace"])
        ))
    
    return TrainingExample(
        unpadded_example_sequence=final_sequence, 
        actual_pace_seconds=final_race_pace, 
        raw_pace_data=raw_pace_data_temp
    )

In [None]:
print("Loading input data...")
with open(unified_grammar_input, "rb") as f: race_details = pickle.load(f)
with open(cleaned_runner_input, "rb") as f: runner_data = pickle.load(f)
print(f"Loaded {len(race_details)} races and {len(runner_data)} runners.")

In [None]:
# Clear and prepare output directory
if os.path.exists(splits_output_dir):
    shutil.rmtree(splits_output_dir)
os.makedirs(splits_output_dir, exist_ok=True)

batch_cache = []
processed_count = 0
total_examples_created = 0

print(f"Processing runners with streaming (batch size: {flush_size})...")

for i, (key, races) in enumerate(runner_data.items()):
    if i % 10000 == 0: 
        print(f"Processed {i} runners...")
    
    if len(races) < 2: continue
    
    # Sort Newest -> Oldest
    races_inv = sorted(races, key=lambda x: x["raceDetails"].start_date_time, reverse=True)
    
    runner_examples = []
    for idx in range(len(races_inv)-1):
        candidate = [races_inv[idx].copy()]
        candidate[0]["weeks_to_next_race_token"] = "week_delta_0"
        candidate[0]["weeks_to_final_race_token"] = "week_delta_0"
        
        for j in range(idx+1, len(races_inv)):
            curr = races_inv[j].copy()
            dt = abs(races_inv[idx]["raceDetails"].start_date_time - races_inv[j]["raceDetails"].start_date_time)
            dt_next = abs(races_inv[j-1]["raceDetails"].start_date_time - races_inv[j]["raceDetails"].start_date_time)
            
            w = int(round(dt.total_seconds() / (7 * 24 * 3600)))
            wn = int(round(dt_next.total_seconds() / (7 * 24 * 3600)))
            
            curr["weeks_to_next_race_token"] = f"week_delta_{wn}"
            curr["weeks_to_final_race_token"] = f"week_delta_{w}"
            
            if 0 < w < max_weeks_between_races: candidate.append(curr)
            
        if len(candidate) > 1:
            runner_examples.append(convert_training_examples_to_dataclass(candidate))
            
    if runner_examples:
        batch_cache.append(RunnerForTraining(
            name_gender_dedup_int=key, 
            training_examples=runner_examples, 
            split_assignment=random.randint(0, num_splits - 1)
        ))
        processed_count += 1
        total_examples_created += len(runner_examples)
            
    if len(batch_cache) >= flush_size:
        append_to_split_files(batch_cache, splits_output_dir, num_splits)
        batch_cache = []
        print(f"Flush! {processed_count} runners saved to disk so far.")

# Final flush
if batch_cache:
    append_to_split_files(batch_cache, splits_output_dir, num_splits)
    
print(f"\nCOMPLETED.")
print(f"Total Runners with History: {processed_count}")
print(f"Total Training Examples: {total_examples_created}")
print(f"Files saved to: {splits_output_dir}/")