# Day 22 â€“ ROI to Category MDE

Run Multiview Distance Embedding (MDE) to discover lagged ROI predictors for each Day20 category time series.
Each subject/story pair is analysed independently and results (tables + plots) are written per story.


In [None]:
import json
import math
import os
import sys
import warnings
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from collections import defaultdict

import numpy as np
import pandas as pd

project_root = Path('/flash/PaoU/seann/fmri-edm-ccm')
project_root.mkdir(parents=True, exist_ok=True)
os.chdir(project_root)

sys.path.append(str(project_root))
sys.path.append('/flash/PaoU/seann/pyEDM/src')
sys.path.append('/flash/PaoU/seann/MDE-main/src')

try:
    from IPython.display import display, HTML, IFrame
except Exception:
    def display(obj):
        print(obj)
    def HTML(data):
        return data
    def IFrame(*args, **kwargs):
        return None

try:
    import matplotlib.pyplot as plt
except Exception as exc:
    plt = None
    warnings.warn(f'Matplotlib unavailable: {exc}')

from src.utils import load_yaml
from src.day22_category_mde import run_mde_for_pair, save_combined_roi_view
from src.day24_subject_concat import (
    DEFAULT_FEATURES_ROOT as DAY24_DEFAULT_FEATURES_ROOT,
    DEFAULT_OUTPUT_SUBDIR as DAY24_OUTPUT_SUBDIR,
    load_subject_boundaries,
    load_subject_concat_manifest,
)
from src.day25_bleed_correction import DEFAULT_DAY25_OUTPUT_SUBDIR
from MDE import MDE

np.random.seed(42)
pd.options.display.max_columns = 120

cfg = load_yaml('configs/demo.yaml')
paths: Dict[str, str] = cfg.get('paths', {}) or {}


In [None]:
SUBJECT_DEFAULT = (cfg.get('subject') or 'UTS01').strip()
STORY_DEFAULT = (cfg.get('story') or 'wheretheressmoke').strip()
TR = float(cfg.get('TR', 2.0))

delta_config = cfg.get('delta', [1])
if isinstance(delta_config, Sequence) and not isinstance(delta_config, (str, bytes)):
    delta_values = [int(v) for v in delta_config]
else:
    delta_values = [int(delta_config)]
delta_default = delta_values[0] if delta_values else 1

lib_sizes_cfg = cfg.get('lib_sizes', [80, 120, 160])
lib_sizes_primary = sorted(int(v) for v in lib_sizes_cfg) if lib_sizes_cfg else [80, 120, 160]
theiler_min = max(int(cfg.get('theiler_min', 3)), 1)
n_parcels = int(cfg.get('n_parcels', 400))

# Batch controls (leave empty to use defaults from the config)
BATCH_SUBJECTS: Sequence[str] = []
BATCH_STORIES: Sequence[str] = []

# Analysis parameters
TARGET_COLUMNS: Sequence[str] = ['cat_abstract']  # extend for additional categories
TAU_GRID = [1, 2]
E_CAP = 6
SPLIT_TRAIN_FRAC = 0.5
SPLIT_VAL_FRAC = 0.25
TOP_N_PLOT = 6
SAVE_INPUT_FRAME = True
SAVE_SCATTER = True

print(f'Default subject/story: {SUBJECT_DEFAULT} / {STORY_DEFAULT}')
print(f'Target categories: {TARGET_COLUMNS}')
print(f'Embedding cap (E): {E_CAP} | tau grid: {TAU_GRID}')


In [None]:
FIGS_ROOT: Optional[Path] = None
FEATURES_ROOT: Optional[Path] = None

from src.day22_category_mde import run_mde_for_pair

def resolve_path(value: Optional[Union[str, Path]], default: str) -> Path:
    if not value:
        candidate = Path(default)
    else:
        candidate = Path(value)
    if not candidate.is_absolute():
        candidate = project_root / candidate
    return candidate

FEATURES_ROOT = project_root / 'features_no_fallback'
if not FEATURES_ROOT.exists():
    raise FileNotFoundError(f"Expected features directory at {FEATURES_ROOT}")

FIGS_ROOT = resolve_path(paths.get('figs'), 'figs')
FIGS_ROOT.mkdir(parents=True, exist_ok=True)

print(f'Features root: {FEATURES_ROOT}')
print(f'Figs root: {FIGS_ROOT}')


In [None]:
def _zscore_vector(arr: np.ndarray) -> np.ndarray:
    arr = np.asarray(arr, dtype=float)
    if arr.size == 0:
        return arr
    mu = np.nanmean(arr)
    sigma = np.nanstd(arr)
    if not np.isfinite(sigma) or sigma <= 1e-12:
        return arr - mu
    return (arr - mu) / sigma


def create_category_comparisons(subject: str, story: str, summaries: Sequence[Dict[str, Any]], figs_root: Path, paths_cfg: Dict[str, str], n_parcels: int) -> None:
    if plt is None:
        return
    valid: List[Dict[str, Any]] = []
    for summary in summaries:
        input_path = summary.get('input_csv')
        selection_path = summary.get('selection_csv')
        if not input_path or not selection_path:
            continue
        input_path = Path(input_path)
        selection_path = Path(selection_path)
        if not input_path.exists() or not selection_path.exists():
            continue
        input_df = pd.read_csv(input_path)
        selection_df = pd.read_csv(selection_path)
        var_col = 'variables' if 'variables' in selection_df.columns else 'variable'
        if selection_df.empty or var_col not in selection_df.columns:
            continue
        top_entry = selection_df.iloc[0]
        top_var = str(top_entry[var_col])
        if top_var not in input_df.columns:
            continue
        rho_cols = [c for c in ('rho', 'rho_val', 'rho_value') if c in selection_df.columns]
        rho_val = float(top_entry[rho_cols[0]]) if rho_cols else float('nan')
        valid.append({
            'summary': summary,
            'category': summary.get('target_column', 'unknown'),
            'input_df': input_df,
            'selection_df': selection_df,
            'top_var': top_var,
            'rho': rho_val,
        })
    if not valid:
        return
    story_dir = figs_root / subject / story / 'day22_category_mde'
    comp_dir = story_dir / 'comparisons'
    comp_dir.mkdir(parents=True, exist_ok=True)
    min_len = min(len(item['input_df']['target']) for item in valid)
    time_axis = np.arange(1, min_len + 1, dtype=int)
    color_map = plt.cm.get_cmap('tab10', max(3, len(valid)))
    records: List[Dict[str, Any]] = []
    for idx, item in enumerate(valid):
        target = item['input_df']['target'].to_numpy(dtype=float)[:min_len]
        predictor = item['input_df'][item['top_var']].to_numpy(dtype=float)[:min_len]
        records.append({
            'category': item['category'],
            'target': target,
            'predictor': predictor,
            'color': color_map(idx),
            'top_var': item['top_var'],
            'rho': item['rho'],
            'selection_df': item['selection_df'],
        })
    rows = len(records)
    fig, axes = plt.subplots(rows, 1, figsize=(12, 2.6 * rows + 1.5), sharex=True)
    if rows == 1:
        axes = [axes]
    for ax, rec in zip(axes, records):
        ax.plot(time_axis, _zscore_vector(rec['target']), color='black', linewidth=1.8, label='target')
        ax.plot(time_axis, _zscore_vector(rec['predictor']), color=rec['color'], linewidth=1.4, label=f"{rec['category']} -> {rec['top_var']}")
        ax.set_ylabel('z')
        ax.legend(loc='upper right', frameon=False, fontsize=9)
        ax.grid(alpha=0.25)
    axes[-1].set_xlabel('Trimmed index')
    fig.tight_layout()
    grid_path = comp_dir / 'category_time_series_grid.png'
    fig.savefig(grid_path, dpi=200)
    plt.show()
    plt.close(fig)
    fig_target, ax_target = plt.subplots(figsize=(12, 4.2))
    for rec in records:
        ax_target.plot(time_axis, _zscore_vector(rec['target']), linewidth=1.2, label=f"{rec['category']} target")
    ax_target.set_title('Category targets (z-scored)')
    ax_target.set_xlabel('Trimmed index')
    ax_target.set_ylabel('z')
    ax_target.grid(alpha=0.3)
    ax_target.legend(loc='upper right', frameon=False, fontsize=9)
    fig_target.tight_layout()
    target_path = comp_dir / 'category_target_overlay.png'
    fig_target.savefig(target_path, dpi=200)
    plt.show()
    plt.close(fig_target)
    fig_pred, ax_pred = plt.subplots(figsize=(12, 4.2))
    for rec in records:
        ax_pred.plot(time_axis, _zscore_vector(rec['predictor']), color=rec['color'], linewidth=1.4, label=f"{rec['category']} ({rec['top_var']})")
    ax_pred.set_title('MDE top predictors across categories (z-scored)')
    ax_pred.set_xlabel('Trimmed index')
    ax_pred.set_ylabel('z')
    ax_pred.grid(alpha=0.3)
    ax_pred.legend(loc='upper right', frameon=False, fontsize=9)
    fig_pred.tight_layout()
    pred_path = comp_dir / 'category_predictor_overlay.png'
    fig_pred.savefig(pred_path, dpi=200)
    plt.show()
    plt.close(fig_pred)
    summary_table = pd.DataFrame({
        'category': [rec['category'] for rec in records],
        'top_variable': [rec['top_var'] for rec in records],
        'rho': [rec['rho'] for rec in records],
    })
    summary_csv = comp_dir / 'category_top_variables.csv'
    summary_table.to_csv(summary_csv, index=False)
    display(summary_table)
    top_rows: List[Dict[str, Any]] = []
    for rec in records:
        sel_df = rec['selection_df']
        var_col = 'variables' if 'variables' in sel_df.columns else 'variable'
        rho_col = next((c for c in ('rho', 'rho_val', 'rho_value') if c in sel_df.columns), None)
        for _, row in sel_df.head(5).iterrows():
            rho_val = float(row[rho_col]) if rho_col else float('nan')
            top_rows.append({
                'category': rec['category'],
                'variable': row[var_col],
                'rho': rho_val,
            })
    if top_rows:
        top_df = pd.DataFrame(top_rows)
        top_df.to_csv(comp_dir / 'category_top5_variables.csv', index=False)
        display(top_df)
    combined_html = save_combined_roi_view(
        output_dir=story_dir,
        subject=subject,
        story=story,
        summaries=summaries,
        paths_cfg=paths_cfg,
        n_parcels=n_parcels,
    )
    if combined_html:
        print(f'Combined ROI view saved to {combined_html}')
    features_root = (FEATURES_ROOT or (project_root / DAY24_DEFAULT_FEATURES_ROOT)).resolve()
    concat_dir = features_root / 'subjects' / subject / DAY24_OUTPUT_SUBDIR
    concat_manifest = load_subject_concat_manifest(
        subject,
        features_root=features_root,
        output_subdir=DAY24_OUTPUT_SUBDIR,
    )
    boundaries_df = load_subject_boundaries(
        subject,
        features_root=features_root,
        output_subdir=DAY24_OUTPUT_SUBDIR,
    )
    has_boundaries = bool(boundaries_df is not None and not boundaries_df.empty)
    context_payload: Dict[str, Any] = {
        'subject': subject,
        'story': story,
        'concat_dir': str(concat_dir),
        'has_manifest': concat_manifest is not None,
        'has_boundaries': bool(has_boundaries),
    }
    if concat_manifest:
        context_payload['combined_paths'] = {
            'categories': concat_manifest.get('categories_path'),
            'roi': concat_manifest.get('roi_path'),
            'boundaries': concat_manifest.get('boundaries_path'),
            'manifest': str(concat_dir / 'manifest.json'),
        }
        context_payload['lag_reset_indices'] = concat_manifest.get('lag_reset_indices', [])
    day25_dir = features_root / 'subjects' / subject / DEFAULT_DAY25_OUTPUT_SUBDIR
    day25_entries = []
    if day25_dir.exists():
        for manifest_file in sorted(day25_dir.glob('*/manifest.json')):
            manifest_data_corr = json.loads(manifest_file.read_text())
            strategy = manifest_data_corr.get('bleed_correction', {}).get('strategy') or manifest_file.parent.name
            day25_entries.append({
                'strategy': strategy,
                'manifest': str(manifest_file),
                'categories': manifest_data_corr.get('categories_path'),
                'roi': manifest_data_corr.get('roi_path'),
            })
    if day25_entries:
        context_payload['day25_corrections'] = day25_entries
    if boundaries_df is not None and not boundaries_df.empty:
        match = boundaries_df[boundaries_df['story'] == story]
        if not match.empty:
            row = match.iloc[0]
            context_payload['story_boundary'] = {
                'start_index': int(row['start_index']),
                'end_index': int(row['end_index']),
                'length': int(row['length']),
                'lag_reset_after': int(row['end_index'] + 1),
            }
    context_path = comp_dir / 'story_concat_context.json'
    context_path.write_text(json.dumps(context_payload, indent=2))
    if concat_manifest:
        print(f'Day24 subject concat context saved to {context_path}')
    else:
        print('Day24 subject concat manifest not found; wrote context stub for reference.')

In [None]:
summary_by_story: Dict[Tuple[str, str], List[Dict[str, Any]]] = defaultdict(list)
subjects = sorted({s.strip() for s in (BATCH_SUBJECTS or [SUBJECT_DEFAULT]) if s.strip()})
stories = sorted({s.strip() for s in (BATCH_STORIES or [STORY_DEFAULT]) if s.strip()})
targets = [col.strip() for col in TARGET_COLUMNS if col.strip()]
if not targets:
    raise ValueError('No target columns specified.')

results: List[Dict[str, Any]] = []
for sub in subjects:
    for story in stories:
        for target in targets:
            print(f'\n=== Running MDE for {sub}/{story} (target={target}) ===')
            try:
                summary = run_mde_for_pair(
                    sub,
                    story,
                    target_column=target,
                    features_root=FEATURES_ROOT,
                    figs_root=FIGS_ROOT,
                    paths_cfg=paths,
                    n_parcels=n_parcels,
                    tau_grid=TAU_GRID,
                    E_cap=E_CAP,
                    lib_sizes=lib_sizes_primary,
                    delta_default=delta_default,
                    theiler_min=theiler_min,
                    train_frac=SPLIT_TRAIN_FRAC,
                    val_frac=SPLIT_VAL_FRAC,
                    top_n_plot=TOP_N_PLOT,
                    save_input_frame=SAVE_INPUT_FRAME,
                    save_scatter=SAVE_SCATTER,
                    plt_module=plt,
                )
            except Exception as exc:
                print(f'No output generated: {exc}')
                continue
            results.append(summary)
            summary_by_story[(sub, story)].append(summary)
            print(f"Top variables: {summary['top_variables']}")

if results:
    display(pd.DataFrame(results))
    for (sub, story), summaries in summary_by_story.items():
        try:
            create_category_comparisons(sub, story, summaries, FIGS_ROOT, paths, n_parcels)
        except Exception as exc:
            print(f'Comparison generation failed for {sub}/{story}: {exc}')
else:
    print('No successful MDE runs.')
