# 04 — Seed Diversity Analysis (Wave 1)

Goal: quantify seed diversity for `gru_derived_tightwd_v2` and build a stronger seed ensemble.

This notebook performs:
1. Per-seed model scoring (t0, t1, avg weighted Pearson)
2. Pairwise prediction correlation matrix across seeds
3. Uniform ensemble score
4. SLSQP optimal non-negative ensemble weights (sum to 1)
5. Best subset score as model count increases (N=1..5)
6. Diversity analysis by `|target|` quintiles

All outputs are saved in `notebooks/artifacts/04_seed_diversity/`.

In [None]:
from pathlib import Path
import json
import pandas as pd
import matplotlib.pyplot as plt

ROOT = Path('..') if Path.cwd().name == 'notebooks' else Path('.')
ART = ROOT / 'notebooks' / 'artifacts' / '04_seed_diversity'
ART.mkdir(parents=True, exist_ok=True)

SEEDS = [42, 43, 44, 45, 46]
CKPTS = [ROOT / f'logs/gru_derived_tightwd_v2_seed{s}.pt' for s in SEEDS]

print('Expected checkpoints:')
for p in CKPTS:
    print('  ', p, 'exists=', p.exists())

## Run analysis

Important pipeline requirements (implemented in runner):
- Hidden state reset on new `seq_ix`
- Process **all** timesteps including warm-up
- Collect predictions only when `need_prediction=True`
- Apply derived features before normalization

In [None]:
# Run full analysis when checkpoints are available.
# This can take significant time on CPU for all 5 seeds.

# import subprocess, sys
# subprocess.run([
#     sys.executable,
#     str(ROOT / 'notebooks' / 'run_04_seed_diversity_analysis.py'),
#     '--config', 'configs/gru_derived_tightwd_v2.yaml',
#     '--normalizer', 'logs/normalizer.npz',
#     '--data', 'datasets/valid.parquet',
# ], check=True)

## Load tabular outputs

In [None]:
status_path = ART / 'status.json'
if status_path.exists():
    print('Status:')
    print(status_path.read_text(encoding='utf-8'))

model_scores = pd.read_csv(ART / 'model_scores.csv') if (ART / 'model_scores.csv').exists() else None
pairwise = pd.read_csv(ART / 'pairwise_correlations.csv') if (ART / 'pairwise_correlations.csv').exists() else None
ensemble_n = pd.read_csv(ART / 'ensemble_vs_n_models.csv') if (ART / 'ensemble_vs_n_models.csv').exists() else None
diversity = pd.read_csv(ART / 'diversity_by_target_bucket.csv') if (ART / 'diversity_by_target_bucket.csv').exists() else None

if model_scores is not None:
    display(model_scores)
if pairwise is not None:
    display(pairwise)
if ensemble_n is not None:
    display(ensemble_n)
if diversity is not None:
    display(diversity)

In [None]:
opt_path = ART / 'optimal_weights.json'
if opt_path.exists():
    opt = json.load(open(opt_path, 'r', encoding='utf-8'))
    print(json.dumps(opt, indent=2))
else:
    print('optimal_weights.json not found yet')

## Plots

In [None]:
plot_files = [
    'per_model_scores.png',
    'pairwise_corr_heatmap.png',
    'ensemble_vs_n_models.png',
    'diversity_by_target_bucket.png',
]

existing = [p for p in plot_files if (ART / p).exists()]
if not existing:
    print('No plots found yet. Run analysis cell first.')
else:
    n = len(existing)
    fig, axes = plt.subplots(n, 1, figsize=(12, 4 * n))
    if n == 1:
        axes = [axes]
    for ax, fname in zip(axes, existing):
        ax.imshow(plt.imread(ART / fname))
        ax.axis('off')
        ax.set_title(fname)
    plt.tight_layout()
    plt.show()

## Interpretation checklist

- Do pairwise correlations stay < 1.0? (needed for ensemble gain)
- Does optimized weighting beat simple average?
- At which N does ensemble return diminish?
- Is disagreement concentrated in higher `|target|` quintiles?

Use these answers to pick final seed ensemble weights for export/submission.