# Analyze Predictions Notebook

Sections:
1) Setup
2) Load metadata and session data
3) Load submission(s) and build prediction matrices
4) Combine predictions with masked SBP
5) Visualizations
6) Compare two submissions (optional)
7) Diagnostics / stats

## 1) Setup

In [None]:
from pathlib import Path
import warnings

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (12, 4)
plt.rcParams['image.aspect'] = 'auto'

try:
    import ipywidgets as widgets
    from IPython.display import display
    HAS_WIDGETS = True
except Exception:
    HAS_WIDGETS = False

# User configuration
DATA_DIR = "kaggle_data"
METADATA_PATH = f"{DATA_DIR}/metadata.csv"
SUBMISSION_PATHS = ["full_window.csv", "adversarial.csv"]  # one or two files
SESSION_ID = "S008"
CHANNEL = 10
TRIAL_ID = None
START = 0
END = 2000
INFER_MASK_FROM_ZERO = True

print(f"HAS_WIDGETS={HAS_WIDGETS}")
print(f"DATA_DIR={DATA_DIR}")
print(f"SUBMISSION_PATHS={SUBMISSION_PATHS}")

## 2) Load metadata and session data

In [None]:
def _candidate_paths(split_dir: Path, session_id: str, kind: str):
    if kind == 'sbp':
        return [
            split_dir / f"{session_id}_sbp_masked.npy",
            split_dir / f"{session_id}_sbp.npy",
            split_dir / session_id / "sbp_masked.npy",
            split_dir / session_id / "sbp.npy",
        ]
    if kind == 'kinematics':
        return [
            split_dir / f"{session_id}_kinematics.npy",
            split_dir / session_id / "kinematics.npy",
        ]
    if kind == 'trial':
        return [
            split_dir / f"{session_id}_trial_info.npz",
            split_dir / f"{session_id}_trial_ids.npy",
            split_dir / f"{session_id}_trials.csv",
            split_dir / session_id / "trial_info.npz",
            split_dir / session_id / "trial_ids.npy",
            split_dir / session_id / "trials.csv",
        ]
    if kind == 'mask':
        return [
            split_dir / f"{session_id}_mask.npy",
            split_dir / session_id / "mask.npy",
        ]
    raise ValueError(f"Unknown kind={kind}")


def _find_existing(candidates):
    hit = next((p for p in candidates if p.exists()), None)
    return hit


def _load_trial_ids(path: Path, n_bins: int) -> np.ndarray:
    if path.suffix == '.npz':
        obj = np.load(path)
        keys = set(obj.files)
        if {'start_bins', 'end_bins'}.issubset(keys):
            starts = obj['start_bins'].astype(np.int64)
            ends = obj['end_bins'].astype(np.int64)
            trial_ids = np.full((n_bins,), -1, dtype=np.int64)
            for idx, (s, e) in enumerate(zip(starts, ends)):
                if s < 0 or e > n_bins or e <= s:
                    raise ValueError(f"Invalid trial segment in {path}: start={s}, end={e}, n_bins={n_bins}")
                trial_ids[s:e] = idx
            return trial_ids
        if 'trial_ids' in keys:
            trial_ids = obj['trial_ids'].astype(np.int64)
            if len(trial_ids) != n_bins:
                raise ValueError(f"trial_ids length mismatch in {path}: {len(trial_ids)} vs {n_bins}")
            return trial_ids
        raise ValueError(f"Unsupported trial npz format in {path}. Keys={sorted(keys)}")

    if path.suffix == '.npy':
        trial_ids = np.load(path).astype(np.int64)
        if len(trial_ids) != n_bins:
            raise ValueError(f"trial_ids length mismatch in {path}: {len(trial_ids)} vs {n_bins}")
        return trial_ids

    if path.suffix == '.csv':
        df = pd.read_csv(path)
        if 'trial_id' in df.columns:
            trial_ids = df['trial_id'].to_numpy(dtype=np.int64)
            if len(trial_ids) != n_bins:
                raise ValueError(f"trial_id length mismatch in {path}: {len(trial_ids)} vs {n_bins}")
            return trial_ids
        if {'start_bin', 'end_bin'}.issubset(df.columns):
            trial_ids = np.full((n_bins,), -1, dtype=np.int64)
            for idx, row in df.reset_index(drop=True).iterrows():
                s = int(row['start_bin'])
                e = int(row['end_bin'])
                if s < 0 or e > n_bins or e <= s:
                    raise ValueError(f"Invalid trial segment in {path}: start={s}, end={e}, n_bins={n_bins}")
                trial_ids[s:e] = idx
            return trial_ids
        raise ValueError(f"Unsupported trial csv format in {path}. Need trial_id or start_bin/end_bin")

    raise ValueError(f"Unsupported trial file type: {path}")


def load_metadata(metadata_path: str | Path) -> pd.DataFrame:
    meta_path = Path(metadata_path)
    if not meta_path.exists():
        raise FileNotFoundError(f"metadata.csv not found at {meta_path}")
    df = pd.read_csv(meta_path)
    if 'session_id' not in df.columns:
        raise ValueError(f"metadata.csv missing session_id column. Columns: {list(df.columns)}")
    df['session_id'] = df['session_id'].astype(str)
    return df


def resolve_split(metadata: pd.DataFrame, session_id: str, default_split: str = 'test') -> str:
    rows = metadata.loc[metadata['session_id'] == session_id]
    if len(rows) == 0:
        return default_split
    if 'split' in rows.columns:
        split = str(rows.iloc[0]['split'])
        if split:
            return split
    return default_split


def load_session_arrays(data_dir: str | Path, metadata: pd.DataFrame, session_id: str, infer_mask_from_zero: bool = True):
    base = Path(data_dir)
    split = resolve_split(metadata, session_id, default_split='test')
    split_dir = base / split
    if not split_dir.exists():
        raise FileNotFoundError(f"Split directory not found: {split_dir}")

    sbp_candidates = _candidate_paths(split_dir, session_id, 'sbp')
    kin_candidates = _candidate_paths(split_dir, session_id, 'kinematics')
    trial_candidates = _candidate_paths(split_dir, session_id, 'trial')
    mask_candidates = _candidate_paths(split_dir, session_id, 'mask')

    sbp_path = _find_existing(sbp_candidates)
    kin_path = _find_existing(kin_candidates)
    trial_path = _find_existing(trial_candidates)
    mask_path = _find_existing(mask_candidates)

    if sbp_path is None:
        searched = '\n'.join(str(p) for p in sbp_candidates)
        raise FileNotFoundError(f"Missing SBP file for {session_id}. Searched:\n{searched}")
    if kin_path is None:
        searched = '\n'.join(str(p) for p in kin_candidates)
        raise FileNotFoundError(f"Missing kinematics file for {session_id}. Searched:\n{searched}")
    if trial_path is None:
        searched = '\n'.join(str(p) for p in trial_candidates)
        raise FileNotFoundError(f"Missing trial file for {session_id}. Searched:\n{searched}")

    sbp = np.load(sbp_path).astype(np.float32)
    kin = np.load(kin_path).astype(np.float32)
    if sbp.ndim != 2 or sbp.shape[1] != 96:
        raise ValueError(f"Expected sbp shape (T,96), got {sbp.shape} from {sbp_path}")
    if kin.ndim != 2 or kin.shape[1] != 4:
        raise ValueError(f"Expected kinematics shape (T,4), got {kin.shape} from {kin_path}")
    if sbp.shape[0] != kin.shape[0]:
        raise ValueError(f"Length mismatch: sbp T={sbp.shape[0]} vs kin T={kin.shape[0]}")

    trial_ids = _load_trial_ids(trial_path, n_bins=sbp.shape[0])
    if len(trial_ids) != sbp.shape[0]:
        raise ValueError(f"trial_ids length mismatch: {len(trial_ids)} vs T={sbp.shape[0]}")

    if mask_path is not None:
        mask = np.load(mask_path).astype(bool)
        if mask.shape != sbp.shape:
            raise ValueError(f"Mask shape mismatch: {mask.shape} vs sbp {sbp.shape} in {mask_path}")
    elif infer_mask_from_zero:
        mask = (sbp == 0)
        warnings.warn(
            f"Mask file not found for {session_id}. Using inferred mask from sbp==0.",
            stacklevel=1,
        )
    else:
        searched = '\n'.join(str(p) for p in mask_candidates)
        raise FileNotFoundError(
            f"Missing mask file for {session_id} and INFER_MASK_FROM_ZERO=False. Searched:\n{searched}"
        )

    return {
        'session_id': session_id,
        'split': split,
        'sbp': sbp,
        'kinematics': kin,
        'trial_ids': trial_ids,
        'mask': mask,
        'paths': {
            'sbp': str(sbp_path),
            'kinematics': str(kin_path),
            'trial_ids': str(trial_path),
            'mask': str(mask_path) if mask_path is not None else None,
        },
    }


metadata = load_metadata(METADATA_PATH)
test_ids = []
if 'split' in metadata.columns:
    test_ids = sorted(metadata.loc[metadata['split'] == 'test', 'session_id'].astype(str).tolist())

print(f"metadata rows={len(metadata)}")
print(f"test sessions in metadata={len(test_ids)}")
if test_ids:
    print(f"first test IDs: {test_ids[:8]}")

## 3) Load submission(s) and build prediction matrices

In [None]:
def load_submission_csv(path: str | Path) -> pd.DataFrame:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Submission file not found: {p}")
    df = pd.read_csv(p)
    required = {'sample_id', 'session_id', 'time_bin', 'channel', 'predicted_sbp'}
    missing = required - set(df.columns)
    if missing:
        raise ValueError(f"Submission {p} missing columns: {sorted(missing)}")

    df = df[['sample_id', 'session_id', 'time_bin', 'channel', 'predicted_sbp']].copy()
    df['session_id'] = df['session_id'].astype(str)
    df['time_bin'] = df['time_bin'].astype(np.int64)
    df['channel'] = df['channel'].astype(np.int64)
    df['predicted_sbp'] = df['predicted_sbp'].astype(np.float32)

    bad_ch = df[(df['channel'] < 0) | (df['channel'] > 95)]
    if len(bad_ch) > 0:
        raise ValueError(f"Found {len(bad_ch)} rows with channel outside [0,95] in {p}")
    bad_t = df[df['time_bin'] < 0]
    if len(bad_t) > 0:
        raise ValueError(f"Found {len(bad_t)} rows with negative time_bin in {p}")
    return df


def build_prediction_matrices(df: pd.DataFrame, expected_lengths: dict[str, int] | None = None):
    mats = {}
    coverage = {}

    for sid, grp in df.groupby('session_id', sort=True):
        max_t = int(grp['time_bin'].max()) + 1
        t_len = max_t
        if expected_lengths is not None and sid in expected_lengths:
            t_len = max(t_len, int(expected_lengths[sid]))

        pred = np.full((t_len, 96), np.nan, dtype=np.float32)
        tb = grp['time_bin'].to_numpy(dtype=np.int64)
        ch = grp['channel'].to_numpy(dtype=np.int64)
        vals = grp['predicted_sbp'].to_numpy(dtype=np.float32)

        pred[tb, ch] = vals
        mats[sid] = pred
        coverage[sid] = float(np.isfinite(pred).mean())

    return mats, coverage


submission_frames = {}
submission_mats = {}

for sp in SUBMISSION_PATHS:
    s_df = load_submission_csv(sp)
    submission_frames[sp] = s_df
    mats, cov = build_prediction_matrices(s_df)
    submission_mats[sp] = mats
    print(f"Loaded {sp}: rows={len(s_df)}, sessions={len(mats)}, mean_coverage={np.mean(list(cov.values())):.4f}")

if len(SUBMISSION_PATHS) == 0:
    raise ValueError("SUBMISSION_PATHS cannot be empty")
if len(SUBMISSION_PATHS) > 2:
    warnings.warn("Only first two submissions will be used for pairwise comparison plots.")

## 4) Combine predictions with masked SBP

In [None]:
def align_prediction_to_session(pred: np.ndarray, target_t: int) -> np.ndarray:
    if pred.shape[1] != 96:
        raise ValueError(f"Prediction must be (T,96), got {pred.shape}")
    if pred.shape[0] == target_t:
        return pred
    out = np.full((target_t, 96), np.nan, dtype=np.float32)
    t_copy = min(target_t, pred.shape[0])
    out[:t_copy] = pred[:t_copy]
    if pred.shape[0] != target_t:
        warnings.warn(f"Prediction T={pred.shape[0]} differs from session T={target_t}. Truncated/padded to match.")
    return out


def fill_reconstruction(sbp: np.ndarray, mask: np.ndarray, pred: np.ndarray) -> np.ndarray:
    if sbp.shape != mask.shape:
        raise ValueError(f"sbp shape {sbp.shape} != mask shape {mask.shape}")
    if pred.shape != sbp.shape:
        raise ValueError(f"pred shape {pred.shape} != sbp shape {sbp.shape}")
    recon = sbp.copy()
    usable = mask & np.isfinite(pred)
    recon[usable] = pred[usable]
    return recon


def _corr_safe(x: np.ndarray, y: np.ndarray) -> float:
    if len(x) < 2 or np.std(x) < 1e-12 or np.std(y) < 1e-12:
        return np.nan
    return float(np.corrcoef(x, y)[0, 1])


def compute_stats(sbp: np.ndarray, pred: np.ndarray, mask: np.ndarray):
    obs = ~mask
    pred_finite = np.isfinite(pred)
    valid_obs = obs & pred_finite
    valid_masked = mask & pred_finite

    stats = {
        'fraction_masked': float(mask.mean()),
        'fraction_observed': float(obs.mean()),
        'fraction_pred_finite': float(pred_finite.mean()),
        'fraction_masked_pred_finite': float(valid_masked.mean()),
        'pred_masked_mean': float(np.nanmean(pred[valid_masked])) if np.any(valid_masked) else np.nan,
        'pred_masked_std': float(np.nanstd(pred[valid_masked])) if np.any(valid_masked) else np.nan,
    }

    if np.any(valid_obs):
        resid = pred[valid_obs] - sbp[valid_obs]
        mse = float(np.mean(resid ** 2))
        denom = float(np.mean((sbp[valid_obs]) ** 2))
        nmse = mse / (denom + 1e-12)
        corr = _corr_safe(sbp[valid_obs], pred[valid_obs])
    else:
        mse = np.nan
        nmse = np.nan
        corr = np.nan

    per_channel = []
    for c in range(96):
        vc = valid_obs[:, c]
        if np.any(vc):
            rc = pred[vc, c] - sbp[vc, c]
            mse_c = float(np.mean(rc ** 2))
            den_c = float(np.mean((sbp[vc, c]) ** 2))
            nmse_c = mse_c / (den_c + 1e-12)
            corr_c = _corr_safe(sbp[vc, c], pred[vc, c])
            n_obs = int(vc.sum())
        else:
            mse_c = np.nan
            nmse_c = np.nan
            corr_c = np.nan
            n_obs = 0
        per_channel.append({'channel': c, 'n_obs': n_obs, 'mse': mse_c, 'nmse': nmse_c, 'corr': corr_c})

    stats['mse_observed'] = mse
    stats['nmse_observed'] = nmse
    stats['corr_observed'] = corr
    stats['per_channel'] = pd.DataFrame(per_channel)
    return stats


session = load_session_arrays(DATA_DIR, metadata, SESSION_ID, infer_mask_from_zero=INFER_MASK_FROM_ZERO)
print(f"Loaded session={SESSION_ID} split={session['split']} sbp={session['sbp'].shape} kin={session['kinematics'].shape}")
print(f"Paths: {session['paths']}")

analysis_by_submission = {}
for sp in SUBMISSION_PATHS:
    mats = submission_mats[sp]
    if SESSION_ID not in mats:
        raise KeyError(f"Session {SESSION_ID} not found in submission {sp}")
    pred_aligned = align_prediction_to_session(mats[SESSION_ID], target_t=session['sbp'].shape[0])
    recon = fill_reconstruction(session['sbp'], session['mask'], pred_aligned)
    stats = compute_stats(session['sbp'], pred_aligned, session['mask'])
    analysis_by_submission[sp] = {
        'pred': pred_aligned,
        'recon': recon,
        'stats': stats,
    }
    print(f"{sp}: mse_obs={stats['mse_observed']:.6f}, nmse_obs={stats['nmse_observed']:.6f}, corr_obs={stats['corr_observed']:.6f}")

# Smoke test checks
assert session['sbp'].shape[1] == 96
assert session['mask'].shape == session['sbp'].shape
assert session['mask'].sum() > 0, "No masked entries found for this session"
print('Smoke test passed.')

## 5) Visualizations

In [None]:
def _trial_boundaries(trial_ids: np.ndarray):
    return np.where(np.diff(trial_ids) != 0)[0] + 1


def _rolling_mean_1d(x: np.ndarray, window: int = 25) -> np.ndarray:
    window = int(max(1, window))
    if window <= 1:
        return x.copy()
    kernel = np.ones(window, dtype=np.float64) / window
    return np.convolve(x, kernel, mode='same')


def plot_heatmaps(session, pred, recon, title_prefix=''):
    sbp = session['sbp']
    mask = session['mask']
    trial_ids = session['trial_ids']

    observed_sbp = sbp.copy()
    observed_sbp[mask] = np.nan
    resid_obs = np.full_like(sbp, np.nan, dtype=np.float32)
    obs = ~mask
    resid_obs[obs] = pred[obs] - sbp[obs]

    fig, axes = plt.subplots(1, 3, figsize=(18, 5), constrained_layout=True)
    ims = []
    ims.append(axes[0].imshow(observed_sbp.T, origin='lower'))
    axes[0].set_title(f"{title_prefix}Observed SBP (masked hidden)")
    ims.append(axes[1].imshow(recon.T, origin='lower'))
    axes[1].set_title(f"{title_prefix}Reconstruction (masked filled)")
    ims.append(axes[2].imshow(resid_obs.T, origin='lower', cmap='coolwarm'))
    axes[2].set_title(f"{title_prefix}Residual on observed (pred - observed)")

    for ax in axes:
        ax.set_xlabel('time_bin')
        ax.set_ylabel('channel')
        for b in _trial_boundaries(trial_ids):
            ax.axvline(b, color='k', lw=0.2, alpha=0.25)

    for ax, im in zip(axes, ims):
        plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
    plt.show()


def plot_channel_series(session, pred, channel=10, start=0, end=2000, smooth_window=0, title_prefix=''):
    sbp = session['sbp']
    mask = session['mask']
    T = sbp.shape[0]
    start = int(max(0, start))
    end = int(min(T, end))
    if end <= start:
        raise ValueError(f"Invalid range start={start}, end={end}")
    c = int(channel)
    if c < 0 or c > 95:
        raise ValueError(f"channel must be in [0,95], got {c}")

    t = np.arange(start, end)
    obs_vals = sbp[start:end, c].copy()
    obs_vals[mask[start:end, c]] = np.nan
    pred_vals = pred[start:end, c]
    pred_sm = _rolling_mean_1d(pred_vals, window=smooth_window) if smooth_window and smooth_window > 1 else None

    fig, ax = plt.subplots(figsize=(14, 4), constrained_layout=True)
    ax.plot(t, pred_vals, lw=1.0, label='predicted sbp', color='C1')
    if pred_sm is not None:
        ax.plot(t, pred_sm, lw=1.6, label=f'predicted smoothed (w={smooth_window})', color='C3')
    ax.plot(t, obs_vals, lw=1.0, label='observed sbp', color='C0')

    masked_t = t[mask[start:end, c]]
    if len(masked_t) > 0:
        ax.scatter(masked_t, pred_vals[mask[start:end, c]], s=6, alpha=0.45, color='black', label='masked locations')

    ax.set_title(f"{title_prefix}Session {session['session_id']} | channel {c}")
    ax.set_xlabel('time_bin')
    ax.set_ylabel('sbp')
    ax.legend(loc='best')
    plt.show()


def plot_distributions(session, pred, bins=80, title_prefix=''):
    sbp = session['sbp']
    mask = session['mask']
    obs = ~mask
    pred_finite = np.isfinite(pred)

    masked_vals = pred[mask & pred_finite]
    observed_vals = sbp[obs]
    resid_obs = pred[obs & pred_finite] - sbp[obs & pred_finite]

    fig, axes = plt.subplots(1, 2, figsize=(14, 4), constrained_layout=True)
    axes[0].hist(observed_vals, bins=bins, alpha=0.6, label='observed entries', density=True)
    axes[0].hist(masked_vals, bins=bins, alpha=0.6, label='predicted at masked entries', density=True)
    axes[0].set_title(f"{title_prefix}Value distribution")
    axes[0].set_xlabel('sbp')
    axes[0].set_ylabel('density')
    axes[0].legend()

    axes[1].hist(resid_obs, bins=bins, alpha=0.8, density=True)
    axes[1].set_title(f"{title_prefix}Residual distribution on observed (pred - observed)")
    axes[1].set_xlabel('residual')
    axes[1].set_ylabel('density')
    plt.show()


def pick_trial_segment(trial_ids: np.ndarray, requested_trial_id=None):
    if requested_trial_id is None:
        valid = [int(x) for x in np.unique(trial_ids) if x >= 0]
        if not valid:
            raise ValueError('No non-negative trial IDs found')
        requested_trial_id = valid[0]

    idx = np.where(trial_ids == requested_trial_id)[0]
    if len(idx) == 0:
        raise ValueError(f"trial_id={requested_trial_id} not present in session")
    return int(requested_trial_id), int(idx.min()), int(idx.max()) + 1


def plot_trial_zoom(session, pred, channel=10, requested_trial_id=None, start=None, end=None, title_prefix=''):
    tid, seg_start, seg_end = pick_trial_segment(session['trial_ids'], requested_trial_id)

    if start is None:
        start = seg_start
    if end is None:
        end = seg_end
    start = int(max(seg_start, start))
    end = int(min(seg_end, end))

    print(f"Using trial_id={tid}, window=[{start}, {end})")
    plot_channel_series(session, pred, channel=channel, start=start, end=end, smooth_window=15, title_prefix=title_prefix)

    sbp = session['sbp'][start:end]
    mask = session['mask'][start:end]
    recon = fill_reconstruction(session['sbp'], session['mask'], pred)[start:end]
    resid_obs = np.full_like(sbp, np.nan, dtype=np.float32)
    obs = ~mask
    resid_obs[obs] = pred[start:end][obs] - sbp[obs]

    fig, axes = plt.subplots(1, 3, figsize=(18, 4), constrained_layout=True)
    im0 = axes[0].imshow(np.where(mask, np.nan, sbp).T, origin='lower')
    im1 = axes[1].imshow(recon.T, origin='lower')
    im2 = axes[2].imshow(resid_obs.T, origin='lower', cmap='coolwarm')
    axes[0].set_title('Observed SBP (trial zoom)')
    axes[1].set_title('Reconstruction (trial zoom)')
    axes[2].set_title('Residual on observed (trial zoom)')
    for ax in axes:
        ax.set_xlabel('time_bin (zoomed)')
        ax.set_ylabel('channel')
    for ax, im in zip(axes, [im0, im1, im2]):
        plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
    plt.show()


primary_submission = SUBMISSION_PATHS[0]
primary = analysis_by_submission[primary_submission]
plot_heatmaps(session, primary['pred'], primary['recon'], title_prefix=f"[{Path(primary_submission).name}] ")
plot_channel_series(session, primary['pred'], channel=CHANNEL, start=START, end=END, smooth_window=25, title_prefix=f"[{Path(primary_submission).name}] ")
plot_distributions(session, primary['pred'], title_prefix=f"[{Path(primary_submission).name}] ")
plot_trial_zoom(session, primary['pred'], channel=CHANNEL, requested_trial_id=TRIAL_ID, start=START, end=END, title_prefix=f"[{Path(primary_submission).name}] ")

## 6) Compare two submissions (optional)

In [None]:
def compare_two_submissions(sub_path_a: str, sub_path_b: str, data_dir: str, metadata: pd.DataFrame):
    df_a = load_submission_csv(sub_path_a)
    df_b = load_submission_csv(sub_path_b)
    mats_a, _ = build_prediction_matrices(df_a)
    mats_b, _ = build_prediction_matrices(df_b)

    common = sorted(set(mats_a.keys()) & set(mats_b.keys()))
    if not common:
        raise ValueError('No common sessions between the two submissions')

    rows = []
    for sid in common:
        sess = load_session_arrays(data_dir, metadata, sid, infer_mask_from_zero=INFER_MASK_FROM_ZERO)
        pa = align_prediction_to_session(mats_a[sid], sess['sbp'].shape[0])
        pb = align_prediction_to_session(mats_b[sid], sess['sbp'].shape[0])
        valid = np.isfinite(pa) & np.isfinite(pb)
        if np.any(valid):
            diff = np.abs(pa[valid] - pb[valid])
            mean_abs = float(np.mean(diff))
            max_abs = float(np.max(diff))
        else:
            mean_abs = np.nan
            max_abs = np.nan
        rows.append({'session_id': sid, 'mean_abs_diff': mean_abs, 'max_abs_diff': max_abs})

    cmp_df = pd.DataFrame(rows).sort_values('max_abs_diff', ascending=False)
    return cmp_df, mats_a, mats_b


if len(SUBMISSION_PATHS) >= 2:
    sub_a = SUBMISSION_PATHS[0]
    sub_b = SUBMISSION_PATHS[1]
    cmp_df, mats_a, mats_b = compare_two_submissions(sub_a, sub_b, DATA_DIR, metadata)
    display(cmp_df.head(15))

    fig, axes = plt.subplots(1, 2, figsize=(14, 4), constrained_layout=True)
    axes[0].plot(cmp_df['mean_abs_diff'].to_numpy())
    axes[0].set_title('Mean abs diff per session')
    axes[0].set_xlabel('session rank (sorted by max_abs_diff)')
    axes[0].set_ylabel('mean_abs_diff')

    axes[1].plot(cmp_df['max_abs_diff'].to_numpy())
    axes[1].set_title('Max abs diff per session')
    axes[1].set_xlabel('session rank (sorted by max_abs_diff)')
    axes[1].set_ylabel('max_abs_diff')
    plt.show()

    if SESSION_ID in mats_a and SESSION_ID in mats_b:
        s = load_session_arrays(DATA_DIR, metadata, SESSION_ID, infer_mask_from_zero=INFER_MASK_FROM_ZERO)
        pa = align_prediction_to_session(mats_a[SESSION_ID], s['sbp'].shape[0])
        pb = align_prediction_to_session(mats_b[SESSION_ID], s['sbp'].shape[0])
        abs_diff = np.abs(pa - pb)
        fig, ax = plt.subplots(figsize=(14, 4), constrained_layout=True)
        im = ax.imshow(abs_diff.T, origin='lower')
        ax.set_title(f"Abs diff heatmap: {Path(sub_a).name} vs {Path(sub_b).name} | session {SESSION_ID}")
        ax.set_xlabel('time_bin')
        ax.set_ylabel('channel')
        for b in _trial_boundaries(s['trial_ids']):
            ax.axvline(b, color='k', lw=0.2, alpha=0.25)
        plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
        plt.show()
else:
    print('Only one submission provided; skipping section 6.')

## 7) Diagnostics / stats

In [None]:
for sp, obj in analysis_by_submission.items():
    print('\n' + '=' * 80)
    print(f"Submission: {sp}")
    s = obj['stats']
    print(f"fraction_masked={s['fraction_masked']:.6f}")
    print(f"fraction_observed={s['fraction_observed']:.6f}")
    print(f"fraction_pred_finite={s['fraction_pred_finite']:.6f}")
    print(f"fraction_masked_pred_finite={s['fraction_masked_pred_finite']:.6f}")
    print(f"pred_masked_mean={s['pred_masked_mean']:.6f}")
    print(f"pred_masked_std={s['pred_masked_std']:.6f}")
    print(f"mse_observed={s['mse_observed']:.6f}")
    print(f"nmse_observed={s['nmse_observed']:.6f}")
    print(f"corr_observed={s['corr_observed']:.6f}")

    per_ch = s['per_channel']
    display(per_ch.sort_values('mse', ascending=False).head(10))


# Optional widgets for quick session/channel exploration
if HAS_WIDGETS:
    sid_options = sorted(set(test_ids) | {SESSION_ID})
    sid_dd = widgets.Dropdown(options=sid_options, value=SESSION_ID, description='Session')
    ch_slider = widgets.IntSlider(value=CHANNEL, min=0, max=95, step=1, description='Channel')
    sub_dd = widgets.Dropdown(options=SUBMISSION_PATHS, value=SUBMISSION_PATHS[0], description='Submission')

    def _update(session_id, channel, submission):
        sess = load_session_arrays(DATA_DIR, metadata, session_id, infer_mask_from_zero=INFER_MASK_FROM_ZERO)
        mats = submission_mats[submission]
        if session_id not in mats:
            print(f"Session {session_id} not in {submission}")
            return
        pred = align_prediction_to_session(mats[session_id], sess['sbp'].shape[0])
        recon = fill_reconstruction(sess['sbp'], sess['mask'], pred)
        plot_heatmaps(sess, pred, recon, title_prefix=f"[{Path(submission).name}] ")
        plot_channel_series(sess, pred, channel=channel, start=START, end=END, smooth_window=25, title_prefix=f"[{Path(submission).name}] ")

    out = widgets.interactive_output(_update, {'session_id': sid_dd, 'channel': ch_slider, 'submission': sub_dd})
    display(widgets.HBox([sid_dd, ch_slider, sub_dd]), out)
else:
    print('ipywidgets not available. Edit config variables and re-run cells.')