# Expand Seed Pool + Per-Target Swap Discovery

Trains new vanilla GRU seeds (configurable range) on Colab GPU, caches their validation
predictions, then discovers the best 2-slot per-target swap candidates and builds ONNX
submission zips.

**Required mirror layout (Google Drive):**
- `wunderfund_mirror/train.parquet`
- `wunderfund_mirror/valid.parquet`
- `wunderfund_mirror/vanilla_all/gru_parity_v1_seed42.pt` ... `seed64.pt`

**Workflow:** Run cells 1-9 in order. Cell 2 (Config) is the only one to edit between sessions.

In [None]:
from google.colab import drive
from pathlib import Path
import os, shutil, subprocess, sys, re, time, json

drive.mount('/content/drive')
print('Drive mounted')

In [None]:
# =============================================================
# CONFIG  -  only cell you need to edit between sessions
# =============================================================

REPO_URL = 'https://github.com/vincentvdo6/competition_package.git'
BRANCH   = 'master'
REPO     = Path('/content/competition_package')
MIRROR   = Path('/content/drive/MyDrive/wunderfund_mirror')

# Seeds already trained and stored in mirror
EXISTING_SEEDS = list(range(42, 65))          # 23 seeds: 42-64

# New seeds to train this session (set to [] to skip training)
NEW_SEEDS = list(range(65, 85))               # 20 new seeds: 65-84

ALL_SEEDS = EXISTING_SEEDS + NEW_SEEDS

# Current best anchor ensemble (per-target weights)
ANCHOR_SEEDS = [43, 44, 45, 46, 50, 54, 55, 57, 58, 59, 60, 61, 63, 64]
ANCHOR_W0    = [1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0]             # t0 weights
ANCHOR_W1    = [1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0.25, 0, 1.75]       # t1 weights

# Session tag used in output file names
SESSION_TAG  = 'feb22-b1'
SLOT_A_NAME  = f'{SESSION_TAG}-t1swap-a-onnx.zip'
SLOT_B_NAME  = f'{SESSION_TAG}-t0swap-b-onnx.zip'

# Cache and discovery settings
CACHE_DIR    = 'cache/all_seeds_valid_preds'
REPORT_PATH  = f'logs/parity_swap_discovery_{SESSION_TAG}.json'
BOOTSTRAP    = 200
DELTA_MIN    = 0.00015
P10_MAX_DROP = 0.00005

print(f'Existing seeds : {len(EXISTING_SEEDS)} ({EXISTING_SEEDS[0]}-{EXISTING_SEEDS[-1]})')
print(f'New seeds      : {len(NEW_SEEDS)} ({NEW_SEEDS[0] if NEW_SEEDS else "none"}-{NEW_SEEDS[-1] if NEW_SEEDS else "none"})')
print(f'Total pool     : {len(ALL_SEEDS)} seeds')
print(f'Session        : {SESSION_TAG}')

In [None]:
# Clone / refresh repo
if REPO.exists():
    shutil.rmtree(REPO)

subprocess.run(['git', 'clone', '--branch', BRANCH, REPO_URL, str(REPO)], check=True)
os.chdir(REPO)
subprocess.run(['git', 'pull', 'origin', BRANCH], check=True)

for script in [
    'scripts/train.py',
    'scripts/greedy_vanilla_ensemble.py',
    'scripts/discover_parity_swaps.py',
    'scripts/check_submission_zip.py',
]:
    assert (REPO / script).exists(), f'Missing script: {script}'

print('Repo ready:', REPO)

In [None]:
# Copy datasets and existing checkpoints from mirror
(REPO / 'datasets').mkdir(parents=True, exist_ok=True)
(REPO / 'logs' / 'vanilla_all').mkdir(parents=True, exist_ok=True)

# Datasets
for fname in ['train.parquet', 'valid.parquet']:
    src = MIRROR / fname
    assert src.exists(), f'Missing from mirror: {src}'
    shutil.copy2(src, REPO / 'datasets' / fname)
    print(f'Copied {fname}')

# Existing checkpoints
ckpt_mirror = MIRROR / 'vanilla_all'
ckpt_local  = REPO / 'logs' / 'vanilla_all'
n_copied = 0
for seed in EXISTING_SEEDS:
    src = ckpt_mirror / f'gru_parity_v1_seed{seed}.pt'
    if src.exists():
        shutil.copy2(src, ckpt_local / src.name)
        n_copied += 1
    else:
        print(f'  WARN: missing existing checkpoint seed{seed} in mirror')

print(f'Copied {n_copied}/{len(EXISTING_SEEDS)} existing checkpoints')

In [None]:
# Train new seeds  (skips any seed already in local logs/ or mirror)
ckpt_mirror = MIRROR / 'vanilla_all'
ckpt_mirror.mkdir(parents=True, exist_ok=True)
ckpt_local  = REPO / 'logs' / 'vanilla_all'

skipped, trained, failed = [], [], []

for seed in NEW_SEEDS:
    ckpt_name   = f'gru_parity_v1_seed{seed}.pt'
    local_ckpt  = ckpt_local  / ckpt_name
    mirror_ckpt = ckpt_mirror / ckpt_name

    if local_ckpt.exists():
        print(f'seed{seed}: already local -- skip')
        skipped.append(seed)
        continue
    if mirror_ckpt.exists():
        shutil.copy2(mirror_ckpt, local_ckpt)
        print(f'seed{seed}: found in mirror -- copied')
        skipped.append(seed)
        continue

    print(f'\n{"="*60}\nTraining seed {seed} ...\n{"="*60}')
    t_start = time.time()
    cmd = [
        sys.executable, '-u', 'scripts/train.py',
        '--config', 'configs/gru_parity_v1.yaml',
        '--seed', str(seed),
    ]
    proc = subprocess.Popen(
        cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1
    )
    for line in proc.stdout:
        print(line, end='', flush=True)
    rc = proc.wait()
    elapsed = time.time() - t_start

    if rc != 0:
        print(f'  ERROR: seed{seed} failed (exit {rc})')
        failed.append(seed)
        continue

    if local_ckpt.exists():
        shutil.copy2(local_ckpt, mirror_ckpt)
        print(f'  Saved to mirror [{elapsed:.0f}s]')
        trained.append(seed)
    else:
        print(f'  WARN: checkpoint not found after training: {local_ckpt}')
        failed.append(seed)

print(f'\nSummary: skipped={len(skipped)}, trained={len(trained)}, failed={len(failed)}')
if failed:
    print(f'Failed seeds: {failed}')

In [None]:
# Cache validation predictions for all available checkpoints
all_ckpts = sorted((REPO / 'logs' / 'vanilla_all').glob('gru_parity_v1_seed*.pt'))
print(f'Found {len(all_ckpts)} checkpoints to cache')

cmd = [
    sys.executable, 'scripts/greedy_vanilla_ensemble.py', 'cache',
    '--checkpoints', *[str(c) for c in all_ckpts],
    '--data', 'datasets/valid.parquet',
    '--cache-dir', CACHE_DIR,
]
proc = subprocess.Popen(
    cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1
)
for line in proc.stdout:
    print(line, end='', flush=True)
rc = proc.wait()
print(f'\nCache exit code: {rc}')
assert rc == 0, 'Cache step failed'

In [None]:
# Swap discovery: LOO + candidate swaps + bootstrap delta + build ONNX zips

# Determine which seeds were actually cached (handles any training failures)
cache_path = REPO / CACHE_DIR
cached_seeds = sorted(
    int(re.search(r'seed(\d+)', f.stem).group(1))
    for f in cache_path.glob('gru_parity_v1_seed*.npz')
)
print(f'Cached seeds ({len(cached_seeds)}): {cached_seeds}')

# Verify all anchor seeds are present in cache
missing_anchor = [s for s in ANCHOR_SEEDS if s not in cached_seeds]
assert not missing_anchor, f'Anchor seeds missing from cache: {missing_anchor}'

cmd = [
    sys.executable, 'scripts/discover_parity_swaps.py',
    '--cache-dir',      CACHE_DIR,
    '--data',           'datasets/valid.parquet',
    '--required-seeds', *[str(s) for s in cached_seeds],
    '--anchor-seeds',   *[str(s) for s in ANCHOR_SEEDS],
    '--anchor-w0',      *[str(w) for w in ANCHOR_W0],
    '--anchor-w1',      *[str(w) for w in ANCHOR_W1],
    '--delta-min',      str(DELTA_MIN),
    '--p10-max-drop',   str(P10_MAX_DROP),
    '--bootstrap',      str(BOOTSTRAP),
    '--output-report',  REPORT_PATH,
    '--output-dir',     'submissions/ready',
    '--slot-a-name',    SLOT_A_NAME,
    '--slot-b-name',    SLOT_B_NAME,
]
proc = subprocess.Popen(
    cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1
)
for line in proc.stdout:
    print(line, end='', flush=True)
rc = proc.wait()
print(f'\nDiscovery exit code: {rc}')
assert rc == 0, 'Swap discovery failed'

In [None]:
# Verify built submission zips
import glob as _glob

zips = sorted(_glob.glob(f'submissions/ready/{SESSION_TAG}-*.zip'))
assert zips, f'No zips found for session {SESSION_TAG}'

for z in zips:
    print(f'\n=== {z} ===')
    subprocess.run([sys.executable, 'scripts/check_submission_zip.py', z], check=True)

print(f'\nArtifacts:')
print(f'  Report : {REPORT_PATH}')
for z in zips:
    print(f'  Zip    : {z}')

In [None]:
# View discovery report summary
report_file = REPO / REPORT_PATH
assert report_file.exists(), f'Report not found: {report_file}'
report = json.loads(report_file.read_text())

anchor_sc = report['anchor']['score']
print(f'ANCHOR  avg={anchor_sc["avg"]:.4f}  t0={anchor_sc["t0"]:.4f}  t1={anchor_sc["t1"]:.4f}')
print(f'  seeds={report["anchor"]["seeds"]}')
print()

print('TOP 15 SINGLE SEEDS (by avg val score):')
print(f'{"Seed":>6}  {"avg":>8}  {"t0":>8}  {"t1":>8}')
print('-' * 38)
for row in sorted(report['single_seed_ranking'], key=lambda r: -r['avg'])[:15]:
    marker = ' *' if row['seed'] in report['anchor']['seeds'] else ''
    print(f'  {row["seed"]:>4}   {row["avg"]:>8.4f}  {row["t0"]:>8.4f}  {row["t1"]:>8.4f}{marker}')
print('  (* = in anchor)')
print()

def show_slot(label, slot, zip_name):
    boot = slot['bootstrap_delta_vs_anchor']
    gate = 'PASS' if slot['keep_gate'] else 'FAIL'
    print(f'{label}:')
    print(f'  Replace seed {slot["replace_seed"]} -> candidate seed {slot["candidate_seed"]}')
    print(f'  delta_avg={slot["delta_avg"]:+.5f}  delta_t0={slot["delta_t0"]:+.5f}  delta_t1={slot["delta_t1"]:+.5f}')
    print(f'  bootstrap  mean={boot["mean"]:+.5f}  p10={boot["p10"]:+.5f}  p50={boot["p50"]:+.5f}  gate={gate}')
    print(f'  zip: submissions/ready/{zip_name}')

show_slot('SLOT A (t1 swap)', report['selected']['slot_a_t1'], SLOT_A_NAME)
print()
show_slot('SLOT B (t0 swap)', report['selected']['slot_b_t0'], SLOT_B_NAME)