# Day 21 â€“ Batch Category Export with Word2Vec Fallback

This notebook extends Day 20 by aligning a Word2Vec fallback embedding space to
English1000 so OOV tokens can be mapped automatically. It batches category
feature generation for every subject/story pair, saves the outputs, and reports
combined token coverage using both vocabularies.


In [None]:
import json
import math
import os
import sys
import warnings
from collections import Counter
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd

np.random.seed(42)
pd.options.display.max_columns = 60

project_root = Path('/flash/PaoU/seann/fmri-edm-ccm')
project_root.mkdir(parents=True, exist_ok=True)
os.chdir(project_root)

sys.path.append(str(project_root))
sys.path.append('/flash/PaoU/seann/pyEDM/src')
sys.path.append('/flash/PaoU/seann/MDE-main/src')

try:
    from gensim.models import KeyedVectors
except Exception:
    KeyedVectors = None  # type: ignore

try:
    from IPython.display import display
except Exception:
    def display(obj):  # type: ignore
        print(obj)

try:
    import matplotlib.pyplot as plt  # noqa: F401  # kept for parity with Day 19 helpers
except Exception as exc:
    plt = None
    warnings.warn(f'Matplotlib unavailable: {exc}')

from src.utils import load_yaml
from src.decoding import load_transcript_words
from src.edm_ccm import English1000Loader

EPS = 1e-12


In [None]:
# --- Configuration -------------------------------------------------------------------
cfg = load_yaml('configs/demo.yaml')
categories_cfg = cfg.get('categories', {}) or {}
cluster_csv_path = categories_cfg.get('cluster_csv_path', '')
prototype_weight_power = float(categories_cfg.get('prototype_weight_power', 1.0))
seconds_bin_width_default = float(categories_cfg.get('seconds_bin_width', 0.05))
temporal_weighting_default = str(categories_cfg.get('temporal_weighting', 'proportional')).lower()

paths = cfg.get('paths', {})
TR = float(cfg.get('TR', 2.0))
features_root = Path(paths.get('features', 'features'))
features_root.mkdir(parents=True, exist_ok=True)

SUBJECT = cfg.get('subject') or 'UTS01'
STORY = cfg.get('story') or 'wheretheressmoke'
TEMPORAL_WEIGHTING = temporal_weighting_default  # {'proportional', 'none'}
SECONDS_BIN_WIDTH = seconds_bin_width_default

# canonical smoothing controls tuned for forecasting
SMOOTHING_SECONDS = 0.75             # shorter window preserves fast dynamics for forecasting
SMOOTHING_METHOD = 'gaussian'       # {'moving_average', 'gaussian'}
GAUSSIAN_SIGMA_SECONDS = 0.5 * SMOOTHING_SECONDS  # tie sigma to window length for EDM
SMOOTHING_PAD_MODE = 'reflect'      # {'edge', 'reflect'}

SAVE_OUTPUTS = True  # toggle off to dry-run the generation loop
BATCH_SUBJECTS: Sequence[str] = []  # optionally restrict to these subjects
BATCH_STORIES: Sequence[str] = []   # optionally restrict to these stories
REPORT_TOKEN_COVERAGE = True

fallback_cfg = categories_cfg.get('fallback', {}) or {}
FALLBACK_ENABLED = bool(fallback_cfg.get('enabled', True))
FALLBACK_MODEL_PATH = fallback_cfg.get('model_path') or paths.get('fallback_embeddings')
FALLBACK_BINARY = bool(fallback_cfg.get('binary', True))
FALLBACK_TRANSFORM_PATH = fallback_cfg.get('transform_path', 'misc/fallback_to_english1000.npz')
FALLBACK_LABEL = str(fallback_cfg.get('label', 'word2vec'))

FALLBACK_MODEL = None
FALLBACK_ALIGNED = None
FALLBACK_INFO: Dict[str, Any] = {}

if FALLBACK_ENABLED:
    if KeyedVectors is None:
        raise ImportError('gensim is required to load word2vec-format embeddings for fallback vocab.')
    if not FALLBACK_MODEL_PATH:
        raise ValueError('Fallback embeddings enabled but no model_path provided (see categories.fallback.model_path).')
    fallback_model_path = Path(FALLBACK_MODEL_PATH)
    if not fallback_model_path.exists():
        raise FileNotFoundError(f'Fallback embedding model not found at {fallback_model_path}')
    transform_path = Path(FALLBACK_TRANSFORM_PATH)
    if not transform_path.exists():
        raise FileNotFoundError(f'Fallback alignment transform not found at {transform_path}')

    print(f'[INFO] Loading fallback embeddings from {fallback_model_path} ...')
    FALLBACK_MODEL = KeyedVectors.load_word2vec_format(str(fallback_model_path), binary=FALLBACK_BINARY)
    transform_data = np.load(transform_path)
    rotation = np.asarray(transform_data['rotation'], dtype=np.float32)
    fallback_mean = np.asarray(transform_data['fallback_mean'], dtype=np.float32)
    english_mean = np.asarray(transform_data['english_mean'], dtype=np.float32)

    class AlignedFallback:
        def __init__(self, kv: 'KeyedVectors', rotation: np.ndarray, fallback_mean: np.ndarray, english_mean: np.ndarray):
            self._kv = kv
            self.rotation = rotation
            self.fallback_mean = fallback_mean
            self.english_mean = english_mean
            self.key_to_index = kv.key_to_index
            self.vector_size = rotation.shape[1]

        def get_vector(self, key: str) -> np.ndarray:
            vec = self._kv.get_vector(key)
            return (vec - self.fallback_mean) @ self.rotation + self.english_mean

        def __contains__(self, key: str) -> bool:
            return key in self._kv.key_to_index

        def __getitem__(self, key: str) -> np.ndarray:
            return self.get_vector(key)

    FALLBACK_ALIGNED = AlignedFallback(FALLBACK_MODEL, rotation, fallback_mean, english_mean)
    overlap_tokens = int(transform_data['tokens_used'].shape[0]) if 'tokens_used' in transform_data else None
    FALLBACK_INFO = {
        'model_path': str(fallback_model_path),
        'transform_path': str(transform_path),
        'tokens': len(FALLBACK_MODEL),
        'dim': FALLBACK_MODEL.vector_size,
        'overlap_tokens': overlap_tokens,
        'label': FALLBACK_LABEL,
    }
else:
    print('[INFO] Fallback embeddings disabled; English1000 only.')

print(f'Subject/story default: {SUBJECT} / {STORY}')
print(f'Cluster CSV: {cluster_csv_path or "<none>"}')
print(f'Temporal weighting: {TEMPORAL_WEIGHTING}')
print(f'Seconds bin width: {SECONDS_BIN_WIDTH}')
print(f'Smoothing: {SMOOTHING_METHOD} | window={SMOOTHING_SECONDS}s | sigma={GAUSSIAN_SIGMA_SECONDS}')
if FALLBACK_ALIGNED is not None:
    print(
        f"Fallback model: {FALLBACK_INFO['model_path']} | dim={FALLBACK_INFO['dim']} | tokens={FALLBACK_INFO['tokens']} "
        f"| overlap={FALLBACK_INFO.get('overlap_tokens', 'n/a')}"
    )
    print(f"Fallback transform: {FALLBACK_INFO['transform_path']}")
else:
    print('Fallback model: <disabled>')


In [None]:
# Helper functions

def load_story_words(paths: Dict, subject: str, story: str) -> List[Tuple[str, float, float]]:
    events = load_transcript_words(paths, subject, story)
    if not events:
        raise ValueError(f'No transcript events found for {subject} {story}.')
    return [(str(word).strip(), float(start), float(end)) for word, start, end in events]


def load_clusters_from_csv(csv_path: str) -> Dict[str, Dict[str, List[Tuple[str, float]]]]:
    from pathlib import Path
    if not csv_path or not Path(csv_path).exists():
        raise FileNotFoundError(f'Cluster CSV not found at {csv_path}')
    df = pd.read_csv(csv_path)
    cols = {c.lower().strip(): c for c in df.columns}
    for needed in ('category', 'word'):
        assert needed in cols, f"CSV must contain '{needed}' column."
    cat_col = cols['category']
    word_col = cols['word']
    weight_col = cols.get('weight')
    if weight_col is None:
        df['_weight'] = 1.0
        weight_col = '_weight'
    df = df[[cat_col, word_col, weight_col]].copy()
    df[word_col] = df[word_col].astype(str).str.strip().str.lower()
    df[cat_col] = df[cat_col].astype(str).str.strip().str.lower()
    df[weight_col] = pd.to_numeric(df[weight_col], errors='coerce').fillna(1.0).clip(lower=0.0)
    clusters: Dict[str, Dict[str, List[Tuple[str, float]]]] = {}
    for cat, sub in df.groupby(cat_col):
        bucket: Dict[str, float] = {}
        for w, wt in zip(sub[word_col].tolist(), sub[weight_col].tolist()):
            if not w:
                continue
            bucket[w] = float(wt)
        pairs = sorted(bucket.items())
        if pairs:
            clusters[cat] = {'words': pairs}
    if not clusters:
        raise ValueError('No clusters parsed from CSV.')
    return clusters


def build_states_from_csv(
    clusters: Dict[str, Dict[str, List[Tuple[str, float]]]],
    primary_lookup: Dict[str, np.ndarray],
    fallback=None,
    weight_power: float = 1.0
) -> Tuple[Dict[str, Dict], Dict[str, Dict]]:
    category_states: Dict[str, Dict] = {}
    category_definitions: Dict[str, Dict] = {}
    oov_counts: Dict[str, int] = {}
    for cat, spec in clusters.items():
        pairs = spec.get('words', [])
        vecs: List[np.ndarray] = []
        weights: List[float] = []
        found_words: List[str] = []
        missing_words: List[str] = []
        for word, wt in pairs:
            vec = lookup_embedding(word, primary_lookup, fallback)
            if vec is None:
                missing_words.append(word)
                continue
            vecs.append(vec.astype(float))
            weights.append(float(max(0.0, wt)) ** float(weight_power))
            found_words.append(word)
        if not vecs:
            warnings.warn(f"[{cat}] no usable representative embeddings; prototype will be None.")
            prototype = None
            prototype_norm = None
        else:
            W = np.array(weights, dtype=float)
            W = W / (W.sum() + 1e-12)
            M = np.stack(vecs, axis=0)
            prototype = (W[:, None] * M).sum(axis=0)
            prototype_norm = float(np.linalg.norm(prototype))
            if prototype_norm < EPS:
                prototype = None
                prototype_norm = None
        rep_lex = {word: float(wt) for word, wt in pairs}
        category_states[cat] = {
            'name': cat,
            'seeds': [],
            'found_seeds': found_words,
            'missing_seeds': missing_words,
            'prototype': prototype,
            'prototype_norm': prototype_norm,
            'lexicon': rep_lex,
            'expanded_count': 0,
            'expansion_params': {'enabled': False, 'top_k': 0, 'min_sim': 0.0},
        }
        category_definitions[cat] = {
            'from': 'csv',
            'seeds': [],
            'found_seeds': found_words,
            'missing_seeds': missing_words,
            'prototype_dim': int(prototype.shape[0]) if isinstance(prototype, np.ndarray) else 0,
            'prototype_norm': prototype_norm,
            'representative_words': rep_lex,
            'lexicon': rep_lex,
            'expanded_neighbors': {},
        }
        oov_counts[cat] = len(missing_words)
    if any(oov_counts.values()):
        warnings.warn(f"OOV representative words: {oov_counts}")
    return category_states, category_definitions


def build_tr_edges(word_events: Sequence[Tuple[str, float, float]], tr_s: float) -> np.ndarray:
    if not word_events:
        return np.arange(0, tr_s, tr_s)
    max_end = max(end for _, _, end in word_events)
    n_tr = max(1, int(math.ceil(max_end / tr_s)))
    edges = np.arange(0.0, (n_tr + 1) * tr_s, tr_s, dtype=float)
    if edges[-1] < max_end:
        edges = np.append(edges, edges[-1] + tr_s)
    if edges[-1] < max_end - 1e-9:
        edges = np.append(edges, edges[-1] + tr_s)
    return edges


def lookup_embedding(
    token: str,
    primary_lookup: Dict[str, np.ndarray],
    fallback=None,
    *,
    return_source: bool = False,
) -> Optional[np.ndarray]:
    key = token.lower().strip()
    if not key:
        return (None, None) if return_source else None
    vec = primary_lookup.get(key) if primary_lookup else None
    if vec is not None:
        vec = np.asarray(vec, dtype=float)
        return (vec, 'primary') if return_source else vec
    fallback_vec = None
    if fallback is not None:
        try:
            if hasattr(fallback, 'get_vector') and key in fallback:
                fallback_vec = np.asarray(fallback.get_vector(key), dtype=float)
            elif hasattr(fallback, '__contains__') and key in fallback:
                fallback_vec = np.asarray(fallback[key], dtype=float)
        except Exception:
            fallback_vec = None
    if fallback_vec is not None:
        return (fallback_vec, 'fallback') if return_source else fallback_vec
    return (None, None) if return_source else None


def make_category_prototype(seeds: Sequence[str], primary_lookup: Dict[str, np.ndarray], fallback=None, allow_single: bool = False) -> Tuple[Optional[np.ndarray], List[str], List[str]]:
    found_vectors = []
    found_words = []
    missing_words = []
    for seed in seeds:
        vec = lookup_embedding(seed, primary_lookup, fallback)
        if vec is None:
            missing_words.append(seed)
            continue
        found_vectors.append(vec)
        found_words.append(seed)
    if not found_vectors:
        return None, found_words, missing_words
    if len(found_vectors) < 2 and not allow_single:
        warnings.warn(f'Only {len(found_vectors)} usable seed(s); enable allow_single_seed to accept singleton prototypes.')
        if not allow_single:
            return None, found_words, missing_words
    prototype = np.mean(found_vectors, axis=0)
    return prototype, found_words, missing_words


def expand_category(prototype: np.ndarray, vocab_embeddings: np.ndarray, vocab_words: Sequence[str], top_k: int, min_sim: float) -> Dict[str, float]:
    if prototype is None or vocab_embeddings is None or vocab_words is None:
        return {}
    proto = np.asarray(prototype, dtype=float)
    proto_norm = np.linalg.norm(proto)
    if proto_norm == 0:
        return {}
    proto_unit = proto / proto_norm
    vocab_norms = np.linalg.norm(vocab_embeddings, axis=1)
    valid_mask = vocab_norms > 0
    sims = np.full(vocab_embeddings.shape[0], -1.0, dtype=float)
    sims[valid_mask] = (vocab_embeddings[valid_mask] @ proto_unit) / vocab_norms[valid_mask]
    top_k_eff = min(top_k, len(sims))
    if top_k_eff <= 0:
        return {}
    candidate_idx = np.argpartition(-sims, top_k_eff - 1)[:top_k_eff]
    out = {}
    for idx in candidate_idx:
        score = float(sims[idx])
        if score < min_sim:
            continue
        out[vocab_words[idx]] = score
    return out


def tr_token_overlap(token_start: float, token_end: float, tr_start: float, tr_end: float, mode: str = 'proportional') -> float:
    token_start = float(token_start)
    token_end = float(token_end)
    if token_end <= token_start:
        token_end = token_start + 1e-3
    if mode == 'midpoint':
        midpoint = 0.5 * (token_start + token_end)
        return 1.0 if tr_start <= midpoint < tr_end else 0.0
    overlap = max(0.0, min(token_end, tr_end) - max(token_start, tr_start))
    duration = token_end - token_start
    if duration <= 0:
        return 1.0 if overlap > 0 else 0.0
    return max(0.0, min(1.0, overlap / duration))


def score_tr(token_payload: Sequence[Dict], method: str, *, lexicon: Optional[Dict[str, float]] = None, prototype: Optional[np.ndarray] = None, prototype_norm: Optional[float] = None) -> float:
    if not token_payload:
        return float('nan')
    method = method.lower()
    if method == 'count':
        if not lexicon:
            return float('nan')
        total = 0.0
        for item in token_payload:
            weight = lexicon.get(item['word'].lower())
            if weight is None:
                continue
            total += weight * item['overlap']
        return float(total)
    if method == 'similarity':
        if prototype is None or prototype_norm is None or prototype_norm < EPS:
            return float('nan')
        num = 0.0
        denom = 0.0
        for item in token_payload:
            emb = item.get('embedding')
            if emb is None:
                continue
            emb_norm = item.get('embedding_norm')
            if emb_norm is None or emb_norm < EPS:
                continue
            sim = float(np.dot(emb, prototype) / (emb_norm * prototype_norm))
            num += sim * item['overlap']
            denom += item['overlap']
        if denom == 0:
            return float('nan')
        value = num / denom
        return float(np.clip(value, -1.0, 1.0))
    raise ValueError(f'Unknown scoring method: {method}')


def ensure_serializable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, (np.floating, np.integer)):
        return obj.item()
    if isinstance(obj, dict):
        return {k: ensure_serializable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [ensure_serializable(v) for v in obj]
    return obj


def build_token_buckets(edges: np.ndarray, event_records: Sequence[Dict], mode: str = 'proportional') -> List[List[Dict]]:
    if edges.size < 2:
        return []
    buckets: List[List[Dict]] = [[] for _ in range(len(edges) - 1)]
    for rec in event_records:
        start = rec['start']
        end = rec['end']
        if end <= edges[0] or start >= edges[-1]:
            continue
        start_idx = max(0, int(np.searchsorted(edges, start, side='right')) - 1)
        end_idx = max(0, int(np.searchsorted(edges, end, side='left')))
        end_idx = min(end_idx, len(buckets) - 1)
        for idx in range(start_idx, end_idx + 1):
            bucket_start = edges[idx]
            bucket_end = edges[idx + 1]
            if mode == 'none':
                overlap = 1.0 if not (end <= bucket_start or start >= bucket_end) else 0.0
            else:
                overlap = tr_token_overlap(start, end, bucket_start, bucket_end, 'proportional')
            if overlap <= 0:
                continue
            buckets[idx].append({
                'word': rec['word'],
                'overlap': overlap,
                'embedding': rec['embedding'],
                'embedding_norm': rec['embedding_norm'],
                'token_start': rec['start'],
                'token_end': rec['end'],
                'bucket_start': bucket_start,
                'bucket_end': bucket_end,
            })
    return buckets


def score_time_series(edges: np.ndarray, buckets: Sequence[Sequence[Dict]], category_states: Dict[str, Dict], category_names: Sequence[str], category_columns: Sequence[str], method: str, index_name: str) -> Tuple[pd.DataFrame, np.ndarray]:
    n_bins = len(buckets)
    score_matrix = np.full((n_bins, len(category_names)), np.nan, dtype=float)
    for col_idx, cat_name in enumerate(category_names):
        state = category_states[cat_name]
        lexicon = state.get('lexicon')
        prototype = state.get('prototype')
        prototype_norm = state.get('prototype_norm')
        for bin_idx, bucket in enumerate(buckets):
            score_matrix[bin_idx, col_idx] = score_tr(bucket, method, lexicon=lexicon, prototype=prototype, prototype_norm=prototype_norm)
    data = {
        index_name: np.arange(n_bins, dtype=int),
        'start_sec': edges[:-1],
        'end_sec': edges[1:],
    }
    for col_idx, col in enumerate(category_columns):
        data[col] = score_matrix[:, col_idx]
    df = pd.DataFrame(data)
    return df, score_matrix


def build_smoothing_kernel(seconds_bin_width: float, smoothing_seconds: float, *, method: str = 'moving_average', gaussian_sigma_seconds: Optional[float] = None) -> np.ndarray:
    if smoothing_seconds <= 0:
        return np.array([1.0], dtype=float)
    method = str(method or 'moving_average').lower()
    if method == 'moving_average':
        window_samples = max(1, int(round(smoothing_seconds / seconds_bin_width)))
        if window_samples % 2 == 0:
            window_samples += 1
        kernel = np.ones(window_samples, dtype=float)
    elif method == 'gaussian':
        sigma_seconds = float(gaussian_sigma_seconds) if gaussian_sigma_seconds not in (None, '') else max(smoothing_seconds / 2.0, seconds_bin_width)
        sigma_samples = max(sigma_seconds / seconds_bin_width, 1e-6)
        half_width = max(1, int(round(3.0 * sigma_samples)))
        grid = np.arange(-half_width, half_width + 1, dtype=float)
        kernel = np.exp(-0.5 * (grid / sigma_samples) ** 2)
    else:
        raise ValueError(f"Unknown smoothing method: {method}")
    kernel_sum = float(kernel.sum())
    if kernel_sum <= 0:
        return np.array([1.0], dtype=float)
    return kernel / kernel_sum


def apply_smoothing_kernel(values: np.ndarray, kernel: np.ndarray, *, pad_mode: str = 'edge', eps: float = 1e-8) -> np.ndarray:
    if values.size == 0 or kernel.size <= 1:
        return values.copy()
    pad_mode = pad_mode if pad_mode in {'edge', 'reflect'} else 'edge'
    half = kernel.size // 2
    padded = np.pad(values, ((half, half), (0, 0)), mode=pad_mode)
    mask = np.isfinite(padded).astype(float)
    filled = np.where(mask, padded, 0.0)
    smoothed = np.empty((values.shape[0], values.shape[1]), dtype=float)
    for col in range(values.shape[1]):
        numerator = np.convolve(filled[:, col], kernel, mode='valid')
        denominator = np.convolve(mask[:, col], kernel, mode='valid')
        with np.errstate(divide='ignore', invalid='ignore'):
            smoothed_col = numerator / np.maximum(denominator, eps)
        smoothed_col[denominator < eps] = np.nan
        smoothed[:, col] = smoothed_col
    return smoothed


def aggregate_seconds_to_edges(canonical_edges: np.ndarray, canonical_values: np.ndarray, target_edges: np.ndarray) -> np.ndarray:
    if canonical_values.size == 0:
        return np.empty((len(target_edges) - 1, 0), dtype=float)
    midpoints = 0.5 * (canonical_edges[:-1] + canonical_edges[1:])
    bin_ids = np.digitize(midpoints, target_edges) - 1
    if bin_ids.size:
        bin_ids = np.clip(bin_ids, 0, len(target_edges) - 2)
    out = np.full((len(target_edges) - 1, canonical_values.shape[1]), np.nan, dtype=float)
    for idx in range(out.shape[0]):
        mask = bin_ids == idx
        if not np.any(mask):
            continue
        values = canonical_values[mask]
        if values.ndim == 1:
            values = values[:, None]
        finite_any = np.isfinite(values).any(axis=0)
        if not finite_any.any():
            continue
        col_means = np.full(values.shape[1], np.nan, dtype=float)
        col_means[finite_any] = np.nanmean(values[:, finite_any], axis=0)
        out[idx] = col_means
    return out


In [None]:
def generate_category_time_series(
    subject: str,
    story: str,
    *,
    cfg_base: Dict[str, Any],
    categories_cfg_base: Dict[str, Any],
    cluster_csv_path: str,
    temporal_weighting: str,
    prototype_weight_power: float,
    smoothing_seconds: float,
    smoothing_method: str,
    gaussian_sigma_seconds: Optional[float],
    smoothing_pad: str,
    seconds_bin_width: float,
    fallback_model: Optional[Any] = None,
    record_coverage: bool = True,
    save_outputs: bool = True,
) -> Dict[str, Any]:
    if not subject or not story:
        raise ValueError('Subject and story must be provided.')
    print(f"=== Day21 category build for {subject} / {story} ===")

    categories_cfg = json.loads(json.dumps(categories_cfg_base or {}))
    categories_cfg['seconds_bin_width'] = float(seconds_bin_width)
    category_sets = categories_cfg.get('sets', {})
    available_sets = sorted(category_sets.keys())
    category_set_name = categories_cfg.get('category_set') or (available_sets[0] if available_sets else None)
    if cluster_csv_path:
        if not category_set_name:
            category_set_name = 'csv_clusters'
        categories_cfg['category_set'] = category_set_name
        categories_cfg['category_score_method'] = 'similarity'
        categories_cfg['allow_single_seed'] = True
        categories_cfg['expansion'] = {'enabled': False}
    category_score_method = str(categories_cfg.get('category_score_method', 'similarity')).lower()
    overlap_mode = str(categories_cfg.get('overlap_weighting', 'proportional')).lower()
    expansion_cfg = categories_cfg.get('expansion', {})
    allow_single = bool(categories_cfg.get('allow_single_seed', False))
    exp_enabled = bool(expansion_cfg.get('enabled', True))
    exp_top_k = int(expansion_cfg.get('top_k', 2000)) if exp_enabled else 0
    exp_min_sim = float(expansion_cfg.get('min_sim', 0.35)) if exp_enabled else 0.0

    selected_set_spec = category_sets.get(category_set_name, {}) if category_sets else {}

    output_root = features_root / 'subjects' / subject / story
    canonical_root = features_root / 'stories' / story
    if save_outputs:
        output_root.mkdir(parents=True, exist_ok=True)
        canonical_root.mkdir(parents=True, exist_ok=True)

    story_events = load_story_words(paths, subject, story)
    print(f'Loaded {len(story_events)} transcript events.')
    tr_edges = build_tr_edges(story_events, TR)
    n_tr = len(tr_edges) - 1
    print(f'TR edges: {len(tr_edges)} (n_tr={n_tr}) spanning {tr_edges[-1]:.2f} seconds.')

    embedding_source = str(categories_cfg.get('embedding_source', 'english1000')).lower()
    english_loader = None
    english_lookup: Dict[str, np.ndarray] = {}
    english_vocab: List[str] = []
    english_matrix = None
    if embedding_source in {'english1000', 'both'}:
        english1000_path = Path(paths.get('data_root', '')) / 'derivative' / 'english1000sm.hf5'
        if english1000_path.exists():
            english_loader = English1000Loader(english1000_path)
            english_lookup = english_loader.lookup
            english_vocab = english_loader.vocab
            english_matrix = english_loader.embeddings
            print(f'Loaded English1000 embeddings from {english1000_path} (vocab={len(english_vocab)}).')
        else:
            raise FileNotFoundError(f'English1000 embeddings not found at {english1000_path}')
    else:
        print('English1000 disabled by configuration.')

    word2vec_model = fallback_model
    if word2vec_model is None and embedding_source in {'word2vec', 'both'}:
        w2v_path = categories_cfg.get('word2vec_path')
        if w2v_path:
            w2v_path = Path(w2v_path)
            if w2v_path.exists():
                try:
                    if KeyedVectors is None:
                        raise ImportError('gensim is required for word2vec fallback loading.')
                    binary = w2v_path.suffix.lower() in {'.bin', '.gz'}
                    word2vec_model = KeyedVectors.load_word2vec_format(w2v_path, binary=binary)
                    print(f'Loaded Word2Vec fallback from {w2v_path}.')
                except Exception as exc:
                    warnings.warn(f'Failed to load Word2Vec fallback: {exc}')
            else:
                warnings.warn(f'Word2Vec path does not exist: {w2v_path}')
        else:
            warnings.warn('Word2Vec fallback requested but no path provided.')
    elif word2vec_model is not None:
        print('Using pre-loaded fallback embedding model (aligned).')
    else:
        print('Word2Vec fallback disabled.')

    if cluster_csv_path:
        csv_clusters = load_clusters_from_csv(cluster_csv_path)
        category_states, category_definitions = build_states_from_csv(
            csv_clusters,
            english_lookup,
            word2vec_model,
            weight_power=prototype_weight_power,
        )
        category_names = sorted(category_states.keys())
        category_columns = [f'cat_{name}' for name in category_names]
        print(f"Loaded {len(category_names)} CSV-driven categories from {cluster_csv_path}: {category_names}")
        zero_norm = [k for k, v in category_states.items() if v.get('prototype') is not None and (v.get('prototype_norm') or 0.0) < EPS]
        if zero_norm:
            warnings.warn(f"Zero-norm prototypes (check OOV/weights): {zero_norm}")
    else:
        category_states = {}
        category_definitions = {}
        seed_oov_counter = Counter()
        for cat_name, cat_spec in selected_set_spec.items():
            seeds = cat_spec.get('seeds', [])
            explicit_words = cat_spec.get('words', [])
            prototype = None
            found_seeds: List[str] = []
            missing_seeds: List[str] = []
            if seeds:
                prototype, found_seeds, missing_seeds = make_category_prototype(seeds, english_lookup, word2vec_model, allow_single)
                seed_oov_counter[cat_name] = len(missing_seeds)
                if prototype is None and category_score_method == 'similarity':
                    warnings.warn(f"Category '{cat_name}' has no usable prototype; TR scores will be NaN.")
            elif category_score_method == 'similarity':
                warnings.warn(f'Category {cat_name} has no seeds; similarity method will yield NaNs.')
            lexicon = {word.lower(): 1.0 for word in explicit_words}
            for seed in found_seeds:
                lexicon.setdefault(seed.lower(), 1.0)
            prototype_norm = None
            expanded_words = {}
            if prototype is not None:
                prototype_norm = float(np.linalg.norm(prototype))
                if exp_enabled and english_matrix is not None:
                    expanded_words = expand_category(prototype, english_matrix, english_vocab, exp_top_k, exp_min_sim)
                    for word, weight in expanded_words.items():
                        lexicon.setdefault(word.lower(), float(weight))
            if not lexicon and category_score_method == 'count':
                warnings.warn(f'Category {cat_name} lexicon is empty; counts will be NaN.')
            category_states[cat_name] = {
                'name': cat_name,
                'seeds': seeds,
                'found_seeds': found_seeds,
                'missing_seeds': missing_seeds,
                'prototype': prototype,
                'prototype_norm': prototype_norm,
                'lexicon': lexicon,
                'expanded_count': len(expanded_words),
                'expansion_params': {
                    'enabled': exp_enabled,
                    'top_k': exp_top_k,
                    'min_sim': exp_min_sim,
                },
            }
            category_definitions[cat_name] = {
                'seeds': seeds,
                'found_seeds': found_seeds,
                'missing_seeds': missing_seeds,
                'prototype_dim': int(prototype.shape[0]) if isinstance(prototype, np.ndarray) else 0,
                'prototype_norm': prototype_norm,
                'expanded_neighbors': ensure_serializable(expanded_words),
                'lexicon': {word: float(weight) for word, weight in sorted(category_states[cat_name]['lexicon'].items())},
            }
        print('Category seeds missing counts:', dict(seed_oov_counter))
        category_names = sorted(category_states.keys())
        category_columns = [f'cat_{name}' for name in category_names]
        print(f'Prepared {len(category_names)} categories: {category_names}')

    tw_mode = str(temporal_weighting or 'proportional').lower()
    if tw_mode not in {'proportional', 'none', 'midpoint'}:
        raise ValueError(f'Unsupported temporal weighting: {tw_mode}')

    seconds_bin_width = float(seconds_bin_width)
    if seconds_bin_width <= 0:
        raise ValueError('seconds_bin_width must be positive.')
    smoothing_method = str(smoothing_method or 'moving_average').lower()
    gaussian_sigma_seconds = gaussian_sigma_seconds if gaussian_sigma_seconds not in (None, '') else None
    smoothing_pad = str(smoothing_pad or 'edge').lower()
    if smoothing_pad not in {'edge', 'reflect'}:
        smoothing_pad = 'edge'

    embedding_cache: Dict[str, Tuple[Optional[np.ndarray], Optional[str]]] = {}
    event_records: List[Dict] = []
    tokens_with_embeddings = 0
    tokens_primary_hits = 0
    tokens_fallback_hits = 0
    unique_all: set[str] = set()
    unique_primary: set[str] = set()
    unique_fallback: set[str] = set()
    for word, onset, offset in story_events:
        token = word.strip()
        if not token:
            continue
        key = token.lower()
        unique_all.add(key)
        if key not in embedding_cache:
            emb, source = lookup_embedding(token, english_lookup, word2vec_model, return_source=True)
            embedding_cache[key] = (emb, source)
        else:
            emb, source = embedding_cache[key]
        emb_norm = float(np.linalg.norm(emb)) if emb is not None else None
        if emb is not None:
            tokens_with_embeddings += 1
            if source == 'primary':
                tokens_primary_hits += 1
                unique_primary.add(key)
            elif source == 'fallback':
                tokens_fallback_hits += 1
                unique_fallback.add(key)
        event_records.append({
            'word': token,
            'start': float(onset),
            'end': float(offset),
            'embedding': emb,
            'embedding_norm': emb_norm,
        })

    total_tokens = len(event_records)
    combined_hits = tokens_primary_hits + tokens_fallback_hits
    primary_pct = (100.0 * tokens_primary_hits / total_tokens) if total_tokens else 0.0
    fallback_pct = (100.0 * tokens_fallback_hits / total_tokens) if total_tokens else 0.0
    combined_pct = (100.0 * combined_hits / total_tokens) if total_tokens else 0.0
    unique_total = len(unique_all)
    unique_combined = len(unique_primary | unique_fallback)
    tokens_oov = total_tokens - combined_hits
    print(f'Tokens with embeddings: {combined_hits}/{total_tokens} (OOV rate={tokens_oov / max(total_tokens, 1):.2%}).')
    if record_coverage:
        print(
            f'Token coverage (primary={primary_pct:.2f}%, fallback={fallback_pct:.2f}%, combined={combined_pct:.2f}%)'
        )

    coverage = None
    if record_coverage:
        coverage = {
            'tokens_total': total_tokens,
            'tokens_primary': tokens_primary_hits,
            'tokens_fallback': tokens_fallback_hits,
            'tokens_combined': combined_hits,
            'tokens_oov': tokens_oov,
            'pct_tokens_primary': primary_pct,
            'pct_tokens_fallback': fallback_pct,
            'pct_tokens_combined': combined_pct,
            'unique_total': unique_total,
            'unique_primary': len(unique_primary),
            'unique_fallback': len(unique_fallback),
            'unique_combined': unique_combined,
            'pct_unique_primary': (100.0 * len(unique_primary) / unique_total) if unique_total else 0.0,
            'pct_unique_fallback': (100.0 * len(unique_fallback) / unique_total) if unique_total else 0.0,
            'pct_unique_combined': (100.0 * unique_combined / unique_total) if unique_total else 0.0,
        }
    if not event_records:
        raise ValueError('No token events available for category featurization.')

    max_end_time = max(rec['end'] for rec in event_records)
    canonical_edges = np.arange(0.0, max_end_time + seconds_bin_width, seconds_bin_width, dtype=float)
    if canonical_edges[-1] < max_end_time:
        canonical_edges = np.append(canonical_edges, canonical_edges[-1] + seconds_bin_width)
    if canonical_edges[-1] < max_end_time - 1e-9:
        canonical_edges = np.append(canonical_edges, canonical_edges[-1] + seconds_bin_width)
    assert np.all(np.diff(canonical_edges) > 0), 'Non-monotone canonical edges.'

    canonical_buckets = build_token_buckets(canonical_edges, event_records, tw_mode)
    empty_canonical = sum(1 for bucket in canonical_buckets if not bucket)
    print(f'Canonical bins without tokens: {empty_canonical}/{len(canonical_buckets)}')

    canonical_df_raw, canonical_matrix = score_time_series(
        canonical_edges,
        canonical_buckets,
        category_states,
        category_names,
        category_columns,
        category_score_method,
        index_name='bin_index',
    )
    canonical_values_raw = canonical_matrix.copy()
    smoothing_kernel = build_smoothing_kernel(
        seconds_bin_width,
        smoothing_seconds,
        method=smoothing_method,
        gaussian_sigma_seconds=gaussian_sigma_seconds,
    )
    smoothing_applied = smoothing_kernel.size > 1
    if canonical_values_raw.size and smoothing_applied:
        canonical_values_smoothed = apply_smoothing_kernel(canonical_values_raw, smoothing_kernel, pad_mode=smoothing_pad)
    else:
        canonical_values_smoothed = canonical_values_raw.copy()

    canonical_df_smoothed = canonical_df_raw.copy()
    if category_columns:
        canonical_df_smoothed.loc[:, category_columns] = canonical_values_smoothed
    canonical_df_selected = canonical_df_smoothed if smoothing_applied else canonical_df_raw

    if save_outputs:
        canonical_root.mkdir(parents=True, exist_ok=True)
        canonical_csv_path = canonical_root / 'category_timeseries_seconds.csv'
        canonical_df_selected.to_csv(canonical_csv_path, index=False)
        if smoothing_applied:
            canonical_df_raw.to_csv(canonical_root / 'category_timeseries_seconds_raw.csv', index=False)
        canonical_definition_path = canonical_root / 'category_definition.json'
        with canonical_definition_path.open('w') as fh:
            json.dump(ensure_serializable(category_definitions), fh, indent=2)
        print(f'Saved canonical story series to {canonical_csv_path}')

    tr_buckets = build_token_buckets(tr_edges, event_records, tw_mode)
    empty_tr = sum(1 for bucket in tr_buckets if not bucket)
    print(f'TRs without tokens: {empty_tr}/{len(tr_buckets)}')

    if category_columns:
        tr_values_raw = aggregate_seconds_to_edges(canonical_edges, canonical_values_raw, tr_edges)
        tr_values_smoothed = aggregate_seconds_to_edges(canonical_edges, canonical_values_smoothed, tr_edges)
    else:
        tr_values_raw = np.empty((len(tr_edges) - 1, 0), dtype=float)
        tr_values_smoothed = tr_values_raw

    base_index = np.arange(len(tr_edges) - 1, dtype=int)
    base_df = pd.DataFrame({'tr_index': base_index, 'start_sec': tr_edges[:-1], 'end_sec': tr_edges[1:]})
    category_df_raw = base_df.copy()
    category_df_smoothed = base_df.copy()
    if category_columns:
        category_df_raw.loc[:, category_columns] = tr_values_raw
        category_df_smoothed.loc[:, category_columns] = tr_values_smoothed
    category_df = category_df_smoothed if smoothing_applied else category_df_raw
    print(category_df.head())

    if category_score_method == 'similarity' and category_columns:
        finite_vals = category_df[category_columns].to_numpy(dtype=float)
        finite_vals = finite_vals[np.isfinite(finite_vals)]
        if finite_vals.size:
            assert np.nanmin(finite_vals) >= -1.0001 and np.nanmax(finite_vals) <= 1.0001, 'Similarity scores out of bounds.'
    else:
        if category_columns:
            assert (category_df[category_columns].fillna(0.0) >= -1e-9).all().all(), 'Count scores must be non-negative.'

    if save_outputs:
        output_root.mkdir(parents=True, exist_ok=True)
        category_csv_path = output_root / 'category_timeseries.csv'
        category_df.to_csv(category_csv_path, index=False)
        if smoothing_applied:
            category_df_raw.to_csv(output_root / 'category_timeseries_raw.csv', index=False)
        definition_path = output_root / 'category_definition.json'
        with definition_path.open('w') as fh:
            json.dump(ensure_serializable(category_definitions), fh, indent=2)
        print(f'Saved category time series to {category_csv_path}')

    trimmed_path = Path(paths.get('figs', 'figs')) / subject / story / 'day16_decoding' / 'semantic_pcs_trimmed.csv'
    max_lag_primary = 0
    trimmed_df = None
    if trimmed_path.exists():
        day16_trim = pd.read_csv(trimmed_path)
        expected_len = len(day16_trim)
        if len(day16_trim) > len(category_df):
            raise ValueError('Day16 trimmed series longer than category series; regenerate Day16 or rerun Day17.')
        max_lag_primary = max(0, len(category_df) - expected_len)
        trimmed_df = category_df.iloc[max_lag_primary:].reset_index(drop=True)
        if save_outputs:
            trimmed_out = trimmed_df.copy()
            trimmed_out.insert(0, 'trim_index', np.arange(len(trimmed_out), dtype=int))
            trimmed_out.drop(columns=['tr_index'], inplace=True, errors='ignore')
            trimmed_out.to_csv(output_root / 'category_timeseries_trimmed.csv', index=False)
            print(f'Saved trimmed category series to {output_root / "category_timeseries_trimmed.csv"}')
    else:
        warnings.warn('Day16 trimmed PCs not found; skipping auto-alignment.')

    smoothing_meta = {
        'applied': bool(smoothing_applied),
        'seconds': smoothing_seconds,
        'method': smoothing_method,
        'gaussian_sigma_seconds': float(gaussian_sigma_seconds) if gaussian_sigma_seconds is not None else None,
        'kernel_size': int(smoothing_kernel.size),
        'pad_mode': smoothing_pad,
        'bin_width_seconds': seconds_bin_width,
    }

    return {
        'subject': subject,
        'story': story,
        'temporal_weighting': tw_mode,
        'category_columns': category_columns,
        'category_states': category_states,
        'category_definitions': category_definitions,
        'category_score_method': category_score_method,
        'event_records': event_records,
        'canonical_buckets': canonical_buckets,
        'tr_buckets': tr_buckets,
        'canonical_df_raw': canonical_df_raw,
        'canonical_df_smoothed': canonical_df_smoothed,
        'canonical_df_selected': canonical_df_selected,
        'category_df_raw': category_df_raw,
        'category_df_smoothed': category_df_smoothed,
        'category_df_selected': category_df,
        'canonical_edges': canonical_edges,
        'tr_edges': tr_edges,
        'smoothing': smoothing_meta,
        'output_root': output_root,
        'canonical_root': canonical_root,
        'trimmed_df': trimmed_df,
        'coverage': coverage,
        'max_lag_primary': max_lag_primary,
    }



In [None]:
# --- Subject/story discovery ---------------------------------------------------------

def discover_subject_story_pairs(
    paths: Dict[str, Any],
    *,
    default_subject: str,
    default_story: str,
    subject_overrides: Sequence[str] | None = None,
    story_overrides: Sequence[str] | None = None,
) -> List[Tuple[str, str]]:
    combos = set()
    failures: List[Dict[str, str]] = []

    def _canon(name: str) -> str:
        return ''.join(ch for ch in name.lower() if ch.isalnum())

    transcript_variants: set[str] = set()

    candidate_transcript_roots: List[Path] = []
    transcripts_root_cfg = paths.get('transcripts')
    if transcripts_root_cfg:
        candidate_transcript_roots.append(Path(transcripts_root_cfg))
    data_root_cfg = paths.get('data_root')
    if data_root_cfg:
        dr_cfg = Path(data_root_cfg)
        candidate_transcript_roots.append(dr_cfg / 'derivative' / 'TextGrids')
        candidate_transcript_roots.append(dr_cfg / 'derivatives' / 'TextGrids')

    for root in candidate_transcript_roots:
        if not root.exists():
            continue
        for path in root.glob('**/*'):
            if path.is_file():
                transcript_variants.add(_canon(path.stem))

    def _norm(seq: Optional[Sequence[str]]) -> List[str]:
        return [str(item).strip() for item in (seq or []) if str(item).strip()]

    subject_overrides = _norm(subject_overrides)
    story_overrides = _norm(story_overrides)

    cache_root = Path(paths.get('cache', 'data_cache'))
    if cache_root.exists():
        for subj_dir in sorted(p for p in cache_root.iterdir() if p.is_dir()):
            for story_dir in sorted(p for p in subj_dir.iterdir() if p.is_dir()):
                combos.add((subj_dir.name, story_dir.name))

    transcripts_root = paths.get('transcripts')
    if transcripts_root:
        tr_root = Path(transcripts_root)
        if tr_root.exists():
            for subj_dir in sorted(p for p in tr_root.iterdir() if p.is_dir()):
                story_dirs = [d for d in subj_dir.iterdir() if d.is_dir()]
                if story_dirs:
                    for story_dir in story_dirs:
                        combos.add((subj_dir.name, story_dir.name))
                        transcript_variants.add(_canon(story_dir.name))
                else:
                    for file in subj_dir.glob('*.*'):
                        if file.is_file():
                            combos.add((subj_dir.name, file.stem))
                            transcript_variants.add(_canon(file.stem))
            for file in tr_root.glob('*.*'):
                if file.is_file():
                    transcript_variants.add(_canon(file.stem))
                if file.is_file() and '_' in file.stem:
                    sub_name, story_name = file.stem.split('_', 1)
                    combos.add((sub_name, story_name))
                    transcript_variants.add(_canon(story_name))

    data_root = paths.get('data_root')
    if data_root:
        dr_root = Path(data_root)
        if dr_root.exists():
            for subj_dir in sorted(dr_root.glob('sub-*')):
                subject_name = subj_dir.name.replace('sub-', '', 1)
                for func_dir in sorted(subj_dir.glob('ses-*/func')):
                    for bold_file in func_dir.glob('*task-*_bold.nii.gz'):
                        task_part = bold_file.name.split('task-')[-1]
                        task_part = task_part.split('_', 1)[0]
                        if not task_part:
                            continue
                        canon = _canon(task_part)
                        if transcript_variants and canon not in transcript_variants:
                            continue
                        combos.add((subject_name, task_part))

    if subject_overrides and story_overrides:
        combos.update((sub, story) for sub in subject_overrides for story in story_overrides)
    elif subject_overrides:
        fallback_stories = {story for _, story in combos} or {default_story}
        combos.update((sub, story) for sub in subject_overrides for story in fallback_stories)
    elif story_overrides:
        fallback_subjects = {sub for sub, _ in combos} or {default_subject}
        combos.update((sub, story) for sub in fallback_subjects for story in story_overrides)

    combos.add((default_subject, default_story))

    valid_pairs: List[Tuple[str, str]] = []
    for subject, story in sorted(combos):
        try:
            load_story_words(paths, subject, story)
        except FileNotFoundError:
            warnings.warn(f'Skipping {subject} / {story}: transcript not found.')
            failures.append({'subject': subject, 'story': story, 'error': 'transcript not found'})
            continue
        except Exception as exc:
            warnings.warn(f'Skipping {subject} / {story}: failed to load transcript ({exc}).')
            failures.append({'subject': subject, 'story': story, 'error': str(exc)})
            continue
        valid_pairs.append((subject, story))

    if not valid_pairs:
        raise RuntimeError('No valid subject/story pairs discovered; check transcript paths and overrides.')
    return valid_pairs, failures


subject_story_pairs, transcript_failures = discover_subject_story_pairs(
    paths,
    default_subject=SUBJECT,
    default_story=STORY,
    subject_overrides=BATCH_SUBJECTS,
    story_overrides=BATCH_STORIES,
)
print(f'Discovered {len(subject_story_pairs)} subject/story pairs:')
for sub, story in subject_story_pairs:
    print(f'  - {sub} / {story}')
if transcript_failures:
    print(f"{len(transcript_failures)} transcript(s) could not be parsed; see `transcript_failures` for details.")


In [None]:
# --- Transcript diagnostics ---------------------------------------------------
if transcript_failures:
    issues_df = pd.DataFrame(transcript_failures).sort_values(['subject', 'story']).reset_index(drop=True)
    display(issues_df)
else:
    print('No transcript parsing issues detected.')

In [None]:
# Optional: peek at the first few lines of problematic transcript files
from src.decoding.text_align import _candidate_paths

def preview_transcript_file(subject: str, story: str, *, max_lines: int = 8) -> None:
    """Print the first few lines of the first matching transcript file."""
    for cand in _candidate_paths(paths, subject, story):
        if cand.exists():
            print(f'Preview for {subject} / {story} -> {cand}')
            with cand.open('r', encoding='utf-8', errors='replace') as fh:
                for idx, line in enumerate(fh):
                    if idx >= max_lines:
                        break
                    print(line.rstrip())
            print('-' * 60)
            return
    print(f'No transcript file found for {subject} / {story}.')

In [None]:
# --- Batch generation ---------------------------------------------------------------
results_summary: List[Dict[str, Any]] = []
coverage_rows: List[Dict[str, Any]] = []
for subject, story in subject_story_pairs:
    print(f">>> Generating categories for {subject} / {story}")
    try:
        result = generate_category_time_series(
            subject,
            story,
            cfg_base=cfg,
            categories_cfg_base=categories_cfg,
            cluster_csv_path=cluster_csv_path,
            temporal_weighting=TEMPORAL_WEIGHTING,
            prototype_weight_power=prototype_weight_power,
            smoothing_seconds=SMOOTHING_SECONDS,
            smoothing_method=SMOOTHING_METHOD,
            gaussian_sigma_seconds=GAUSSIAN_SIGMA_SECONDS,
            smoothing_pad=SMOOTHING_PAD_MODE,
            seconds_bin_width=SECONDS_BIN_WIDTH,
            fallback_model=FALLBACK_ALIGNED,
            record_coverage=REPORT_TOKEN_COVERAGE,
            save_outputs=SAVE_OUTPUTS,
        )
    except Exception as exc:
        warnings.warn(f'FAILED: {subject} / {story} ({exc})')
        results_summary.append({
            'subject': subject,
            'story': story,
            'status': 'error',
            'error': str(exc),
        })
        continue

    coverage = result.get('coverage')
    if coverage:
        row = {'subject': subject, 'story': story}
        row.update(coverage)
        coverage_rows.append(row)

    results_summary.append({
        'subject': subject,
        'story': story,
        'status': 'ok',
        'tokens': len(result.get('event_records', [])),
        'canonical_bins': len(result.get('canonical_edges', [])) - 1,
        'trs': len(result.get('tr_edges', [])) - 1,
        'categories': len(result.get('category_columns', [])),
        'subject_features': str(result.get('output_root')),
        'story_features': str(result.get('canonical_root')),
    })

summary_df = pd.DataFrame(results_summary)
if not summary_df.empty:
    summary_df = summary_df.sort_values(['status', 'subject', 'story']).reset_index(drop=True)
display(summary_df)

coverage_df = pd.DataFrame(coverage_rows)
if not coverage_df.empty:
    coverage_df = coverage_df.sort_values(['pct_tokens_combined', 'subject', 'story']).reset_index(drop=True)
    combined_total = coverage_df['tokens_total'].sum()
    combined_hits = coverage_df['tokens_combined'].sum()
    combined_pct = 100.0 * combined_hits / combined_total if combined_total else 0.0
    print(f"Combined coverage across all processed stories: {combined_hits}/{combined_total} ({combined_pct:.2f}%)")
    print('Lowest combined coverage stories:')
    display(coverage_df[['subject', 'story', 'tokens_total', 'tokens_oov', 'pct_tokens_combined', 'pct_unique_combined']].head(10))
else:
    print('No coverage data collected; ensure REPORT_TOKEN_COVERAGE is enabled.')


## Next steps

- Inspect `summary_df` to confirm every pair completed.
- Review `coverage_df` to see combined token coverage and spot low-coverage stories.
- The generated CSVs live under `features/subjects/<subject>/<story>/` and `features/stories/<story>/`.
- Re-run this notebook whenever transcript corrections, embeddings, or smoothing settings change.
