# SSL Pretraining Kill Test
**Goal**: Test masked-feature-reconstruction SSL pretrain + fine-tune on 3 seeds (42, 43, 44).

**Kill gate**: mean val > 0.2709 to scale up.

**Two-stage**:
1. SSL pretrain (15 epochs, mask 25% of timesteps, reconstruct features)
2. Fine-tune with gru_parity_v1 recipe (50 epochs, early stopping)

**Pipeline**: Setup -> SSL pretrain x3 -> Fine-tune x3 -> Evaluate -> Save to Drive

In [None]:
# Cell 1: Mount Drive, download data from Kaggle, clone repo
import os, json, subprocess

from google.colab import drive
drive.mount('/content/drive')
os.makedirs('/content/drive/MyDrive/wunderfund', exist_ok=True)

!pip install -q kaggle==1.6.14 --force-reinstall
os.makedirs('/root/.kaggle', exist_ok=True)
with open('/root/.kaggle/kaggle.json', 'w') as f:
    json.dump({"username": "vincentvdo6", "key": "FILL_IN"}, f)
os.chmod('/root/.kaggle/kaggle.json', 0o600)

os.makedirs('/content/data', exist_ok=True)
!kaggle datasets download -d vincentvdo6/wunderfund-predictorium -p /content/data/ --force
!unzip -o -q /content/data/wunderfund-predictorium.zip -d /content/data/
!ls /content/data/*.parquet

# Clone repo
REPO = "/content/competition_package"
os.chdir("/content")
subprocess.run(["rm", "-rf", REPO], check=False)
subprocess.run(["git", "clone", "https://github.com/vincentvdo6/competition_package.git", REPO], check=True)
os.chdir(REPO)
os.makedirs("datasets", exist_ok=True)
os.makedirs("logs", exist_ok=True)

subprocess.run(["ln", "-sf", "/content/data/train.parquet", "datasets/train.parquet"], check=True)
subprocess.run(["ln", "-sf", "/content/data/valid.parquet", "datasets/valid.parquet"], check=True)

assert os.path.exists("datasets/train.parquet"), "train.parquet not found!"
assert os.path.exists("datasets/valid.parquet"), "valid.parquet not found!"

commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip()
print(f"Commit: {commit}")
print(f"GPU: {subprocess.check_output(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'], text=True).strip()}")
print("Ready!")

In [None]:
# Cell 2: SSL Pretrain — 3 seeds (15 epochs each, ~9 min/seed)
import subprocess, sys, os
os.chdir("/content/competition_package")

SEEDS = [42, 43, 44]
CONFIG = "configs/gru_ssl_pretrain_v1.yaml"

print(f"=== SSL PRETRAIN ({len(SEEDS)} seeds) ===")
print(f"Config: {CONFIG}")
print("=" * 60, flush=True)

for seed in SEEDS:
    ckpt = f"logs/gru_ssl_pretrain_v1_seed{seed}.pt"
    if os.path.exists(ckpt):
        print(f"seed {seed}: SSL checkpoint exists -- skip")
        continue
    print(f"\n{'='*60}")
    print(f"SSL PRETRAIN seed {seed}")
    print(f"{'='*60}", flush=True)
    proc = subprocess.Popen(
        [sys.executable, "-u", "scripts/train_ssl.py",
         "--config", CONFIG,
         "--seed", str(seed), "--device", "cuda"],
        stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
    )
    for line in proc.stdout:
        print(line, end="", flush=True)
    rc = proc.wait()
    if rc != 0:
        print(f"ERROR: seed {seed} failed with rc={rc}")
    else:
        print(f"seed {seed}: SSL pretrain done")

print(f"\nSSL pretraining complete!")

In [None]:
# Cell 3: Fine-tune from SSL checkpoints — 3 seeds (50 epochs w/ early stopping)
import subprocess, sys, os
os.chdir("/content/competition_package")

SEEDS = [42, 43, 44]
FINETUNE_CONFIG = "configs/gru_parity_v1.yaml"

print(f"=== FINE-TUNE FROM SSL ({len(SEEDS)} seeds) ===")
print(f"Config: {FINETUNE_CONFIG}")
print("=" * 60, flush=True)

for seed in SEEDS:
    ssl_ckpt = f"logs/gru_ssl_pretrain_v1_seed{seed}.pt"
    out_ckpt = f"logs/gru_parity_v1_seed{seed}.pt"
    if not os.path.exists(ssl_ckpt):
        print(f"seed {seed}: SSL checkpoint missing -- skip")
        continue
    if os.path.exists(out_ckpt):
        print(f"seed {seed}: fine-tuned checkpoint exists -- skip")
        continue
    print(f"\n{'='*60}")
    print(f"FINE-TUNE seed {seed} (from {ssl_ckpt})")
    print(f"{'='*60}", flush=True)
    proc = subprocess.Popen(
        [sys.executable, "-u", "scripts/train.py",
         "--config", FINETUNE_CONFIG,
         "--seed", str(seed), "--device", "cuda",
         "--resume", ssl_ckpt],
        stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
    )
    for line in proc.stdout:
        print(line, end="", flush=True)
    rc = proc.wait()
    if rc != 0:
        print(f"ERROR: seed {seed} failed with rc={rc}")
    else:
        print(f"seed {seed}: fine-tune done")

print(f"\nFine-tuning complete!")

In [None]:
# Cell 4: Evaluate — compare SSL vs vanilla baseline
import os, torch
os.chdir("/content/competition_package")

SEEDS = [42, 43, 44]
BASELINE_VALS = {42: 0.2649, 43: 0.2737, 44: 0.2690}  # vanilla baseline reference
GATE = 0.2709  # kill gate: mean must exceed this

print(f"{'Seed':<6} {'SSL Val':>10} {'Baseline':>10} {'Delta':>10}")
print("-" * 40)

ssl_scores = []
for seed in SEEDS:
    pt = f"logs/gru_parity_v1_seed{seed}.pt"
    if not os.path.exists(pt):
        print(f"s{seed}:  MISSING")
        continue
    ckpt = torch.load(pt, map_location="cpu", weights_only=False)
    score = float(ckpt.get("best_score", 0))
    epoch = ckpt.get("best_epoch", "?")
    baseline = BASELINE_VALS.get(seed, 0)
    delta = score - baseline
    ssl_scores.append(score)
    print(f"s{seed:<5} {score:>10.4f} {baseline:>10.4f} {delta:>+10.4f}  (ep {epoch})")

if ssl_scores:
    mean_ssl = sum(ssl_scores) / len(ssl_scores)
    mean_base = sum(BASELINE_VALS[s] for s in SEEDS) / len(SEEDS)
    mean_delta = mean_ssl - mean_base
    print(f"\nMean SSL:      {mean_ssl:.4f}")
    print(f"Mean baseline: {mean_base:.4f}")
    print(f"Mean delta:    {mean_delta:+.4f}")
    print(f"\nKill gate ({GATE:.4f}): {'PASS' if mean_ssl > GATE else 'FAIL'}")
    if mean_ssl > GATE:
        print("-> Scale to 5-10 seeds + submit!")
    elif mean_ssl > 0.2695:
        print("-> Near miss. Train 5 more seeds to check mean.")
    else:
        print("-> KILL SSL direction.")

In [None]:
# Cell 5: Save checkpoints to Drive
import os, torch, shutil
os.chdir("/content/competition_package")

SEEDS = [42, 43, 44]
DRIVE_DIR = "/content/drive/MyDrive/wunderfund"

# Save SSL pretrain checkpoints
for seed in SEEDS:
    for prefix in ["gru_ssl_pretrain_v1", "gru_parity_v1"]:
        pt = f"logs/{prefix}_seed{seed}.pt"
        if os.path.exists(pt):
            # Strip optimizer state to save space
            ckpt = torch.load(pt, map_location="cpu", weights_only=False)
            slim = {
                "model_state_dict": ckpt["model_state_dict"],
                "config": ckpt.get("config", {}),
                "best_score": ckpt.get("best_score", None),
                "best_epoch": ckpt.get("best_epoch", None),
            }
            dst = f"{DRIVE_DIR}/ssl_{prefix}_seed{seed}.pt"
            torch.save(slim, dst)
            sz = os.path.getsize(dst) / 1e6
            print(f"Saved: {dst} ({sz:.1f}MB)")

print("\nDone! Download from Drive or continue to scale up.")