# Day 26 – Smoothing ↔︎ MDE Tuning

Explore how different smoothing controls impact category time-series and downstream ROI→category MDE performance for a single subject/story. The notebook reuses the Day19 category builder and Day22 MDE runner, plots smoothed categories, and captures MDE metrics while saving outputs under `figs/<subject>/<story>/day26_smoothing/`.

In [1]:
import json
import math
import warnings
from collections import Counter
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple


import nbformat
import numpy as np
import pandas as pd

try:
    import matplotlib.pyplot as plt
except Exception as exc:
    plt = None
    raise RuntimeError(f"Matplotlib is required for Day26 plots: {exc}")

from IPython.display import display

import sys

PROJECT_ROOT = Path('/flash/PaoU/seann/fmri-edm-ccm')
sys.path.append(str(PROJECT_ROOT))
sys.path.append('/flash/PaoU/seann/pyEDM/src')
sys.path.append('/flash/PaoU/seann/MDE-main/src')



from src.utils import load_yaml
from src.day22_category_mde import run_mde_for_pair, sanitize_name
from src.decoding import load_transcript_words
from src.edm_ccm import English1000Loader


KeyboardInterrupt: 

In [None]:
PROJECT_ROOT = Path('/flash/PaoU/seann/fmri-edm-ccm')
CONFIG_PATH = PROJECT_ROOT / 'configs' / 'demo.yaml'

cfg = load_yaml(CONFIG_PATH)
paths: Dict[str, str] = cfg.get('paths', {}) or {}
paths.setdefault('project_root', str(PROJECT_ROOT))
for key in ('cache', 'figs', 'results'):
    val = paths.get(key)
    if val and not Path(val).is_absolute():
        paths[key] = str((PROJECT_ROOT / val).resolve())

SUBJECT = (cfg.get('subject') or 'UTS01').strip()
STORY = (cfg.get('story') or 'wheretheressmoke').strip()
TR = float(cfg.get('TR', 2.0))

categories_cfg = cfg.get('categories', {}) or {}
cluster_csv_rel = categories_cfg.get('cluster_csv_path', '')
cluster_csv_path = str((PROJECT_ROOT / cluster_csv_rel).resolve()) if cluster_csv_rel else ''
prototype_weight_power = float(categories_cfg.get('prototype_weight_power', 1.0))
seconds_bin_width_default = float(categories_cfg.get('seconds_bin_width', 0.05))
temporal_weighting_default = str(categories_cfg.get('temporal_weighting', 'proportional')).lower()

TAU_GRID = cfg.get('tau_grid') or [1, 2]
if isinstance(TAU_GRID, (int, float)):
    TAU_GRID = [int(TAU_GRID)]
E_CAP = int(cfg.get('E_mult', cfg.get('E_cap', 6)))

FIGS_BASE = PROJECT_ROOT / 'figs' / SUBJECT / STORY / 'day26_smoothing'
FIGS_BASE.mkdir(parents=True, exist_ok=True)

FEATURES_EVAL_BASE = PROJECT_ROOT / 'features_day26_eval'
FEATURES_EVAL_BASE.mkdir(parents=True, exist_ok=True)

print(f'Subject/story: {SUBJECT} / {STORY}')
print(f'Default tau grid: {TAU_GRID} | embedding cap: {E_CAP}')
print(f'Cluster CSV: {cluster_csv_path or "<none>"}')


In [None]:
# Reuse helper functions from Day19
DAY19_NOTEBOOK = PROJECT_ROOT / 'notebooks' / 'Day19_smooth_categories.ipynb'
day19_nb = nbformat.read(DAY19_NOTEBOOK, as_version=4)
helper_code = day19_nb.cells[3].source
generator_code = day19_nb.cells[4].source

# Provide globals expected by the Day19 utilities
EPS = 1e-12

exec(helper_code, globals())
exec(generator_code, globals())

print('Loaded Day19 helper + generator functions.')


In [None]:
# Build smoothing configuration grid with biologically motivated band
PLAUDIBLE_BAND = (0.5, 2.0)  # seconds, roughly matching expected semantic/BOLD dynamics
smooth_seconds_grid = np.round(np.arange(0.0, 1.25 + 1e-9, 0.25), 2)
SMOOTHING_CONFIGS: List[Dict[str, object]] = []

def _format_seconds_tag(seconds: float) -> str:
    return f"{seconds:.2f}".replace('.', 'p')

# Explicit "no smoothing" baseline at 0.0s
SMOOTHING_CONFIGS.append({
    'name': 'none_0p00',
    'smoothing_seconds': 0.0,
    'method': 'none',
    'gaussian_sigma_seconds': None,
    'pad_mode': 'edge',
    'seconds_bin_width': seconds_bin_width_default,
    'temporal_weighting': temporal_weighting_default,
    'within_band': True,
})

# Add Gaussian + Moving Average for the remaining windows
for seconds in smooth_seconds_grid:
    seconds = float(seconds)
    if seconds == 0.0:
        continue
    tag = _format_seconds_tag(seconds)

    # Gaussian smoothing configuration
    SMOOTHING_CONFIGS.append({
        'name': f'gauss_{tag}',
        'smoothing_seconds': seconds,
        'method': 'gaussian',
        'gaussian_sigma_seconds': seconds * 0.5,
        'pad_mode': 'reflect',
        'seconds_bin_width': seconds_bin_width_default,
        'temporal_weighting': temporal_weighting_default,
        'within_band': True,
    })

    # Moving-average smoothing configuration
    SMOOTHING_CONFIGS.append({
        'name': f'movavg_{tag}',
        'smoothing_seconds': seconds,
        'method': 'moving_average',
        'gaussian_sigma_seconds': None,
        'pad_mode': 'edge',
        'seconds_bin_width': seconds_bin_width_default,
        'temporal_weighting': temporal_weighting_default,
        'within_band': True,
    })

print(f"Evaluating {len(SMOOTHING_CONFIGS)} settings across windows {smooth_seconds_grid.tolist()} (none + gaussian + movavg).")


In [None]:
TARGET_COLUMN = 'cat_abstract'  # adjust if you want to optimize a different category

results: List[Dict] = []

for config in SMOOTHING_CONFIGS:
    cfg_name = str(config['name']).strip() or 'config'
    safe_name = sanitize_name(cfg_name)
    smoothing_seconds = float(config['smoothing_seconds'])
    smoothing_method = str(config['method'])
    gaussian_sigma = config.get('gaussian_sigma_seconds')
    pad_mode = config.get('pad_mode', 'edge')
    seconds_bin_width = float(config.get('seconds_bin_width', seconds_bin_width_default))
    temporal_weighting = str(config.get('temporal_weighting', temporal_weighting_default))
    within_band = bool(config.get('within_band', False))

    print(f"=== Smoothing config: {cfg_name} | method={smoothing_method} | window={smoothing_seconds:.2f}s ===")

    features_root = FEATURES_EVAL_BASE / safe_name
    globals()['features_root'] = features_root
    features_root.mkdir(parents=True, exist_ok=True)

    cluster_csv_use = cluster_csv_path
    if cluster_csv_use and not Path(cluster_csv_use).is_absolute():
        cluster_csv_use = str((PROJECT_ROOT / cluster_csv_use).resolve())
    if cluster_csv_use and not Path(cluster_csv_use).exists():
        warnings.warn(f'Cluster CSV not found at {cluster_csv_use}; proceeding without clusters.')
        cluster_csv_use = ''

    result = generate_category_time_series(
        SUBJECT,
        STORY,
        cfg_base=cfg,
        categories_cfg_base=categories_cfg,
        cluster_csv_path=cluster_csv_use or '',
        temporal_weighting=temporal_weighting,
        prototype_weight_power=prototype_weight_power,
        smoothing_seconds=smoothing_seconds,
        smoothing_method=smoothing_method,
        gaussian_sigma_seconds=gaussian_sigma,
        smoothing_pad=pad_mode,
        seconds_bin_width=seconds_bin_width,
        save_outputs=False,
    )

    category_df = result['category_df_selected']
    category_cols = result['category_columns']
    if not category_cols:
        raise RuntimeError('No category columns generated. Check configuration.')
    if TARGET_COLUMN not in category_cols:
        raise RuntimeError(f"Target column {TARGET_COLUMN} not present in category dataframe.")

    target_series = category_df[TARGET_COLUMN].astype(float)
    target_std = float(target_series.std(ddof=1))
    target_range = float(target_series.max() - target_series.min())
    target_diff = target_series.diff().dropna()
    target_diff_mean = float(target_diff.abs().mean()) if not target_diff.empty else 0.0

    top_cols = category_cols[:12]
    fig, axes = plt.subplots(3, 4, figsize=(14, 8), sharex=True)
    axes = axes.flatten()
    time_axis = category_df['start_sec']
    for ax_idx, col in enumerate(top_cols):
        ax = axes[ax_idx]
        ax.plot(time_axis, category_df[col], linewidth=1.0)
        ax.set_title(col)
        ax.grid(alpha=0.3)
    for ax_idx in range(len(top_cols), len(axes)):
        axes[ax_idx].axis('off')
    fig.suptitle(f'{SUBJECT} / {STORY} – {cfg_name} smoothing')
    fig.tight_layout(rect=(0, 0, 1, 0.96))

    plot_dir = FIGS_BASE / safe_name
    plot_dir.mkdir(parents=True, exist_ok=True)
    plot_path = plot_dir / 'category_timeseries_overview.png'
    fig.savefig(plot_path, dpi=180)
    plt.close(fig)
    print(f'Saved category plot to {plot_path}')

    eval_root = FEATURES_EVAL_BASE / safe_name
    category_dir = eval_root / 'subjects' / SUBJECT / STORY
    category_dir.mkdir(parents=True, exist_ok=True)
    cat_csv = category_dir / 'category_timeseries.csv'
    category_df.to_csv(cat_csv, index=False)

    figs_root_config = FIGS_BASE / safe_name

    summary = run_mde_for_pair(
        SUBJECT,
        STORY,
        target_column=TARGET_COLUMN,
        features_root=eval_root,
        figs_root=figs_root_config,
        paths_cfg=paths,
        n_parcels=int(cfg.get('n_parcels', 400)),
        tau_grid=TAU_GRID,
        E_cap=E_CAP,
        lib_sizes=cfg.get('lib_sizes', [80, 120, 160]),
        delta_default=int((cfg.get('delta') or [1])[0]),
        theiler_min=int(cfg.get('theiler_min', 3)),
        train_frac=0.5,
        val_frac=0.25,
        top_n_plot=6,
        save_input_frame=True,
        save_scatter=True,
        plt_module=plt,
    )

    selection_path = Path(summary['selection_csv'])
    mde_df = pd.read_csv(selection_path)
    rho_col = next((c for c in mde_df.columns if c.lower().startswith('rho')), None)
    best_rho = float(mde_df[rho_col].iloc[0]) if rho_col and not mde_df.empty else float('nan')
    rho_top5 = mde_df[rho_col].head(5) if rho_col else pd.Series(dtype=float)
    rho_mean_top5 = float(rho_top5.mean()) if not rho_top5.empty else float('nan')
    rho_median_top5 = float(rho_top5.median()) if not rho_top5.empty else float('nan')
    rho_std_top5 = float(rho_top5.std(ddof=1)) if len(rho_top5) >= 2 else float('nan')
    positive_rho_top5 = int((rho_top5 > 0).sum()) if not rho_top5.empty else 0

    best_var_col = 'variables' if 'variables' in mde_df.columns else 'variable'
    best_variable = str(mde_df[best_var_col].iloc[0]) if not mde_df.empty else ''
    top5_variables = mde_df[best_var_col].head(5).astype(str) if best_var_col in mde_df.columns else pd.Series(dtype=str)
    unique_top5 = int(top5_variables.nunique()) if not top5_variables.empty else 0

    results.append({
        'config': cfg_name,
        'safe_name': safe_name,
        'method': smoothing_method,
        'smoothing_seconds': smoothing_seconds,
        'within_plausible_band': within_band,
        'gaussian_sigma_seconds': gaussian_sigma,
        'pad_mode': pad_mode,
        'seconds_bin_width': seconds_bin_width,
        'temporal_weighting': temporal_weighting,
        'top_variable': best_variable,
        'top_rho': best_rho,
        'rho_mean_top5': rho_mean_top5,
        'rho_median_top5': rho_median_top5,
        'rho_std_top5': rho_std_top5,
        'positive_rho_top5': positive_rho_top5,
        'unique_top5_variables': unique_top5,
        'target_std': target_std,
        'target_range': target_range,
        'target_diff_abs_mean': target_diff_mean,
        'selection_csv': str(selection_path),
        'plot_dir': str(plot_dir),
        'mde_dir': str(figs_root_config / 'day22_category_mde'),
    })

results_df = pd.DataFrame(results)
results_df = results_df.sort_values(by=['method', 'smoothing_seconds']).reset_index(drop=True)
display(results_df)
summary_path = FIGS_BASE / 'day26_mde_smoothing_summary.csv'
results_df.to_csv(summary_path, index=False)
print(f'Summary saved to {summary_path}')

# Build method × smoothing matrices for multiple metrics
if not results_df.empty:
    metrics_to_pivot = {
        'top_rho': 'Top rho',
        'rho_mean_top5': 'Mean rho (top5)',
        'rho_median_top5': 'Median rho (top5)',
        'target_std': 'Target std',
        'target_range': 'Target range',
        'target_diff_abs_mean': 'Mean |Δtarget|',
    }
    pivot_dir = FIGS_BASE / 'matrices'
    pivot_dir.mkdir(parents=True, exist_ok=True)

    for metric_key, metric_label in metrics_to_pivot.items():
        matrix = results_df.pivot(index='smoothing_seconds', columns='method', values=metric_key)
        display(matrix)
        matrix_path = pivot_dir / f'day26_{metric_key}_matrix.csv'
        matrix.to_csv(matrix_path)
        print(f"{metric_label} matrix saved to {matrix_path}")
else:
    print('No results captured; skipping matrix export.')
