# Day 3 — EDM/CCM Drivers for `sub-UTS01`

Focus: single-subject semantic forecasting diagnostics with **word**, **phoneme**, **acoustic**, and **English1000** trajectories.

Checklist
- [ ] Load drivers (envelope, word-rate, phoneme-rate, semantic PCs)
- [ ] Run per-driver Simplex sweeps (E = 2–6)
- [ ] Compare multi-driver forecasts (word + acoustics / semantics)
- [ ] Run S-Map nonlinearity test for the best semantic channel
- [ ] Perform CCM library sweeps across candidate drivers
- [ ] Save artifacts under `derivatives/results/day3_sub-UTS01/`
- [ ] Record recommendations and future semantic embedding TODOs


In [None]:
import sys
from pathlib import Path

REPO_ROOT = Path.cwd().parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))


In [None]:
import time
import pickle
from pathlib import Path
from typing import List, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
from IPython.display import display

try:
    import pyEDM
    HAVE_PYEDM = True
except ImportError:
    HAVE_PYEDM = False
    print('pyEDM not installed. Install with `pip install pyEDM` before running EDM cells.')

from src.io_ds003020 import list_stories_for_subject
from src.qc_viz import (
    HAVE_AUDIO,
    HAVE_TEXTGRID,
    ensure_dir,
    normalize,
)
from src.edm_ccm import (
    English1000Loader,
    StoryDriver,
    aggregate_driver_summary,
    load_subject_stories,
    multi_segment_library_lengths,
)

plt.style.use('default')
plt.rcParams.update({'figure.figsize': (8, 4)})

DATA_ROOT = Path('/bucket/PaoU/seann/openneuro/ds003020')
SUBJECT_ID = 'sub-UTS01'
TR = 2.0
RESULTS_DIR = ensure_dir(REPO_ROOT / 'derivatives' / 'results' / f'day3_{SUBJECT_ID}')
DRIVER_CACHE_PATH = RESULTS_DIR / f'{SUBJECT_ID}_story_drivers.pkl'
ENGLISH1000_PATH = DATA_ROOT / 'derivative' / 'english1000sm.hf5'
SEMANTIC_COMPONENTS = 64
SEMANTIC_MULTI_TOPK = 3

print(f'AUDIO deps: {HAVE_AUDIO}, TEXTGRID deps: {HAVE_TEXTGRID}, pyEDM: {HAVE_PYEDM}')


In [None]:
subject_records = list_stories_for_subject(DATA_ROOT, SUBJECT_ID)
story_df = pd.DataFrame(subject_records)
if story_df.empty:
    raise RuntimeError(f'No stories found for {SUBJECT_ID}.')

story_df['has_textgrid'] = story_df['textgrid'].notna()
print(f"Total stories: {len(story_df)} | with TextGrid: {story_df['has_textgrid'].sum()}")
display(story_df[['story_id', 'session', 'run', 'has_textgrid']].head())


In [None]:
if not HAVE_AUDIO or not HAVE_TEXTGRID:
    raise RuntimeError('Missing audio/TextGrid dependencies; install librosa, soundfile, textgrid.')

usable_df = story_df[story_df['has_textgrid']].copy()
print(f'Candidate stories with WAV+TextGrid: {len(usable_df)}')


In [None]:
english_loader = None
if ENGLISH1000_PATH.exists():
    try:
        english_loader = English1000Loader(ENGLISH1000_PATH)
        print('English1000 embeddings loaded.')
    except Exception as exc:
        print('Unable to load English1000 embeddings:', exc)
else:
    print('English1000 file not found; semantic PCs will be skipped.')


In [None]:
cache_meta = {
    'tr': TR,
    'semantic_components': SEMANTIC_COMPONENTS if english_loader is not None else 0,
    'semantic_loader': 'english1000' if english_loader is not None else None,
    'semantic_source': str(ENGLISH1000_PATH) if english_loader is not None else None,
    'story_ids': sorted(usable_df['story_id'].tolist()),
}

story_driver_list: List[StoryDriver]
if DRIVER_CACHE_PATH.exists():
    try:
        with DRIVER_CACHE_PATH.open('rb') as fh:
            cached = pickle.load(fh)
        cached_meta = cached.get('metadata', {})
        if cached_meta == cache_meta:
            story_driver_list = cached['drivers']
            print(f"Loaded story drivers from cache: {DRIVER_CACHE_PATH}")
        else:
            print('Driver cache metadata mismatch; regenerating.')
            story_driver_list = None  # type: ignore
    except Exception as exc:
        print(f'Unable to read driver cache ({exc}); regenerating.')
        story_driver_list = None  # type: ignore
else:
    story_driver_list = None  # type: ignore

if story_driver_list is None:
    start = time.time()
    story_driver_list = load_subject_stories(
        usable_df.to_dict('records'),
        tr=TR,
        semantic_loader=english_loader,
        semantic_components=SEMANTIC_COMPONENTS if english_loader is not None else None,
    )
    elapsed = time.time() - start
    print(f'Regenerated story drivers in {elapsed/60:.1f} min; caching to {DRIVER_CACHE_PATH}')
    payload = {'metadata': cache_meta, 'drivers': story_driver_list}
    with DRIVER_CACHE_PATH.open('wb') as fh:
        pickle.dump(payload, fh)
else:
    print('Driver cache matches current configuration; using cached drivers.')

print(f'Stories with driver series: {len(story_driver_list)}')
if not story_driver_list:
    raise RuntimeError('No driver series available for EDM analysis.')

driver_cache: Dict[str, StoryDriver] = {s.story_id: s for s in story_driver_list}


In [None]:
summary_rows = aggregate_driver_summary(story_driver_list)
summary_df = pd.DataFrame(summary_rows).sort_values('n_tr', ascending=False).reset_index(drop=True)
summary_path = RESULTS_DIR / f'{SUBJECT_ID}_driver_summary.csv'
summary_df.to_csv(summary_path, index=False)
print(f'Summary saved to {summary_path}')
display(summary_df.head())


In [None]:
target_story_id = None
for candidate in summary_df['story_id']:
    drv = driver_cache.get(candidate)
    if drv is None:
        continue
    if drv.drivers.semantic is not None:
        target_story_id = candidate
        break
if target_story_id is None:
    target_story_id = summary_df.iloc[0]['story_id']

print(f'Target story: {target_story_id}')
target_story = driver_cache[target_story_id]
series = target_story.drivers

word_rate = series.word_rate if series.word_rate is not None else np.zeros(series.n_tr, dtype=float)
phoneme_rate = series.phoneme_rate if series.phoneme_rate is not None else np.zeros(series.n_tr, dtype=float)
semantic_cols = []
if series.semantic is not None and series.semantic_labels is not None:
    semantic_cols = [f'Semantic_{label}' for label in series.semantic_labels]

time_index = np.arange(series.n_tr) * TR
data = {
    'Time': time_index,
    'Envelope': series.envelope.astype(float),
    'WordRate': word_rate.astype(float),
    'PhonemeRate': phoneme_rate.astype(float),
}

if semantic_cols:
    for idx, label in enumerate(semantic_cols):
        data[label] = series.semantic[:, idx].astype(float)

data_df = pd.DataFrame(data)

for col in data_df.columns:
    if col == 'Time':
        continue
    data_df[col] = data_df[col].fillna(0.0)

data_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_drivers.csv'
data_df.to_csv(data_path, index=False)
print(f'Driver time series saved to {data_path}')
data_df.head()


In [None]:
fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(data_df['Time'], normalize(data_df['WordRate']), label='WordRate (norm)')
ax.plot(data_df['Time'], normalize(data_df['Envelope']), label='Envelope (norm)')
if 'PhonemeRate' in data_df.columns and data_df['PhonemeRate'].std() > 0:
    ax.plot(data_df['Time'], normalize(data_df['PhonemeRate']), label='PhonemeRate (norm)')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Normalized amplitude')
ax.set_title(f'{SUBJECT_ID} | {target_story_id} — Driver Overview')
ax.legend(loc='upper right')
fig_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_drivers.png'
fig.savefig(fig_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f'Saved {fig_path}')


In [None]:
E_VALUES = [2, 3, 4, 5, 6]
TAU = 1
TP = 1

def _has_variance(column: str) -> bool:
    return column in data_df.columns and float(data_df[column].std()) > 0

single_driver_configs = []
if 'WordRate' in data_df.columns:
    single_driver_configs.append({
        'name': 'WordRate→WordRate',
        'columns': 'WordRate',
        'target': 'WordRate',
    })
if _has_variance('Envelope'):
    single_driver_configs.append({
        'name': 'Envelope→WordRate',
        'columns': 'Envelope',
        'target': 'WordRate',
    })
if _has_variance('PhonemeRate'):
    single_driver_configs.append({
        'name': 'PhonemeRate→WordRate',
        'columns': 'PhonemeRate',
        'target': 'WordRate',
    })
for col in semantic_cols[:SEMANTIC_MULTI_TOPK]:
    if _has_variance(col):
        single_driver_configs.append({
            'name': f'{col}→WordRate',
            'columns': col,
            'target': 'WordRate',
        })

multi_configs = []
seen_pairs = set()
base_drivers = [col for col in ['Envelope', 'PhonemeRate'] if _has_variance(col)]
top_semantic = [col for col in semantic_cols[:SEMANTIC_MULTI_TOPK] if _has_variance(col)]

candidates = []
if base_drivers:
    candidates.append(base_drivers)
for k in range(1, len(top_semantic) + 1):
    candidates.append(base_drivers + top_semantic[:k])

for cols in candidates:
    unique_cols = []
    for col in cols:
        if col not in unique_cols:
            unique_cols.append(col)
    if not unique_cols:
        continue
    name = '+'.join(unique_cols) + '→WordRate'
    key = tuple(unique_cols)
    if key in seen_pairs:
        continue
    seen_pairs.add(key)
    multi_configs.append({
        'name': name,
        'columns': ' '.join(unique_cols),
        'target': 'WordRate',
    })

if not single_driver_configs and not multi_configs:
    raise RuntimeError('No driver configurations available for simplex sweep.')


In [None]:
lib_range = f"1 {len(data_df)}"
pred_range = lib_range

simplex_records = []
for cfg in single_driver_configs + multi_configs:
    for E in E_VALUES:
        try:
            res = pyEDM.Simplex(
                dataFrame=data_df,
                E=E,
                tau=TAU,
                Tp=TP,
                columns=cfg['columns'],
                target=cfg['target'],
                lib=lib_range,
                pred=pred_range,
                ignoreNan=True,
                verbose=False,
            )
            rho = float(res['rho'].values[-1]) if 'rho' in res else np.nan
        except Exception as exc:
            print(f"Skip {cfg['name']} (E={E}): {exc}")
            rho = np.nan
        simplex_records.append({'name': cfg['name'], 'columns': cfg['columns'], 'target': cfg['target'], 'E': E, 'rho': rho})

simplex_df = pd.DataFrame(simplex_records)
simplex_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_simplex_comparison.csv'
simplex_df.to_csv(simplex_path, index=False)
print(f'Simplex comparison saved to {simplex_path}')

fig, ax = plt.subplots()
for name, sub in simplex_df.groupby('name'):
    ax.plot(sub['E'], sub['rho'], marker='o', label=name)
ax.set_xlabel('Embedding dimension E')
ax.set_ylabel('Forecast skill (rho)')
ax.set_title(f'Simplex comparison | {SUBJECT_ID} {target_story_id}')
ax.axhline(0, color='black', linewidth=1)
ax.legend(loc='best')
fig_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_simplex_comparison.png'
fig.savefig(fig_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f'Saved {fig_path}')

wordrate_mask = simplex_df['target'] == 'WordRate'
if wordrate_mask.any():
    best_word_row = simplex_df[wordrate_mask].sort_values('rho', ascending=False).iloc[0]
else:
    best_word_row = simplex_df.sort_values('rho', ascending=False).iloc[0]

try:
    best_E = int(best_word_row['E'])
except (TypeError, ValueError):
    best_E = E_VALUES[0]

print(f"Best WordRate config: {best_word_row['name']} (E={best_E}, rho={best_word_row['rho']:.3f})")


In [None]:
simplex_run = pyEDM.Simplex(
    dataFrame=data_df,
    E=best_E,
    tau=1,
    Tp=1,
    columns='WordRate',
    target='WordRate',
    lib=f"1 {len(data_df)}",
    pred=f"1 {len(data_df)}",
    ignoreNan=True,
    verbose=True,
)

forecast_df = pd.DataFrame({
    'Time': simplex_run['Time'],
    'Observed': simplex_run['Observations'],
    'Forecast': simplex_run['Predictions'],
})
forecast_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_simplex_wordrate_forecast.csv'
forecast_df.to_csv(forecast_path, index=False)
print(f'Forecast series saved to {forecast_path}')

fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(forecast_df['Time'], forecast_df['Observed'], label='Observed')
ax.plot(forecast_df['Time'], forecast_df['Forecast'], label='Forecast', alpha=0.8)
ax.set_xlabel('Time index (TR)')
ax.set_ylabel('WordRate')
ax.set_title(f'WordRate Simplex forecast (E={best_E}, Tp=1)')
ax.legend(loc='upper right')
fig_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_simplex_wordrate_forecast.png'
fig.savefig(fig_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f'Saved {fig_path}')


In [None]:
theta_values = [0, 0.5, 1.0, 1.5, 2.0]
sm_records = []
for theta in theta_values:
    res = pyEDM.SMap(
        dataFrame=data_df,
        E=best_E,
        tau=1,
        Tp=1,
        columns='WordRate',
        target='WordRate',
        theta=theta,
        lib=f"1 {len(data_df)}",
        pred=f"1 {len(data_df)}",
        ignoreNan=True,
        verbose=False,
    )
    rho = float(res['rho'].values[-1]) if 'rho' in res else np.nan
    sm_records.append({'theta': theta, 'rho': rho})

smap_df = pd.DataFrame(sm_records)
smap_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_smap_wordrate.csv'
smap_df.to_csv(smap_path, index=False)
print(f'S-Map results saved to {smap_path}')

display(smap_df)
fig, ax = plt.subplots()
ax.plot(smap_df['theta'], smap_df['rho'], marker='o')
ax.set_xlabel('Nonlinearity (theta)')
ax.set_ylabel('Forecast skill (rho)')
ax.set_title(f'WordRate S-Map | E={best_E}, Tp=1')
ax.axhline(0, color='black', linewidth=1)
fig_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_smap_wordrate.png'
fig.savefig(fig_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f'Saved {fig_path}')


In [None]:
semantic_cols = [col for col in data_df.columns if col.startswith('Semantic_')]
ccm_pairs = []
if 'Envelope' in data_df.columns and data_df['Envelope'].std() > 0:
    ccm_pairs.append(('Envelope', 'WordRate'))
if 'PhonemeRate' in data_df.columns and data_df['PhonemeRate'].std() > 0:
    ccm_pairs.append(('PhonemeRate', 'WordRate'))
if semantic_cols:
    ccm_pairs.append((semantic_cols[0], 'WordRate'))

lib_sizes = np.arange(50, min(600, series.n_tr), 50)
if lib_sizes.size == 0:
    print('Library sizes too small for CCM sweep; skipping CCM.')
else:
    ccm_frames = []
    for driver, target in ccm_pairs:
        try:
            forward = pyEDM.CCM(
                dataFrame=data_df,
                E=best_E,
                tau=1,
                columns=driver,
                target=target,
                Tp=0,
                librarySizes=lib_sizes,
                sample=100,
                random=True,
                verbose=False,
            )
            forward['pair'] = f'{driver}→{target}'
            ccm_frames.append(forward)

            reverse = pyEDM.CCM(
                dataFrame=data_df,
                E=best_E,
                tau=1,
                columns=target,
                target=driver,
                Tp=0,
                librarySizes=lib_sizes,
                sample=100,
                random=True,
                verbose=False,
            )
            reverse['pair'] = f'{target}→{driver}'
            ccm_frames.append(reverse)
        except Exception as exc:
            print(f'Skip CCM pair {driver}->{target}: {exc}')

    if ccm_frames:
        ccm_df = pd.concat(ccm_frames, ignore_index=True)
        ccm_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_ccm_comparison.csv'
        ccm_df.to_csv(ccm_path, index=False)
        print(f'CCM results saved to {ccm_path}')

        fig, ax = plt.subplots()
        for pair, sub in ccm_df.groupby('pair'):
            ax.plot(sub['LibSize'], sub['Rho'], marker='o', label=pair)
        ax.set_xlabel('Library size (TR)')
        ax.set_ylabel('CCM ρ')
        ax.set_title(f'CCM library sweep | {SUBJECT_ID} {target_story_id}')
        ax.legend(loc='best')
        fig_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_ccm_comparison.png'
        fig.savefig(fig_path, dpi=150, bbox_inches='tight')
        plt.close(fig)
        print(f'Saved {fig_path}')
    else:
        print('No CCM pairs computed; all drivers missing variance or pyEDM error.')


In [None]:
library_lengths = summary_df['n_tr'].values
lib_depth = multi_segment_library_lengths(library_lengths, exclusion=best_E)
rows = []
for E in [2, 3, 4, 5, 6]:
    effective = np.maximum(library_lengths - E, 0)
    rows.append({
        'E': E,
        'median_library': float(np.median(effective)),
        'min_library': int(np.min(effective)),
        'stories_ge_200': int((effective >= 200).sum()),
    })
embedding_df = pd.DataFrame(rows)
display(embedding_df)
fig, ax = plt.subplots()
ax.plot(embedding_df['E'], embedding_df['median_library'], marker='o', label='Median usable TRs')
ax.plot(embedding_df['E'], embedding_df['min_library'], marker='s', label='Minimum usable TRs')
ax.axhline(200, color='black', linestyle='--', linewidth=1)
ax.set_xlabel('Embedding dimension E')
ax.set_ylabel('Usable TR count per story')
ax.set_title(f'{SUBJECT_ID} — Library depth vs embedding dimension')
ax.legend()
fig_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_library_vs_E.png'
fig.savefig(fig_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f'Saved {fig_path}')


In [None]:
EMBED_E = max(3, best_E)
TAU_ATTR = 1
window = (EMBED_E - 1) * TAU_ATTR
if len(data_df) <= window:
    print(f'Skipping attractor plot: insufficient TRs for E={EMBED_E}.')
else:
    values = data_df['WordRate'].values
    embedded = []
    for idx in range(len(values) - window):
        embedded.append([values[idx + j * TAU_ATTR] for j in range(EMBED_E)])
    embedded = np.asarray(embedded)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(embedded[:, 0], embedded[:, 1], embedded[:, 2], lw=0.6, alpha=0.9)
    ax.set_xlabel('WordRate(t)')
    ax.set_ylabel('WordRate(t-τ)')
    ax.set_zlabel('WordRate(t-2τ)')
    ax.set_title(f'{SUBJECT_ID} | {target_story_id} — Attractor (E={EMBED_E}, τ={TAU_ATTR})')
    fig_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_attractor.png'
    fig.savefig(fig_path, dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f'Saved {fig_path}')



In [None]:
best_rho_by_name = simplex_df.sort_values('rho', ascending=False).groupby('name').first()
if not best_rho_by_name.empty:
    best_rho_dict = {name: float(row['rho']) for name, row in best_rho_by_name.iterrows()}
else:
    best_rho_dict = {}

smap_best_theta = float(smap_df.sort_values('rho', ascending=False).iloc[0]['theta']) if not smap_df.empty else float('nan')
ccm_pairs_summary = ccm_df['pair'].unique().tolist() if 'ccm_df' in locals() else []

summary_notes = {
    'subject': SUBJECT_ID,
    'story': target_story_id,
    'simplex_best_rho': best_rho_dict,
    'best_wordrate_E': int(best_word_row['E']),
    'best_wordrate_rho': float(best_word_row['rho']),
    'smap_best_theta': smap_best_theta,
    'ccm_pairs': ccm_pairs_summary,
    'semantic_components': int(SEMANTIC_COMPONENTS) if english_loader is not None else 0,
    'semantic_columns_used': semantic_cols[:SEMANTIC_MULTI_TOPK] if semantic_cols else [],
}

notes_path = RESULTS_DIR / f'{SUBJECT_ID}_{target_story_id}_day3_notes.json'
import json
notes_path.write_text(json.dumps(summary_notes, indent=2))
print('Notes saved to', notes_path)
summary_notes



### Future embedding TODOs
- Evaluate Word2Vec and GloVe embeddings aligned via TextGrid timings.
- Test contextual LMs (e.g., GPT/BERT) and QA-style features (Huth & Tang, 2025).
- Run conditional CCM controlling for envelope/phoneme drivers when forecasting semantic PCs.
