In [61]:
import os
os.chdir('/flash/PaoU/seann/fmri-edm-ccm')

In [62]:
from src.utils import load_yaml

cfg = load_yaml('configs/demo.yaml')
SUB, STORY = cfg['subject'], cfg['story']
paths = cfg['paths']
print('Subject/Story:', SUB, STORY)
print('E_univ:', cfg.get('E_univ', cfg['E_grid'][0]), 'E_mult:', cfg.get('E_mult', cfg['E_grid'][1]), 'tau:', cfg['tau'], 'delta:', cfg['delta'])

Subject/Story: UTS01 wheretheressmoke
E_univ: 3 E_mult: 4 tau: 1 delta: [1, 2]


In [63]:
from pathlib import Path

data_root = Path(paths['data_root'])
bold_runs = sorted(data_root.glob(f'sub-{SUB}/**/*{STORY}*_bold.nii.gz'))
print('BOLD runs found:', len(bold_runs))
print('Example BOLD:', bold_runs[0] if bold_runs else None)
audio_path = next((p for p in [data_root / 'stimuli' / f'{STORY}.wav', data_root / 'audio' / f'{STORY}.wav'] if p.exists()), None)
textgrid_path = next((p for p in [
    data_root / 'stimuli' / f'{STORY}.TextGrid',
    data_root / 'annotations' / f'{STORY}.TextGrid',
    data_root / 'derivative' / 'TextGrids' / f'{STORY}.TextGrid',
    data_root / 'derivative' / 'TextGrids' / f'{STORY.lower()}.TextGrid',
] if p.exists()), None)
print('Audio WAV:', audio_path)
print('TextGrid:', textgrid_path)

BOLD runs found: 10
Example BOLD: /bucket/PaoU/seann/openneuro/ds003020/sub-UTS01/ses-10/func/sub-UTS01_ses-10_task-wheretheressmoke_run-9_bold.nii.gz
Audio WAV: /bucket/PaoU/seann/openneuro/ds003020/stimuli/wheretheressmoke.wav
TextGrid: /bucket/PaoU/seann/openneuro/ds003020/derivative/TextGrids/wheretheressmoke.TextGrid


In [64]:
import sys
import numpy as np
import pandas as pd

sys.path.append('/flash/PaoU/seann/pyEDM/src')
from pyEDM import Simplex, SMap, CCM
from sklearn.linear_model import LinearRegression

from src import features, roi, ccm
from src.utils import zscore_per_column

X = features.load_english1000_TR(SUB, STORY, paths)
Z, _ = features.pca_fit_transform(X, cfg['pca_components'])
env = features.load_envelope_TR(SUB, STORY, paths)
wr = features.load_wordrate_TR(SUB, STORY, paths)
R = roi.load_schaefer_timeseries_TR(SUB, STORY, cfg['n_parcels'], paths)

print('Semantic/TR shapes:', X.shape, Z.shape)
print('Drivers shapes:', env.shape, wr.shape)
print('ROI shape:', R.shape)

Semantic/TR shapes: (301, 985) (301, 128)
Drivers shapes: (301,) (301,)
ROI shape: (301, 400)


In [65]:
E_univ = cfg.get('E_univ', cfg['E_grid'][0])
theiler_univ = max(cfg['theiler_min'], E_univ)
shortlist = ccm.ccm_conditional_screen(
    R,
    Z[:, 0],
    [env, wr],
    E_univ,
    cfg['tau'],
    theiler_univ,
    cfg['lib_sizes'],
)[: cfg['shortlist_topk']]
print('CCM shortlist (top 10):', shortlist[:10])

candidate_count = min(6, len(shortlist))
candidate_rois = shortlist[:candidate_count]

base_df = pd.DataFrame({
    'Time': np.arange(1, len(Z) + 1),
    'sem_pc1': Z[:, 0],
    'env': env,
    'wr': wr,
})
for idx in candidate_rois:
    base_df[f'roi_{idx}'] = R[:, idx]
base_df.head()

CCM shortlist (top 10): [251, 278, 254, 294, 240, 391, 234, 352, 397, 55]


Unnamed: 0,Time,sem_pc1,env,wr,roi_251,roi_278,roi_254,roi_294,roi_240,roi_391
0,1,7.008479,0.015846,7.0,1.212879,3.058342,1.301039,0.709288,2.214728,1.074983
1,2,1.394021,0.015706,7.0,0.480061,1.545674,0.015077,-0.075154,0.729541,1.209382
2,3,6.088841,0.017975,9.0,1.051968,1.178203,0.757851,0.068815,-1.210854,1.216795
3,4,9.195161,0.027316,11.0,-0.089952,0.815287,-0.27598,0.097228,-1.678459,-0.810547
4,5,4.663446,0.030437,9.0,0.341899,0.828872,0.217353,-0.805272,-1.854795,-0.919949


In [66]:
delta = cfg['delta'][0]
tau_embed = cfg['tau']
train_fraction = 0.6
candidate_lags = [0, 1, 2]
max_lag = max(E_univ - 1, max(candidate_lags))

lag_vars = ['sem_pc1'] + [f'roi_{idx}' for idx in candidate_rois]
lagged = {var: {} for var in lag_vars}
for var in lag_vars:
    for lag in range(max(candidate_lags) + 1):
        lagged[var][lag] = base_df[var].shift(lag).iloc[max_lag:].reset_index(drop=True)

time_trim = base_df['Time'].iloc[max_lag:].reset_index(drop=True)
target_trim = lagged['sem_pc1'][0].copy()
env_trim = base_df['env'].iloc[max_lag:].reset_index(drop=True)
wr_trim = base_df['wr'].iloc[max_lag:].reset_index(drop=True)
roi_trim = {f'roi_{idx}': base_df[f'roi_{idx}'].iloc[max_lag:].reset_index(drop=True) for idx in candidate_rois}
print('Usable samples after lag trimming:', len(time_trim))

from typing import List, Tuple, Dict

def evaluate_state(state_spec: List[Tuple[str, int]], return_all: bool = False) -> Dict[str, object]:
    state_cols = []
    data = pd.DataFrame({'Time': time_trim.values, 'sem_pc1': target_trim.values})
    for var, lag in state_spec:
        col_name = f'{var}_lag{lag}'
        data[col_name] = lagged[var][lag].values
        state_cols.append(col_name)
    for var, series in roi_trim.items():
        data[var] = series.values
    data['env'] = env_trim.values
    data['wr'] = wr_trim.values
    data = data.dropna().reset_index(drop=True)
    N = len(data)
    pred_end = N - delta
    if pred_end <= 0:
        return {'rho_simplex': np.nan}
    train_end = min(int(N * train_fraction), pred_end - 5)
    if train_end <= delta or pred_end <= train_end:
        return {'rho_simplex': np.nan}
    for col in state_cols + ['sem_pc1']:
        mu = data.loc[:train_end-1, col].mean()
        sigma = data.loc[:train_end-1, col].std(ddof=0)
        if sigma == 0:
            sigma = 1.0
        data[col] = (data[col] - mu) / sigma
    lib_str = f'1 {train_end}'
    pred_str = f'{train_end + 1} {pred_end}'
    exclusion = max((len(state_cols) - 1) * tau_embed + delta, cfg['theiler_min'])
    knn = len(state_cols) + 1
    simplex_obj = Simplex(
        dataFrame=data[['Time', 'sem_pc1'] + state_cols],
        columns=' '.join(state_cols),
        target='sem_pc1',
        E=len(state_cols),
        tau=-tau_embed,
        Tp=delta,
        knn=knn,
        exclusionRadius=exclusion,
        lib=lib_str,
        pred=pred_str,
        embedded=True,
        verbose=False,
        returnObject=True,
    )
    simplex_obj.Run()
    proj = simplex_obj.Projection.dropna()
    rho_simplex = float(np.corrcoef(proj['Observations'], proj['Predictions'])[0, 1]) if not proj.empty else np.nan
    result = {
        'rho_simplex': rho_simplex,
        'train_end': train_end,
        'pred_end': pred_end,
        'state_cols': state_cols,
    }
    if return_all:
        result['data'] = data
    return result

Usable samples after lag trimming: 299


In [67]:
base_state = [('sem_pc1', 0), ('sem_pc1', 1), ('sem_pc1', 2)]
base_res = evaluate_state(base_state)
print('Base simplex ρ:', base_res['rho_simplex'])

best_state = base_state
best_res = base_res
best_details = {'type': 'base'}
threshold = 0.02

for roi_idx in candidate_rois:
    var = f'roi_{roi_idx}'
    for lag in candidate_lags:
        candidate_state = [('sem_pc1', 0), ('sem_pc1', 1), (var, lag)]
        res = evaluate_state(candidate_state)
        if np.isfinite(res['rho_simplex']) and res['rho_simplex'] > best_res['rho_simplex'] + threshold:
            best_state = candidate_state
            best_res = res
            best_details = {'type': 'single', 'roi': roi_idx, 'lag': lag}

for i, roi_a in enumerate(candidate_rois):
    for roi_b in candidate_rois[i+1:]:
        var_a, var_b = f'roi_{roi_a}', f'roi_{roi_b}'
        for lag_a in candidate_lags:
            for lag_b in candidate_lags:
                candidate_state = [('sem_pc1', 0), (var_a, lag_a), (var_b, lag_b)]
                res = evaluate_state(candidate_state)
                if np.isfinite(res['rho_simplex']) and res['rho_simplex'] > best_res['rho_simplex'] + threshold:
                    best_state = candidate_state
                    best_res = res
                    best_details = {'type': 'pair', 'roiA': roi_a, 'lagA': lag_a, 'roiB': roi_b, 'lagB': lag_b}

print('Best state:', best_state)
print('Δρ relative to baseline:', best_res['rho_simplex'] - base_res['rho_simplex'])
print('Details:', best_details)

Base simplex ρ: 0.07910433509432334
Best state: [('sem_pc1', 0), ('roi_294', 1), ('roi_240', 2)]
Δρ relative to baseline: 0.31279251069821495
Details: {'type': 'pair', 'roiA': 294, 'lagA': 1, 'roiB': 240, 'lagB': 2}


In [68]:
final_res = evaluate_state(best_state, return_all=True)
df_final = final_res['data']
state_cols = final_res['state_cols']
train_end = final_res['train_end']
pred_end = final_res['pred_end']
print(f'Train rows: {train_end}, Pred rows: {pred_end - train_end}')

theta_grid = [0, 1, 2, 4, 8]
theta_scores = {}
exclusion = max((len(state_cols) - 1) * tau_embed + delta, cfg['theiler_min'])
knn = len(state_cols) + 1
for theta in theta_grid:
    smap_obj = SMap(
        dataFrame=df_final[['Time', 'sem_pc1'] + state_cols],
        columns=' '.join(state_cols),
        target='sem_pc1',
        E=len(state_cols),
        tau=-tau_embed,
        Tp=delta,
        knn=knn,
        theta=theta,
        exclusionRadius=exclusion,
        lib=f'1 {train_end}',
        pred=f'{train_end + 1} {pred_end}',
        embedded=True,
        verbose=False,
        returnObject=True,
    )
    smap_obj.Run()
    proj = smap_obj.Projection.dropna()
    theta_scores[theta] = float(np.corrcoef(proj['Observations'], proj['Predictions'])[0, 1]) if not proj.empty else np.nan

theta_pref = max(theta_scores, key=lambda t: theta_scores[t])
rho_smap = theta_scores[theta_pref]
print('ρ_univ (base simplex):', base_res['rho_simplex'])
print('ρ_simplex (best state):', best_res['rho_simplex'])
print('ρ_smap:', rho_smap)
print('θ preference:', theta_pref)

baseline_E, baseline_tau = 3, 1
from src import baselines
lag_stack = features.make_lag_stack(df_final[['env', 'wr']].to_numpy(), E=baseline_E, tau=baseline_tau)
baseline_offset = (baseline_E - 1) * baseline_tau
y_baseline = df_final['sem_pc1'].iloc[baseline_offset:].reset_index(drop=True)
lag_stack = lag_stack[: len(y_baseline)]
train_end_baseline = max(0, train_end - baseline_offset)
pred_end_baseline = max(train_end_baseline + 1, pred_end - baseline_offset)
yhat_baseline = baselines.ridge_forecast(lag_stack, y_baseline.to_numpy())
rho_baseline = float(np.corrcoef(
    y_baseline.iloc[train_end_baseline:pred_end_baseline],
    yhat_baseline[train_end_baseline:pred_end_baseline]
)[0, 1]) if pred_end_baseline > train_end_baseline else np.nan
print('ρ_drivers (ridge baseline):', rho_baseline)

Train rows: 179, Pred rows: 119
ρ_univ (base simplex): 0.07910433509432334
ρ_simplex (best state): 0.3918968457925383
ρ_smap: -0.120778041978878
θ preference: 0
ρ_drivers (ridge baseline): 0.20351073748316667


In [69]:
from pathlib import Path
from src import plots

plot_root = Path(paths['figs']) / SUB / STORY / 'day10_pyEDM_small'
plot_root.mkdir(parents=True, exist_ok=True)

skill_summary = {
    'drivers-only': rho_baseline,
    'simplex': best_res['rho_simplex'],
    'smap': rho_smap,
}
plots.forecast_bars(skill_summary, str(plot_root / 'forecast_bars.png'))
plots.theta_sweep(theta_grid, [theta_scores[t] for t in theta_grid], str(plot_root / 'theta_sweep.png'))
plots.attractor_3d(df_final[state_cols].to_numpy(), str(plot_root / 'attractor_3d.png'), color=np.linspace(0, 1, df_final.shape[0]))

In [70]:
roi_vars = [var for var, lag in best_state if var.startswith('roi_')]
if roi_vars:
    chosen_var = roi_vars[0]
    drivers_all = df_final[['env', 'wr']].to_numpy()
    model_roi = LinearRegression(fit_intercept=True).fit(drivers_all[:train_end], df_final[chosen_var].iloc[:train_end])
    roi_resid = df_final[chosen_var] - model_roi.predict(drivers_all)
    model_target = LinearRegression(fit_intercept=True).fit(drivers_all[:train_end], df_final['sem_pc1'].iloc[:train_end])
    target_resid = df_final['sem_pc1'] - model_target.predict(drivers_all)
    ccm_df = pd.DataFrame({
        'Time': df_final['Time'].iloc[:pred_end].astype(int).values,
        'X': roi_resid.iloc[:pred_end].values,
        'Y': target_resid.iloc[:pred_end].values,
    })
    max_lib = min(pred_end, train_end)
    lib_grid = sorted(set(int(val) for val in np.linspace(int(max_lib * 0.3), max_lib, num=3)))
    lib_grid = [val for val in lib_grid if val > 10]
    if lib_grid:
        lib_str = ' '.join(str(val) for val in lib_grid)
        ccm_res = CCM(
            dataFrame=ccm_df,
            columns='X',
            target='Y',
            E=E_univ,
            tau=-tau_embed,
            Tp=delta,
            exclusionRadius=max((E_univ - 1) * tau_embed + delta, cfg['theiler_min']),
            libSizes=lib_str,
            sample=1,
            embedded=False,
            verbose=False,
        )
        print(ccm_res)
        plots.ccm_curve(ccm_res['LibSize'].tolist(), ccm_res.iloc[:, 1].tolist(), str(plot_root / 'ccm_curve.png'))
    else:
        print('Insufficient library size for CCM.')
else:
    print('Final state contains only semantic lags; CCM skipped.')

   LibSize      X:Y      Y:X
0       53 -0.04932 -0.01085
1      116 -0.10473  0.12439
2      179 -0.16154  0.02286


In [71]:
print('Final manifold:', best_state)
print('ρ_univ:', base_res['rho_simplex'])
print('ρ_simplex:', best_res['rho_simplex'])
print('ρ_smap:', rho_smap)
print('θ preference:', theta_pref)
print('ρ_drivers:', rho_baseline)

Final manifold: [('sem_pc1', 0), ('roi_294', 1), ('roi_240', 2)]
ρ_univ: 0.07910433509432334
ρ_simplex: 0.3918968457925383
ρ_smap: -0.120778041978878
θ preference: 0
ρ_drivers: 0.20351073748316667


In [72]:
from nilearn import plotting
import nibabel as nib
import numpy as np
from pathlib import Path

atlas_path = Path('parcellations/Parcellations/MNI/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_2mm.nii.gz')
atlas_img = nib.load(str(atlas_path))
roi_ids = [int(var.split('_')[1]) for var, _ in best_state if var.startswith('roi_')]
roi_labels = {240: '7Networks_RH_SomMot_10', 294: '7Networks_RH_SalVentAttn_TempOccPar_1'}
fig_root = Path(paths['figs']) / SUB / STORY / 'day10_pyEDM_small'
fig_root.mkdir(parents=True, exist_ok=True)

for roi in roi_ids:
    data = (atlas_img.get_fdata() == roi).astype(np.int16)
    if not data.any():
        print(f'Warning: ROI {roi} not found in atlas.')
        continue
    roi_img = nib.Nifti1Image(data, atlas_img.affine)
    title = f'ROI {roi}: ' + roi_labels.get(roi, 'Schaefer parcel')
    out_orth = str(fig_root / f'roi_{roi}_orth.png')
    display = plotting.plot_roi(roi_img, display_mode='ortho', alpha=0.8, title=title, cmap='autumn_r')
    display.savefig(out_orth)
    display.close()
    out_glass = str(fig_root / f'roi_{roi}_glass.png')
    glass = plotting.plot_roi(roi_img, display_mode='lyrz', title=title + ' (glass)', cmap='autumn_r')
    glass.savefig(out_glass)
    glass.close()
    print(f'Saved {out_orth} and {out_glass}')

ValueError: lyrz is not a valid display_mode. Valid options are ['mosaic', 'ortho', 'tiled', 'x', 'xz', 'y', 'yx', 'yz', 'z']

In [None]:
from nilearn import plotting
import pandas as pd
from pathlib import Path

roi_ids = [int(var.split('_')[1]) for var, _ in best_state if var.startswith('roi_')]
cent_path = Path('parcellations/Parcellations/MNI/Centroid_coordinates/Schaefer2018_400Parcels_7Networks_order_FSLMNI152_2mm.Centroid_RAS.csv')
cent = pd.read_csv(cent_path)
subset = cent[cent['ROI Label'].isin(roi_ids)]
coords = subset[['R', 'A', 'S']].to_numpy()
labels = subset['ROI Name'].tolist()
view = plotting.view_markers(coords, marker_labels=labels, marker_size=10)
html_path = Path(paths['figs']) / SUB / STORY / 'day10_pyEDM_small' / 'roi_markers.html'
html_path.parent.mkdir(parents=True, exist_ok=True)
view.save_as_html(str(html_path))
print('Saved interactive ROI view to', html_path)
view