# Day 16 – Decoding & Evaluation

Extends the ROI forecasting pipeline with text decoding, encoding-based reranking, and evaluation controls.

In [None]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd

project_root = Path('/flash/PaoU/seann/fmri-edm-ccm')
os.chdir(project_root)

sys.path.append(str(project_root))
sys.path.append('/flash/PaoU/seann/pyEDM/src')
sys.path.append('/flash/PaoU/seann/MDE-main/src')

from pyEDM import Simplex, SMap, CCM, ComputeError
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.preprocessing import StandardScaler
from MDE import MDE

from src.utils import load_yaml
from src import features, roi
from src.edm_ccm import English1000Loader
from src.decoding import (
    load_transcript_words,
    make_tr_windows,
    reference_text_windows,
    PCProjector,
    ROIEncoder,
    pc_encoding_score,
    roi_encoding_score,
    eval_text_list,
    identification_matrix,
)

try:
    from src.decoding import BeamDecoder
except ImportError:
    BeamDecoder = None
from IPython.display import display


In [None]:
cfg = load_yaml('configs/demo.yaml')
dec_cfg = cfg.get('decoding', {})
SUB, STORY = cfg['subject'], cfg['story']
paths = cfg['paths']

target_type = str(dec_cfg.get('target_type', 'pcs')).lower()
target_basis_for_selection = str(dec_cfg.get('target_basis_for_selection', 'embedding')).lower()
selection_val_windows = int(dec_cfg.get('selection_val_windows', 40))
selection_metric = str(dec_cfg.get('selection_metric', 'cca_mean')).lower()
selection_cca_components = int(dec_cfg.get('selection_cca_components', 3))
selection_every_k_steps = max(1, int(dec_cfg.get('selection_every_k_steps', 1)))
selection_timeout_s = dec_cfg.get('selection_timeout_s')
if selection_timeout_s is not None:
    selection_timeout_s = float(selection_timeout_s)
use_topk_pcs_cfg = int(dec_cfg.get('use_topk_pcs', 5))

weights_cfg = dec_cfg.get('weights', {}) or {}
selection_weights = {
    'embedding': float(weights_cfg.get('embedding', 1.0)),
    'categories': float(weights_cfg.get('categories', 0.0)),
    'roi': float(weights_cfg.get('roi', 0.0)),
}

print(f"Subject/Story: {SUB} {STORY}")
print('Decoding config:', dec_cfg)

X_full = features.load_english1000_TR(SUB, STORY, paths)
semantic_dim = X_full.shape[1]

env = pd.Series(features.load_envelope_TR(SUB, STORY, paths), name='env')
wr = pd.Series(features.load_wordrate_TR(SUB, STORY, paths), name='wr')
R = roi.load_schaefer_timeseries_TR(SUB, STORY, cfg['n_parcels'], paths)

output_root = Path(paths['figs']) / SUB / STORY / 'day16_decoding'
output_root.mkdir(parents=True, exist_ok=True)

english1000_path = Path(paths['data_root']) / 'derivative' / 'english1000sm.hf5'
if english1000_path.exists():
    english_loader = English1000Loader(english1000_path)
    print('Loaded English1000 embeddings from', english1000_path)
else:
    english_loader = None
    print('WARNING: English1000 embeddings not found at', english1000_path)

print('Semantic matrix:', X_full.shape)
print('Drivers:', env.shape, wr.shape)
print('ROI matrix:', R.shape)


In [None]:
tau_grid = [1, 2]
delta_options = [1, 2]
delta_default = int(dec_cfg.get('delta_default', 1))
E_univ_grid = [2, 3, 4]
E_cap = 6
theiler_min = max(int(cfg.get('theiler_min', 3)), 1)
lib_sizes_primary = sorted(int(v) for v in cfg['lib_sizes'])

roi_cols = [f'roi_{idx}' for idx in range(R.shape[1])]
max_tau = max(tau_grid)
max_lag_primary = max_tau * (E_cap - 1)
if X_full.shape[0] <= max_lag_primary:
    raise ValueError('Time series too short for requested lag configuration')

N_total = X_full.shape[0]
N_trim = N_total - max_lag_primary


def make_splits(n_samples):
    train_end = max(1, int(np.floor(n_samples * 0.5)))
    val_span = max(1, int(np.floor(n_samples * 0.25)))
    val_end = min(n_samples - 1, train_end + val_span)
    if val_end <= train_end:
        val_end = min(n_samples - 1, train_end + 1)
    test_end = n_samples
    if test_end <= val_end:
        test_end = min(n_samples, val_end + 1)
    return {
        'train': (0, train_end),
        'val': (train_end, val_end),
        'test': (val_end, test_end),
    }

splits = make_splits(N_trim)
print('Split indices:', splits)

train_end_trim = splits['train'][1]
train_end_global = max_lag_primary + train_end_trim
Z_train, pca = features.pca_fit_transform(X_full[:train_end_global], cfg['pca_components'])
Z = pca.transform(X_full)

np.save(output_root / 'pca_components.npy', pca.components_)
np.save(output_root / 'pca_mean.npy', pca.mean_)

sem_pc1 = pd.Series(Z[:, 0], name='sem_pc1')
base = pd.DataFrame({
    'Time': np.arange(1, N_total + 1),
    'sem_pc1': sem_pc1,
    'env': env,
    'wr': wr,
})
for idx, col in enumerate(roi_cols):
    base[col] = R[:, idx]


def make_lag_dict(series, max_lag):
    return {
        lag: series.shift(lag).iloc[max_lag:].reset_index(drop=True)
        for lag in range(max_lag + 1)
    }

lag_store = {name: make_lag_dict(base[name], max_lag_primary)
             for name in ['sem_pc1', 'env', 'wr'] + roi_cols}

time_trim = base['Time'].iloc[max_lag_primary:].reset_index(drop=True)
target_trim = lag_store['sem_pc1'][0]
N = len(time_trim)
print(f'Usable samples after lag trimming: {N}')

sem_trim_all = Z[max_lag_primary:, :]
roi_trim_all = R[max_lag_primary:, :]
trim_indices = np.arange(max_lag_primary, N_total)

embedding_dim_trim = sem_trim_all.shape[1]
topk = max(1, min(use_topk_pcs_cfg, embedding_dim_trim))
if topk < use_topk_pcs_cfg:
    print(f'Adjusted top-{use_topk_pcs_cfg} PCs to available dimension {embedding_dim_trim}')
trimmed_pc = sem_trim_all[:, :topk]
pc_columns = [f'sem_pc{i+1}' for i in range(topk)]
trimmed_pc_df = pd.DataFrame(trimmed_pc, columns=pc_columns)
trimmed_pc_df.insert(0, 'trim_index', np.arange(len(trimmed_pc)))
trimmed_pc_df.to_csv(output_root / 'semantic_pcs_trimmed.csv', index=False)


def format_state(state_spec):
    return ' | '.join(f'{var}:{lag}' for var, lag in state_spec)


def evaluate_state(state_spec, phase='select', delta_step=1,
                   data_store=lag_store, splits=splits,
                   theiler=theiler_min, lib_sizes=lib_sizes_primary):
    if not state_spec:
        raise ValueError('state_spec must contain at least one coordinate')
    df = pd.DataFrame({'Time': time_trim.values, 'target': target_trim.values})
    max_lag_state = 0
    columns = []
    for var, lag in state_spec:
        if var not in data_store:
            raise KeyError(f'{var} not in data_store')
        if lag not in data_store[var]:
            raise KeyError(f'{var} lag {lag} unavailable')
        col_name = f'{var}_lag{lag}'
        df[col_name] = data_store[var][lag].values
        columns.append(col_name)
        max_lag_state = max(max_lag_state, lag)
    train_slice = slice(*splits['train'])
    stats = {}
    for col in ['target'] + columns:
        mu = df.loc[train_slice, col].mean()
        sigma = df.loc[train_slice, col].std(ddof=0)
        if sigma == 0 or np.isnan(sigma):
            sigma = 1.0
        df[col] = (df[col] - mu) / sigma
        stats[col] = {'mean': float(mu), 'std': float(sigma)}
    if phase == 'select':
        lib_range = (splits['train'][0] + 1, splits['train'][1])
        pred_range = (splits['val'][0] + 1, splits['val'][1])
    elif phase == 'final':
        lib_range = (splits['train'][0] + 1, splits['val'][1])
        pred_range = (splits['test'][0] + 1, splits['test'][1])
    else:
        raise ValueError(f'Unknown phase {phase}')
    if lib_range[1] <= lib_range[0] or pred_range[1] <= pred_range[0]:
        return {'rho': np.nan, 'df': df, 'result': pd.DataFrame(), 'stats': stats,
                'lib': lib_range, 'pred': pred_range, 'columns': columns,
                'exclusion': max_lag_state + delta_step, 'phase': phase,
                'delta': delta_step, 'state': state_spec}
    train_len = splits['train'][1] - splits['train'][0]
    knn = max(2, min(len(state_spec) + 1, max(2, train_len - 1)))
    exclusion = max(max_lag_state + delta_step, theiler)
    try:
        simplex_df = Simplex(
            dataFrame=df[['Time', 'target'] + columns],
            columns=' '.join(columns),
            target='target',
            lib=f'{lib_range[0]} {lib_range[1]}',
            pred=f'{pred_range[0]} {pred_range[1]}',
            E=len(state_spec),
            Tp=delta_step,
            tau=0,
            knn=knn,
            exclusionRadius=exclusion,
            embedded=True,
            verbose=False
        )
        if simplex_df.empty:
            rho = np.nan
        else:
            err = ComputeError(simplex_df['Observations'], simplex_df['Predictions'])
            rho = float(err.get('rho', np.nan))
    except Exception as exc:
        print(f'Simplex failed for state {format_state(state_spec)} ({phase}): {exc}')
        simplex_df = pd.DataFrame()
        rho = np.nan
    return {
        'rho': rho,
        'df': df[['Time', 'target'] + columns],
        'result': simplex_df,
        'stats': stats,
        'lib': lib_range,
        'pred': pred_range,
        'columns': columns,
        'exclusion': exclusion,
        'phase': phase,
        'delta': delta_step,
        'state': state_spec,
    }


In [None]:
lagged_roi_cols = {}
for roi_name in roi_cols:
    for lag in range(0, max_lag_primary + 1):
        lagged_roi_cols[f'{roi_name}_lag{lag}'] = lag_store[roi_name][lag].values

mde_df = pd.DataFrame({'sem_pc1': target_trim.values})
for col_name, values in lagged_roi_cols.items():
    mde_df[col_name] = values

N_trim = len(mde_df)
lib_span = [splits['train'][0] + 1, splits['train'][1]]
pred_span = [splits['val'][0] + 1, splits['val'][1]]


def clamp_span(span, max_len):
    start, end = span
    start = min(max(1, start), max_len)
    end = min(max(1, end), max_len)
    if end < start:
        end = start
    return [start, end]

lib_span = clamp_span(lib_span, N_trim)
pred_span = clamp_span(pred_span, N_trim)

p_lib_sizes = sorted({max(1, min(99, int(round(size / N_trim * 100)))) for size in lib_sizes_primary if size < N_trim})
if not p_lib_sizes:
    p_lib_sizes = [10, 25, 50, 75]

print('Trimmed samples for MDE:', N_trim)
print('Library span:', lib_span, 'Prediction span:', pred_span)
print('pLibSizes (%):', p_lib_sizes)

mde = MDE(
    dataFrame=mde_df,
    target='sem_pc1',
    removeColumns=['sem_pc1'],
    D=E_cap,
    lib=lib_span,
    pred=pred_span,
    Tp=delta_default,
    tau=-1,
    exclusionRadius=theiler_min,
    sample=5,
    pLibSizes=p_lib_sizes,
    ccmSlope=0.0,
    crossMapRhoMin=0.0,
    embedDimRhoMin=0.0,
    cores=1,
    noTime=True,
    verbose=False,
    consoleOut=False
)
mde.Run()

mde_output = mde.MDEOut.copy()
mde_output.insert(0, 'step', range(1, len(mde_output) + 1))


def base_roi_name(var_name):
    return var_name.split('_lag')[0] if '_lag' in var_name else var_name

mde_variables = mde_output['variables'].tolist()
mde_roi_bases = [base_roi_name(var) for var in mde_variables if var.startswith('roi_')]
mde_roi_labels = sorted({int(name.split('_')[1]) for name in mde_roi_bases})
print('MDE-selected variables:', mde_variables)

mde_rank_df = mde_output.rename(columns={'variables': 'variable', 'rho': 'rho_val'})
mde_rank_df.to_csv(output_root / 'mde_rank.csv', index=False)

# ROI univariate baseline
roi_univ_records = []
for roi_name in roi_cols:
    for tau in tau_grid:
        for E in E_univ_grid:
            lags = [tau * k for k in range(E)]
            if lags[-1] > max_lag_primary:
                continue
            state = [(roi_name, lag) for lag in lags]
            res = evaluate_state(state, phase='select')
            roi_univ_records.append({
                'roi': roi_name,
                'tau': tau,
                'E': E,
                'state': state,
                'rho_val': res['rho'],
            })

roi_univ_df = pd.DataFrame(roi_univ_records).sort_values('rho_val', ascending=False).reset_index(drop=True)
best_univ_row = roi_univ_df.iloc[0]
roi_univ_state = best_univ_row['state']
roi_univ_select = evaluate_state(roi_univ_state, phase='select')
roi_univ_final = evaluate_state(roi_univ_state, phase='final')
print('ROI-Univariate best state:', format_state(roi_univ_state))
print('Validation rho:', roi_univ_select['rho'])
print('Test rho:', roi_univ_final['rho'])

# ROI multivariate candidates from MDE order
roi_multi_candidates = []
for var in mde_variables:
    if not var.startswith('roi_'):
        continue
    if '_lag' in var:
        base, lag = var.rsplit('_lag', 1)
        roi_multi_candidates.append((base, int(lag)))
    else:
        roi_multi_candidates.append((var, 0))

print('ROI-Multivariate candidate coordinate count:', len(roi_multi_candidates))
if roi_multi_candidates:
    print('First candidates:', format_state(roi_multi_candidates[:min(3, len(roi_multi_candidates))]))

roi_multi_state = list(roi_multi_candidates)
roi_multi_select = None
roi_multi_final = None


In [None]:
# Initialize placeholders for ROI decoding artifacts; populated after selection
sem_pred_trim = None
roi_state_matrix = None
roi_state_columns = []
state_scaler = None
ridge_semantic = None
mde_steps = []


In [None]:
tr_s = float(dec_cfg.get('tr_seconds', cfg['TR']))
window_len_tr = int(dec_cfg.get('window_len_tr', 6))
stride_tr = int(dec_cfg.get('stride_tr', 3))
hrf_shift = int(dec_cfg.get('hrf_shift_tr', 0))

windows_tr = make_tr_windows(N, tr_s, window_len_tr, stride_tr, hrf_shift)
window_records = []
for idx, (start, end) in enumerate(windows_tr):
    record = {
        'window_index': idx,
        'start_trim': start,
        'end_trim': end,
        'start_global_tr': max_lag_primary + start,
        'end_global_tr': max_lag_primary + end,
        'start_sec': (max_lag_primary + start) * tr_s,
        'end_sec': (max_lag_primary + end) * tr_s,
    }
    window_records.append(record)

windows_df = pd.DataFrame(window_records)
windows_df.to_csv(output_root / 'tr_windows.csv', index=False)
print(f'Generated {len(windows_df)} windows.')

try:
    word_events = load_transcript_words(paths, SUB, STORY)
    print(f'Loaded {len(word_events)} transcript intervals.')
except FileNotFoundError as exc:
    print(f'Transcript unavailable: {exc}')
    word_events = []

reference_texts = []
reference_map = {}
if word_events:
    windows_sec = [(row['start_sec'], row['end_sec']) for row in window_records]
    reference_texts = reference_text_windows(word_events, windows_sec)
    ref_df = pd.DataFrame({'window_index': windows_df['window_index'], 'reference_text': reference_texts})
    ref_df.to_csv(output_root / 'reference_texts.csv', index=False)
    reference_map = dict(zip(windows_df['window_index'], reference_texts))

win_test_mask = (windows_df['start_trim'] >= splits['test'][0]) & (windows_df['end_trim'] <= splits['test'][1])
test_windows_df = windows_df[win_test_mask].reset_index(drop=True)
test_windows_df.to_csv(output_root / 'test_windows.csv', index=False)
print(f'Test windows: {len(test_windows_df)}')

if reference_map:
    test_reference_texts = [reference_map.get(idx, '') for idx in test_windows_df['window_index']]
else:
    test_reference_texts = []

val_mask = (windows_df['start_trim'] >= splits['val'][0]) & (windows_df['end_trim'] <= splits['val'][1])
val_candidates = windows_df[val_mask].copy()
if selection_val_windows > 0 and len(val_candidates) > selection_val_windows:
    start_idx = max(0, (len(val_candidates) - selection_val_windows) // 2)
    val_candidates = val_candidates.iloc[start_idx:start_idx + selection_val_windows].copy()

if val_candidates.empty and selection_val_windows > 0:
    non_test_mask = windows_df['end_trim'] <= splits['test'][0]
    fallback_df = windows_df[non_test_mask].copy()
    if fallback_df.empty:
        fallback_df = windows_df.copy()
    if len(fallback_df) > selection_val_windows:
        start_idx = max(0, (len(fallback_df) - selection_val_windows) // 2)
        fallback_df = fallback_df.iloc[start_idx:start_idx + selection_val_windows].copy()
    val_candidates = fallback_df

validation_windows_df = val_candidates.reset_index(drop=True)
validation_windows_df.to_csv(output_root / 'validation_windows.csv', index=False)
print('Validation windows:', len(validation_windows_df))

if reference_map:
    validation_reference_texts = [reference_map.get(idx, '') for idx in validation_windows_df['window_index']]
else:
    validation_reference_texts = ['' for _ in range(len(validation_windows_df))]

validation_meta = []
for row in validation_windows_df.itertuples():
    start = int(row.start_trim)
    end = int(row.end_trim)
    n_tr = end - start
    if n_tr <= 0:
        continue
    tr_edges = np.linspace(row.start_sec, row.end_sec, n_tr + 1)
    validation_meta.append({
        'window_index': int(row.window_index),
        'trim_slice': (start, end),
        'tr_edges': tr_edges,
        'observed_embedding': sem_trim_all[start:end],
        'observed_pc': trimmed_pc[start:end],
    })
print('Validation windows prepared for selection:', len(validation_meta))



In [None]:
embedding_projector_full = PCProjector(pca.components_, pca.mean_)
pc_projector = PCProjector(pca.components_[:topk], pca.mean_)
selection_projector = embedding_projector_full if target_basis_for_selection == 'embedding' else pc_projector
selection_embedding_method = 'cosine' if selection_metric == 'cosine_mean' else 'cca'
print(f"Selection basis: {target_basis_for_selection}; projector dims = {selection_projector.n_components}")


def _distribute_word_times(count: int, start: float, end: float):
    if count <= 0:
        return []
    duration = max(end - start, 1e-6)
    step = duration / count
    return [(start + i * step, start + (i + 1) * step) for i in range(count)]


def aggregate_candidate_series(
    text: str,
    tr_edges: np.ndarray,
    projector: PCProjector,
    max_components: int | None = None,
) -> np.ndarray | None:
    if english_loader is None:
        return None
    raw_tokens = [tok for tok in text.strip().split() if tok]
    if not raw_tokens:
        return None
    lookup = getattr(english_loader, 'lookup', {})
    vectors = [lookup.get(tok.lower()) for tok in raw_tokens]
    vectors = [vec for vec in vectors if vec is not None]
    if not vectors:
        return None
    word_matrix = np.vstack(vectors)
    pc_words = projector.word_to_pc(word_matrix)
    if max_components is not None and pc_words.shape[1] > max_components:
        pc_words = pc_words[:, :max_components]
    word_times = _distribute_word_times(len(vectors), float(tr_edges[0]), float(tr_edges[-1]))
    return projector.aggregate_to_TR(pc_words, word_times, tr_edges)


def compute_embedding_score(target: np.ndarray, candidate: np.ndarray, method: str, cca_components: int) -> float:
    score = pc_encoding_score(target, candidate, method=method, cca_components=cca_components)
    return float(score) if np.isfinite(score) else float('-inf')


def compute_roi_score(target: np.ndarray, candidate: np.ndarray, method: str, cca_components: int) -> float:
    score = roi_encoding_score(target, candidate, method=method, cca_components=cca_components)
    return float(score) if np.isfinite(score) else float('-inf')


def category_scorer(text: str, tr_edges: np.ndarray, category_target: np.ndarray | None) -> float:
    _ = (text, tr_edges, category_target)
    return float('-inf')


def combined_score(
    text: str,
    tr_edges: np.ndarray,
    *,
    embedding_target: np.ndarray | None = None,
    embedding_projector: PCProjector | None = None,
    embedding_method: str = 'cca',
    cca_components: int = 5,
    category_target: np.ndarray | None = None,
    roi_target: np.ndarray | None = None,
    roi_encoder: ROIEncoder | None = None,
    roi_method: str = 'mean',
    roi_cca_components: int | None = None,
    weights: dict | None = None,
    cache: dict | None = None,
) -> float:
    weights = weights or selection_weights
    score_sum = 0.0
    contributed = False
    edges_key = (len(tr_edges), float(tr_edges[-1] - tr_edges[0]))

    if embedding_target is not None and embedding_projector is not None:
        weight = float(weights.get('embedding', 0.0))
        if weight > 0:
            max_components = embedding_target.shape[1]
            cache_key = ('embedding', text, edges_key, embedding_projector.n_components, max_components)
            candidate = None
            if cache is not None and cache_key in cache:
                candidate = cache[cache_key]
            else:
                candidate = aggregate_candidate_series(text, tr_edges, embedding_projector, max_components=max_components)
                if cache is not None:
                    cache[cache_key] = candidate
            if candidate is not None and candidate.shape[1] >= max_components:
                candidate_use = candidate[:, :max_components]
                emb_score = compute_embedding_score(embedding_target, candidate_use, embedding_method, cca_components)
                if np.isfinite(emb_score):
                    score_sum += weight * emb_score
                    contributed = True

    if category_target is not None:
        weight = float(weights.get('categories', 0.0))
        if weight > 0:
            cat_score = category_scorer(text, tr_edges, category_target)
            if np.isfinite(cat_score):
                score_sum += weight * cat_score
                contributed = True

    if roi_target is not None and roi_encoder is not None:
        weight = float(weights.get('roi', 0.0))
        if weight > 0:
            cache_key = ('roi', text, edges_key, pc_projector.n_components)
            candidate = None
            if cache is not None and cache_key in cache:
                candidate = cache[cache_key]
            else:
                candidate = aggregate_candidate_series(text, tr_edges, pc_projector, max_components=pc_projector.n_components)
                if cache is not None:
                    cache[cache_key] = candidate
            if candidate is not None and candidate.size > 0:
                try:
                    roi_pred = roi_encoder.predict(candidate)
                except Exception:
                    roi_pred = None
                if roi_pred is not None:
                    cca_rois = roi_cca_components if roi_cca_components is not None else min(roi_target.shape[1], cca_components)
                    roi_score = compute_roi_score(roi_target, roi_pred, roi_method, cca_rois)
                    if np.isfinite(roi_score):
                        score_sum += weight * roi_score
                        contributed = True

    if not contributed:
        return float('-inf')
    return float(score_sum)



In [None]:
from time import perf_counter

selection_records = []
best_metric = float('-inf')
best_state_metric = []
best_rho_val = float('-inf')
best_state_rho = list(roi_multi_state)
timed_out_any = False

candidate_pool = reference_texts if reference_texts else []
if english_loader is None:
    print('WARNING: English1000 unavailable; decoding-metric selection disabled. Using rho fallback.')
if validation_meta and not candidate_pool:
    print('Validation selection skipped: candidate text pool empty.')
selection_candidate_cache = {}
candidate_index_map = {text: idx for idx, text in enumerate(candidate_pool)}

ridge_alpha = float(dec_cfg.get('pc_decoder_reg', 1.0))


def evaluate_decoding_state(state):
    if not validation_meta or not candidate_pool:
        return float('nan'), False
    start_time = perf_counter()
    cols = [f'{var}_lag{lag}' for var, lag in state]
    matrix = np.column_stack([lag_store[var][lag].values for var, lag in state])
    scaler = StandardScaler()
    X_train = scaler.fit_transform(matrix[slice(*splits['train'])])
    X_all = scaler.transform(matrix)
    target_array = sem_trim_all if target_basis_for_selection == 'embedding' else trimmed_pc
    ridge = Ridge(alpha=ridge_alpha)
    ridge.fit(X_train, target_array[slice(*splits['train'])])
    pred_all = ridge.predict(X_all)

    per_window_scores = []
    score_matrix = []
    for meta, ref_text in zip(validation_meta, validation_reference_texts):
        if selection_timeout_s is not None and (perf_counter() - start_time) > selection_timeout_s:
            return float('nan'), True
        start, end = meta['trim_slice']
        tr_edges = meta['tr_edges']
        embedding_target = pred_all[start:end]
        if embedding_target.ndim != 2 or embedding_target.shape[0] != len(tr_edges) - 1:
            per_window_scores.append(float('-inf'))
            score_matrix.append([float('-inf')] * len(candidate_pool))
            continue
        if embedding_target.size == 0:
            per_window_scores.append(float('-inf'))
            score_matrix.append([float('-inf')] * len(candidate_pool))
            continue
        if target_basis_for_selection == 'pcs' and embedding_target.shape[1] > topk:
            embedding_target = embedding_target[:, :topk]
        elif target_basis_for_selection == 'embedding':
            observed_dim = meta['observed_embedding'].shape[1]
            if embedding_target.shape[1] > observed_dim:
                embedding_target = embedding_target[:, :observed_dim]
        row_scores = []
        best_window_score = float('-inf')
        for cand_text in candidate_pool:
            score = combined_score(
                cand_text,
                tr_edges,
                embedding_target=embedding_target,
                embedding_projector=selection_projector,
                embedding_method=selection_embedding_method,
                cca_components=selection_cca_components,
                weights=selection_weights,
                cache=selection_candidate_cache,
            )
            row_scores.append(score)
            if score > best_window_score:
                best_window_score = score
        per_window_scores.append(best_window_score)
        score_matrix.append(row_scores)

    scores_arr = np.asarray(per_window_scores, dtype=float)
    scores_arr[~np.isfinite(scores_arr)] = np.nan
    if selection_metric in {'cca_mean', 'cosine_mean'}:
        metric_value = float(np.nanmean(scores_arr)) if np.isfinite(scores_arr).any() else float('nan')
    elif selection_metric == 'ident_diag_pct':
        diag_percentiles = []
        for row_scores, ref_text in zip(score_matrix, validation_reference_texts):
            if not ref_text:
                continue
            col_idx = candidate_index_map.get(ref_text)
            if col_idx is None or col_idx >= len(row_scores):
                continue
            row_arr = np.asarray(row_scores, dtype=float)
            value = row_arr[col_idx]
            finite = row_arr[np.isfinite(row_arr)]
            if finite.size == 0 or not np.isfinite(value):
                continue
            diag_percentiles.append(float(np.mean(finite <= value)))
        metric_value = float(np.mean(diag_percentiles)) if diag_percentiles else float('nan')
    else:
        metric_value = float(np.nanmean(scores_arr)) if np.isfinite(scores_arr).any() else float('nan')
    return metric_value, False


cumulative_state = []
for step_idx, coord in enumerate(roi_multi_state, start=1):
    cumulative_state = cumulative_state + [coord]
    select_eval = evaluate_state(cumulative_state, phase='select')
    final_eval = evaluate_state(cumulative_state, phase='final')
    rho_val = float(select_eval['rho']) if select_eval else float('nan')
    if np.isfinite(rho_val) and rho_val > best_rho_val:
        best_rho_val = rho_val
        best_state_rho = list(cumulative_state)
    delta_val = rho_val - (selection_records[-1]['rho_val'] if selection_records else 0.0)

    metric_value = float('nan')
    timed_out = False
    should_eval = bool(validation_meta and candidate_pool and ((step_idx % selection_every_k_steps == 0) or (step_idx == len(roi_multi_state))))
    if should_eval:
        metric_value, timed_out = evaluate_decoding_state(cumulative_state)
        if timed_out:
            timed_out_any = True
        if not timed_out and np.isfinite(metric_value) and (not np.isfinite(best_metric) or metric_value > best_metric):
            best_metric = metric_value
            best_state_metric = list(cumulative_state)

    selection_records.append({
        'step': step_idx,
        'roi': coord[0],
        'lag': coord[1],
        'rho_val': rho_val,
        'rho_test': float(final_eval['rho']) if final_eval else float('nan'),
        'delta_rho_val': delta_val,
        'state': format_state(cumulative_state),
        'selection_metric': metric_value,
        'selection_metric_name': selection_metric,
        'selection_timed_out': timed_out,
        'evaluated': should_eval,
    })

if roi_multi_state and timed_out_any:
    print('Decoding selection timed out; falling back to rho-based criterion.')

metric_selected = roi_multi_state and not timed_out_any and np.isfinite(best_metric)
if metric_selected and best_state_metric:
    roi_multi_state = best_state_metric
    print('Selected ROI-Multivariate state by decoding metric:', format_state(roi_multi_state))
    print(f"Selection metric ({selection_metric}):", best_metric)
else:
    roi_multi_state = best_state_rho if best_state_rho else list(roi_multi_state)
    metric_selected = False
    print('Selected ROI-Multivariate state by rho-based fallback:', format_state(roi_multi_state))

roi_multi_select = evaluate_state(roi_multi_state, phase='select') if roi_multi_state else {'rho': np.nan}
roi_multi_final = evaluate_state(roi_multi_state, phase='final') if roi_multi_state else {'rho': np.nan}

final_eval = evaluate_state(roi_multi_state, phase='final') if roi_multi_state else None
if final_eval and not final_eval['result'].empty:
    pred_df = final_eval['result'].copy()
    pred_df.rename(columns={'Predictions': 'sem_pc1_pred', 'Observations': 'sem_pc1_obs'}, inplace=True)
    pred_df.to_csv(output_root / 'sem_pc1_predictions.csv', index=False)

roi_state_matrix = None
roi_state_columns = []
state_scaler = None
ridge_semantic = None
sem_pred_trim = None

if roi_multi_state:
    roi_state_columns = [f'{var}_lag{lag}' for var, lag in roi_multi_state]
    roi_state_matrix = np.column_stack([lag_store[var][lag].values for var, lag in roi_multi_state])

    state_scaler = StandardScaler()
    X_train = state_scaler.fit_transform(roi_state_matrix[slice(*splits['train'])])
    X_all = state_scaler.transform(roi_state_matrix)

    ridge_semantic = Ridge(alpha=ridge_alpha)
    ridge_semantic.fit(X_train, trimmed_pc[slice(*splits['train'])])
    sem_pred_trim = ridge_semantic.predict(X_all)

    sem_pred_df = pd.DataFrame(sem_pred_trim, columns=pc_columns)
    sem_pred_df.insert(0, 'trim_index', np.arange(len(sem_pred_trim)))
    sem_pred_df.to_csv(output_root / 'semantic_pc_predictions.csv', index=False)

    roi_state_df = pd.DataFrame(roi_state_matrix, columns=roi_state_columns)
    roi_state_df.insert(0, 'trim_index', np.arange(len(roi_state_df)))
    roi_state_df.to_csv(output_root / 'roi_state_series.csv', index=False)
else:
    print('No ROI multivariate state selected; semantic predictions unavailable.')

mde_steps = selection_records
mde_steps_df = pd.DataFrame(selection_records)
mde_steps_df.to_csv(output_root / 'mde_path.csv', index=False)

primary_rows = [
    {
        'label': 'ROI-Univariate',
        'rho_test': roi_univ_final['rho'],
        'rho_val': roi_univ_select['rho'],
        'E': len(roi_univ_state),
        'tau_sel': best_univ_row['tau'],
        'Tp': delta_default,
        'state': format_state(roi_univ_state),
        'selection_metric': float('nan'),
        'selection_basis': 'univariate',
    }
]
primary_rows.append({
    'label': 'ROI-Multivariate (MDE)',
    'rho_test': roi_multi_final['rho'] if roi_multi_final else np.nan,
    'rho_val': roi_multi_select['rho'] if roi_multi_select else np.nan,
    'E': len(roi_multi_state),
    'tau_sel': 'MDE-derived',
    'Tp': delta_default,
    'state': format_state(roi_multi_state),
    'selection_metric': best_metric if metric_selected else float('nan'),
    'selection_basis': target_basis_for_selection,
})
primary_results_df = pd.DataFrame(primary_rows)
primary_results_df.to_csv(output_root / 'primary_results.csv', index=False)
primary_results_df

# Build test targets now that semantic predictions are available
test_targets = []
for row in test_windows_df.itertuples():
    start = int(row.start_trim)
    end = int(row.end_trim)
    n_tr = end - start
    if n_tr <= 0:
        continue
    tr_edges = np.linspace(row.start_sec, row.end_sec, n_tr + 1)
    pc_target = sem_pred_trim[start:end] if sem_pred_trim is not None else trimmed_pc[start:end]
    roi_target = roi_state_matrix[start:end] if roi_state_matrix is not None else None
    test_targets.append({
        'window_index': int(row.window_index),
        'trim_slice': (start, end),
        'tr_edges': tr_edges,
        'pc_target': pc_target,
        'roi_target': roi_target,
    })
print('Prepared targets for', len(test_targets), 'test windows.')



In [None]:
pc_score_method = str(dec_cfg.get('pc_score_method', 'mean')).lower()
roi_score_method = str(dec_cfg.get('roi_score_method', 'mean')).lower()
cca_components = int(dec_cfg.get('cca_components', min(topk, 5)))



In [None]:
alpha_values = dec_cfg.get('alpha_sweep', [dec_cfg.get('alpha_encoding', 0.6)])
if not isinstance(alpha_values, (list, tuple)):
    alpha_values = [alpha_values]
alpha_values = [float(val) for val in alpha_values]

print(f'PC projector prepared with top-{topk} components.')

roi_encoder = None
if roi_state_matrix is not None:
    roi_encoder = ROIEncoder(alpha=float(dec_cfg.get('roi_encoding_reg', 1.0)))
    roi_encoder.fit(trimmed_pc[slice(*splits['train'])], roi_state_matrix[slice(*splits['train'])])
    print('ROI encoder fitted on training split.')
else:
    print('ROI encoder unavailable (no ROI multivariate state).')

if reference_texts:
    print('Reference texts available for', len(reference_texts), 'windows.')
else:
    print('Reference texts unavailable; nearest-neighbor decoding will be skipped if needed.')

if test_targets:
    print('Test targets ready:', len(test_targets))
else:
    print('No test targets prepared; decoding will be skipped.')


In [None]:
decoded_outputs = {}
alpha_summary = []
decode_cache = {}

beam_available = False
if BeamDecoder is not None and english_loader is not None and test_targets:
    lm_name = dec_cfg.get('lm_name', 'gpt2')
    beam_size = int(dec_cfg.get('beam_size', 5))
    topk_next = int(dec_cfg.get('topk_next', 10))
    max_tokens = int(dec_cfg.get('max_tokens_per_window', 25))
    prompt_text = dec_cfg.get('prompt_text', '')
    for alpha in alpha_values:
        try:
            decoder = BeamDecoder(
                lm_name=lm_name,
                beam_size=beam_size,
                topk_next=topk_next,
                alpha_encoding=float(alpha),
                max_tokens_per_window=max_tokens,
            )
        except ImportError as exc:
            print(f'Beam decoder unavailable for alpha {alpha}:', exc)
            decoded_outputs = {}
            beam_available = False
            break
        beam_available = True
        decoded_texts = []
        for target in test_targets:
            embedding_target = target['pc_target']
            roi_target = target['roi_target']
            tr_edges = target['tr_edges']
            roi_cca = min(cca_components, roi_target.shape[1]) if roi_target is not None else None

            def scorer(text, tr_edges_inner, _unused, embedding_target=embedding_target, roi_target=roi_target, roi_cca=roi_cca):
                return combined_score(
                    text,
                    tr_edges_inner,
                    embedding_target=embedding_target,
                    embedding_projector=pc_projector,
                    embedding_method=pc_score_method,
                    cca_components=cca_components,
                    roi_target=roi_target,
                    roi_encoder=roi_encoder,
                    roi_method=roi_score_method,
                    roi_cca_components=roi_cca,
                    weights=selection_weights,
                    cache=decode_cache,
                )

            decoded = decoder.decode_window(prompt_text, scorer, tr_edges, embedding_target)
            decoded_texts.append(decoded.strip())
        key = f'beam_alpha_{alpha:.2f}'
        decoded_outputs[key] = decoded_texts
        alpha_summary.append({'method': 'beam', 'alpha': float(alpha), 'n_windows': len(decoded_texts)})
if not beam_available:
    print('Beam decoding unavailable; falling back to nearest-neighbor reference selection.')

if (not decoded_outputs or not beam_available) and reference_texts:
    candidate_pool = reference_texts
    nn_texts = []
    for target in test_targets:
        tr_edges = target['tr_edges']
        embedding_target = target['pc_target']
        roi_target = target['roi_target']
        roi_cca = min(cca_components, roi_target.shape[1]) if roi_target is not None else None
        best_text = ''
        best_score = float('-inf')
        for cand in candidate_pool:
            score = combined_score(
                cand,
                tr_edges,
                embedding_target=embedding_target,
                embedding_projector=pc_projector,
                embedding_method=pc_score_method,
                cca_components=cca_components,
                roi_target=roi_target,
                roi_encoder=roi_encoder,
                roi_method=roi_score_method,
                roi_cca_components=roi_cca,
                weights=selection_weights,
                cache=decode_cache,
            )
            if score > best_score:
                best_score = score
                best_text = cand
        nn_texts.append(best_text)
    decoded_outputs['nearest_neighbor'] = nn_texts
    alpha_summary.append({'method': 'nearest_neighbor', 'alpha': None, 'n_windows': len(nn_texts)})

decoded_records = []
for method, texts in decoded_outputs.items():
    for window_idx, text in zip(test_windows_df['window_index'], texts):
        decoded_records.append({'method': method, 'window_index': int(window_idx), 'decoded_text': text})

decoded_df = pd.DataFrame(decoded_records)
decoded_df.to_csv(output_root / 'decoded_texts.csv', index=False)
pd.DataFrame(alpha_summary).to_csv(output_root / 'alpha_decode_summary.csv', index=False)
print('Decoded outputs saved for methods:', list(decoded_outputs.keys()))
