# DayXX: QA-Emb Time Series

Build QA-Emb question-wise time series for ds003020 transcripts (token n-grams → QA scores → 50 ms canonical smoothing → TR aggregation), with plots and a transcript/question alignment explorer.

Run the QA-Emb encoding on a GPU node (load cuda/python modules, activate .venv, set HF_TOKEN, ensure `torch.cuda.is_available()` is True); caching is enabled so once `/featurestest/qaemb/..._qaemb_tokens.npy` is written you can reload it instead of re-encoding.

In [None]:
import sys, os
from pathlib import Path

PROJECT_ROOT = Path("/flash/PaoU/seann/fmri-edm-ccm")
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import json, math, os, sys, warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch

from src.utils import load_yaml
from src.decoding import load_transcript_words
from src.day19_category_builder import (
    build_tr_edges,
    tr_token_overlap,
    build_smoothing_kernel,
    apply_smoothing_kernel,
    aggregate_seconds_to_edges,
)


In [None]:
cfg = load_yaml('/flash/PaoU/seann/fmri-edm-ccm/configs/demo.yaml')
paths = cfg.get('paths', {})
TR = float(cfg.get('TR', 2.0))

SUBJECT = cfg.get('subject') or 'UTS01'
STORY = cfg.get('story') or 'wheretheressmoke'

qa_cfg = cfg.get('qa_emb', {}) or {}
QA_QUESTIONS_PATH = qa_cfg.get('questions_path', 'configs/qaemb_questions.json')
QA_CHECKPOINT = qa_cfg.get('checkpoint', 'meta-llama/Meta-Llama-3-8B-Instruct')
QA_NGRAM_SIZE = int(qa_cfg.get('ngram_size', 10))
QA_USE_CACHE = bool(qa_cfg.get('use_cache', True))
SECONDS_BIN_WIDTH = float(qa_cfg.get('seconds_bin_width', 0.05))
SMOOTHING_SECONDS = float(qa_cfg.get('smoothing_seconds', 1.0))
SMOOTHING_METHOD = qa_cfg.get('smoothing_method', 'moving_average')
GAUSSIAN_SIGMA_SECONDS = qa_cfg.get('gaussian_sigma_seconds', 0.5 * SMOOTHING_SECONDS)
SMOOTHING_PAD_MODE = qa_cfg.get('smoothing_pad_mode', 'reflect')
SAVE_OUTPUTS = bool(qa_cfg.get('save_outputs', True))

features_root = Path(paths.get('featurestest', 'featurestest')) / 'qaemb'
features_root.mkdir(parents=True, exist_ok=True)
print(f'Using features root: {features_root}')


In [None]:
questions_path = PROJECT_ROOT / QA_QUESTIONS_PATH
if not questions_path.exists():
    raise FileNotFoundError(f'QA question file not found at {questions_path}')
with questions_path.open() as fh:
    QA_QUESTIONS = json.load(fh)
if not isinstance(QA_QUESTIONS, list) or not all(isinstance(q, str) for q in QA_QUESTIONS):
    raise ValueError('QA questions JSON must be a list of strings.')

n_questions = len(QA_QUESTIONS)
print(f'Loaded {n_questions} QA questions from {questions_path}')


def abbreviate_question(q: str, max_words: int = 3) -> str:
    q = str(q).strip()
    if q.endswith('?'):
        q = q[:-1]
    words = q.split()
    return ' '.join(words[:max_words])


QA_ABBREVS = [abbreviate_question(q) for q in QA_QUESTIONS]

In [None]:
story_events = load_transcript_words(paths, SUBJECT, STORY)
if not story_events:
    raise ValueError(f'No transcript events found for {SUBJECT} {STORY}.')

print(f'Loaded {len(story_events)} transcript tokens for {SUBJECT} / {STORY}.')

token_df = pd.DataFrame(story_events, columns=['word', 'start', 'end'])
token_df['word'] = token_df['word'].astype(str).str.strip()
token_df['midpoint'] = 0.5 * (token_df['start'] + token_df['end'])
token_df['token_index'] = np.arange(len(token_df))

all_words = token_df['word'].tolist()


def make_ngram_text(words: Sequence[str], i: int, n: int) -> str:
    start = max(0, i - n + 1)
    return ' '.join(words[start:i + 1])


examples = [make_ngram_text(all_words, i, QA_NGRAM_SIZE) for i in range(len(all_words))]
print(f'Prepared {len(examples)} QA-Emb inputs with ngram size {QA_NGRAM_SIZE}.')

In [None]:

try:
    from imodelsx import QAEmb
except ImportError as exc:
    raise ImportError('Install imodelsx to run QA-Emb encoding (e.g., `pip install imodelsx`).') from exc

# Ensure HF token is available for gated checkpoints
os.environ.setdefault("HF_TOKEN", "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")

QA_USE_CACHE = True  # force caching for expensive QAEmb encoding

qa_root = features_root / 'tokens' / SUBJECT
qa_root.mkdir(parents=True, exist_ok=True)
qa_file = qa_root / f'{STORY}_qaemb_tokens.npy'
qa_questions_out = qa_root / f'{STORY}_qaemb_questions.json'

print(f"cuda_available: {torch.cuda.is_available()}")

qa_matrix = None
if qa_file.exists():
    cached = np.load(qa_file)
    if cached.shape == (len(token_df), n_questions):
        qa_matrix = cached
        print(f'Loaded cached QA embeddings from {qa_file}')
    else:
        warnings.warn(f'Cached QA embeddings had shape {cached.shape}; expected {(len(token_df), n_questions)}. Recomputing.')

if qa_matrix is None:
    if "HF_TOKEN" not in os.environ:
        warnings.warn('HF_TOKEN is not set; gated checkpoints will fail. Set it before rerunning.')
    embedder = QAEmb(
        questions=QA_QUESTIONS,
        checkpoint=QA_CHECKPOINT,
        use_cache=QA_USE_CACHE,
    )

    BATCH_SIZE = 128
    qa_rows = []
    for start in range(0, len(examples), BATCH_SIZE):
        batch = examples[start:start + BATCH_SIZE]
        emb = embedder(batch)
        qa_rows.append(np.asarray(emb, dtype=float))

    qa_matrix = np.vstack(qa_rows) if qa_rows else np.empty((0, n_questions), dtype=float)
    np.save(qa_file, qa_matrix)
    with qa_questions_out.open('w') as fh:
        json.dump(QA_QUESTIONS, fh, indent=2)
    print(f'Saved token-level QA features to {qa_file}')
else:
    if not qa_questions_out.exists():
        with qa_questions_out.open('w') as fh:
            json.dump(QA_QUESTIONS, fh, indent=2)

assert qa_matrix.shape[0] == len(token_df), 'QA matrix rows must match tokens.'
assert qa_matrix.shape[1] == n_questions, 'QA matrix columns must match question count.'


In [None]:
tr_edges = build_tr_edges(story_events, TR)

max_end_time = float(token_df['end'].max())
canonical_edges = np.arange(0.0, max_end_time + SECONDS_BIN_WIDTH, SECONDS_BIN_WIDTH, dtype=float)
if canonical_edges[-1] < max_end_time:
    canonical_edges = np.append(canonical_edges, canonical_edges[-1] + SECONDS_BIN_WIDTH)
if canonical_edges[-1] < max_end_time - 1e-9:
    canonical_edges = np.append(canonical_edges, canonical_edges[-1] + SECONDS_BIN_WIDTH)

assert np.all(np.diff(canonical_edges) > 0), 'Non-monotone canonical edges.'
print(f'Canonical bins: {len(canonical_edges) - 1}, TR bins: {len(tr_edges) - 1}')

In [None]:
event_records: List[Dict] = []
for i, row in token_df.iterrows():
    event_records.append(
        {
            'word': row['word'],
            'start': float(row['start']),
            'end': float(row['end']),
            'qa_vec': qa_matrix[i].astype(float),
        }
    )


def build_token_buckets(edges: np.ndarray, event_records: Sequence[Dict], mode: str = 'proportional') -> List[List[Dict]]:
    if edges.size < 2:
        return []
    buckets: List[List[Dict]] = [[] for _ in range(len(edges) - 1)]
    for rec in event_records:
        start = rec['start']
        end = rec['end']
        if end <= edges[0] or start >= edges[-1]:
            continue
        start_idx = max(0, int(np.searchsorted(edges, start, side='right')) - 1)
        end_idx = max(0, int(np.searchsorted(edges, end, side='left')))
        end_idx = min(end_idx, len(buckets) - 1)
        for idx in range(start_idx, end_idx + 1):
            bucket_start = edges[idx]
            bucket_end = edges[idx + 1]
            overlap = tr_token_overlap(start, end, bucket_start, bucket_end, 'proportional')
            if overlap <= 0:
                continue
            item = {
                'word': rec['word'],
                'overlap': overlap,
                'token_start': start,
                'token_end': end,
                'bucket_start': bucket_start,
                'bucket_end': bucket_end,
            }
            if 'qa_vec' in rec:
                item['qa_vec'] = rec['qa_vec']
            buckets[idx].append(item)
    return buckets


canonical_buckets = build_token_buckets(canonical_edges, event_records, mode='proportional')
tr_buckets = build_token_buckets(tr_edges, event_records, mode='proportional')
print(
    f'Canonical bins without tokens: {sum(1 for b in canonical_buckets if not b)}/{len(canonical_buckets)}; '
    f'TR bins without tokens: {sum(1 for b in tr_buckets if not b)}/{len(tr_buckets)}'
)

In [None]:
def score_qa_time_series(
    edges: np.ndarray,
    buckets: Sequence[Sequence[Dict]],
    n_questions: int,
    *,
    index_name: str = 'bin_index',
    prefix: str = 'qa_q',
):
    n_bins = len(buckets)
    qa_ts = np.full((n_bins, n_questions), np.nan, dtype=float)
    for i, bucket in enumerate(buckets):
        if not bucket:
            continue
        num = np.zeros(n_questions, dtype=float)
        denom = 0.0
        for item in bucket:
            qa_vec = item.get('qa_vec')
            if qa_vec is None:
                continue
            w = float(item.get('overlap', 1.0))
            num += qa_vec * w
            denom += w
        if denom > 0:
            qa_ts[i] = num / denom
    data = {
        index_name: np.arange(n_bins, dtype=int),
        'start_sec': edges[:-1],
        'end_sec': edges[1:],
    }
    cols = []
    for j in range(n_questions):
        col = f'{prefix}{j:03d}'
        data[col] = qa_ts[:, j]
        cols.append(col)
    df = pd.DataFrame(data)
    return df, qa_ts, cols


canonical_df_raw, canonical_matrix, qa_columns = score_qa_time_series(
    canonical_edges,
    canonical_buckets,
    n_questions,
    index_name='bin_index',
    prefix='qa_q',
)

In [None]:
smoothing_kernel = build_smoothing_kernel(
    SECONDS_BIN_WIDTH,
    SMOOTHING_SECONDS,
    method=SMOOTHING_METHOD,
    gaussian_sigma_seconds=GAUSSIAN_SIGMA_SECONDS,
)
smoothing_applied = smoothing_kernel.size > 1

canonical_values_raw = canonical_matrix.copy()
if canonical_values_raw.size and smoothing_applied:
    canonical_values_smoothed = apply_smoothing_kernel(canonical_values_raw, smoothing_kernel, pad_mode=SMOOTHING_PAD_MODE)
else:
    canonical_values_smoothed = canonical_values_raw.copy()

canonical_df_smoothed = canonical_df_raw.copy()
if qa_columns:
    canonical_df_smoothed.loc[:, qa_columns] = canonical_values_smoothed
canonical_df_selected = canonical_df_smoothed if smoothing_applied else canonical_df_raw
print(f'Smoothing kernel length: {len(smoothing_kernel)} (applied={smoothing_applied})')

In [None]:
tr_values_raw = aggregate_seconds_to_edges(canonical_edges, canonical_values_raw, tr_edges)
tr_values_smoothed = aggregate_seconds_to_edges(canonical_edges, canonical_values_smoothed, tr_edges)

base_index = np.arange(len(tr_edges) - 1, dtype=int)
base_df = pd.DataFrame({'tr_index': base_index, 'start_sec': tr_edges[:-1], 'end_sec': tr_edges[1:]})

tr_df_raw = base_df.copy()
tr_df_smoothed = base_df.copy()
if qa_columns:
    tr_df_raw.loc[:, qa_columns] = tr_values_raw
    tr_df_smoothed.loc[:, qa_columns] = tr_values_smoothed
tr_df_selected = tr_df_smoothed if smoothing_applied else tr_df_raw

In [None]:
output_root = features_root / 'subjects' / SUBJECT / STORY
canonical_root = features_root / 'stories' / STORY

if SAVE_OUTPUTS:
    output_root.mkdir(parents=True, exist_ok=True)
    canonical_root.mkdir(parents=True, exist_ok=True)

    canonical_csv = canonical_root / 'qaemb_timeseries_seconds.csv'
    canonical_df_selected.to_csv(canonical_csv, index=False)
    if smoothing_applied:
        canonical_df_raw.to_csv(canonical_root / 'qaemb_timeseries_seconds_raw.csv', index=False)

    tr_csv = output_root / 'qaemb_timeseries.csv'
    tr_df_selected.to_csv(tr_csv, index=False)
    if smoothing_applied:
        tr_df_raw.to_csv(output_root / 'qaemb_timeseries_raw.csv', index=False)

    meta = {
        'subject': SUBJECT,
        'story': STORY,
        'tr_seconds': TR,
        'seconds_bin_width': SECONDS_BIN_WIDTH,
        'smoothing_seconds': SMOOTHING_SECONDS,
        'smoothing_method': SMOOTHING_METHOD,
        'gaussian_sigma_seconds': GAUSSIAN_SIGMA_SECONDS,
        'smoothing_pad_mode': SMOOTHING_PAD_MODE,
        'questions_path': str(questions_path),
        'checkpoint': QA_CHECKPOINT,
        'n_questions': n_questions,
        'ngram_size': QA_NGRAM_SIZE,
    }
    with (output_root / 'qaemb_metadata.json').open('w') as fh:
        json.dump(meta, fh, indent=2)

    print(f'Saved canonical QA series to {canonical_csv}Saved TR QA series to {tr_csv}')
else:
    print('Skipping save (SAVE_OUTPUTS is False).')

In [None]:
canonical_time = 0.5 * (canonical_edges[:-1] + canonical_edges[1:])
tr_time = tr_edges[:-1]

def plot_qa_series(selected_cols):
    if not selected_cols:
        print('No QA columns selected.')
        return
    plt.figure(figsize=(12, 4))
    for col in selected_cols:
        plt.plot(canonical_time, canonical_df_selected[col], label=f'{col} canonical', linewidth=1.4)
        plt.plot(tr_time, tr_df_selected[col], label=f'{col} TR', linestyle='--', marker='.', markersize=3)
    plt.xlabel('Time (s)')
    plt.ylabel('QA score')
    plt.title(f'{SUBJECT} / {STORY} | smoothing={SMOOTHING_METHOD} ({SMOOTHING_SECONDS}s)')
    plt.grid(True, alpha=0.3)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()


default_cols = qa_columns[: min(3, len(qa_columns))]
if widgets is not None:
    selector = widgets.SelectMultiple(options=qa_columns, value=tuple(default_cols), description='Questions:', layout=widgets.Layout(width='40%'))
    out = widgets.Output()
    display(selector, out)

    def _update(*_):
        with out:
            out.clear_output()
            plot_qa_series(list(selector.value))

    selector.observe(_update, 'value')
    _update()
else:
    plot_qa_series(default_cols)

In [None]:
if widgets is None:
    raise RuntimeError('ipywidgets unavailable; install ipywidgets to use the alignment explorer.')

TOKEN_BASE_DF = token_df[['token_index', 'word', 'start', 'end', 'midpoint']].copy()
TOKEN_BASE_DF['duration'] = TOKEN_BASE_DF['end'] - TOKEN_BASE_DF['start']
QA_SCORE_CACHE = {col: qa_matrix[:, idx].astype(float) for idx, col in enumerate(qa_columns)}
QA_SCORE_ABS_MAX = float(np.nanmax(np.abs(qa_matrix))) if qa_matrix.size else 0.0

canonical_time = 0.5 * (canonical_edges[:-1] + canonical_edges[1:])
tr_time = tr_edges[:-1]
qa_name_lookup = {col: QA_QUESTIONS[idx] if idx < len(QA_QUESTIONS) else col for idx, col in enumerate(qa_columns)}
MAX_TOKENS_DISPLAY = 60
FOCUS_WINDOW_SECONDS = 20.0


def _interpolate_series(series_values: pd.Series, times: np.ndarray, query: np.ndarray) -> np.ndarray:
    values = np.asarray(series_values, dtype=float)
    times = np.asarray(times, dtype=float)
    query = np.asarray(query, dtype=float)
    finite = np.isfinite(values)
    if finite.sum() < 2:
        return np.full(query.shape, np.nan, dtype=float)
    interp = np.interp(query, times[finite], values[finite])
    interp[(query < times[finite][0]) | (query > times[finite][-1])] = np.nan
    return interp


def _prepare_subset(col: str, t0: float, t1: float) -> pd.DataFrame:
    mask = (TOKEN_BASE_DF['midpoint'] >= t0) & (TOKEN_BASE_DF['midpoint'] <= t1)
    subset = TOKEN_BASE_DF.loc[mask].copy()
    scores = QA_SCORE_CACHE.get(col)
    if scores is None:
        raise RuntimeError(f'No cached scores for {col}')
    subset['score'] = scores[mask.to_numpy()]
    subset['abs_score'] = subset['score'].abs()
    subset.sort_values('start', inplace=True)
    subset['canonical_value'] = _interpolate_series(canonical_df_selected[col], canonical_time, subset['midpoint'].to_numpy())
    subset['tr_value'] = _interpolate_series(tr_df_selected[col], tr_time, subset['midpoint'].to_numpy())
    return subset


def _plot_alignment(col: str, subset: pd.DataFrame, highlight: pd.DataFrame, *, t0: float, t1: float, question_label: str):
    import matplotlib.lines as mlines

    highlight_ranked = highlight.sort_values('score', ascending=False)
    if len(highlight_ranked) > MAX_TOKENS_DISPLAY:
        plot_highlight = highlight_ranked.head(MAX_TOKENS_DISPLAY).sort_values('midpoint')
    else:
        plot_highlight = highlight_ranked.sort_values('midpoint')

    series_canon = canonical_df_selected[col].to_numpy(dtype=float)
    series_tr = tr_df_selected[col].to_numpy(dtype=float)
    max_abs = float(np.nanmax(np.abs(subset['score'].to_numpy()))) if subset['score'].notna().any() else 1.0
    max_abs = max(max_abs, 1.0)

    fig = plt.figure(figsize=(14, 5))
    gs = fig.add_gridspec(2, 1, height_ratios=[3, 1], hspace=0.12)
    ax = fig.add_subplot(gs[0])
    ax_tokens = fig.add_subplot(gs[1], sharex=ax)

    canon_mask = (canonical_time >= t0) & (canonical_time <= t1)
    tr_mask = (tr_time >= t0) & (tr_time <= t1)
    ax.plot(canonical_time[canon_mask], series_canon[canon_mask], color='tab:blue', label='Canonical (smoothed)')
    ax.plot(tr_time[tr_mask], series_tr[tr_mask], color='tab:orange', label='TR (smoothed)')

    token_handle = None
    if not plot_highlight.empty:
        scale = float(np.nanmax(plot_highlight['abs_score'].to_numpy())) if plot_highlight['abs_score'].notna().any() else 0.0
        scale = scale if scale > 0 else 1.0
        colors = np.where(plot_highlight['score'] >= 0, 'tab:green', 'tab:red')
        sizes = 60 + 200 * (plot_highlight['abs_score'] / scale)
        ax.scatter(plot_highlight['midpoint'], plot_highlight['canonical_value'], s=sizes, c=colors, alpha=0.9, edgecolor='white', linewidth=0.4)
        token_handle = mlines.Line2D([], [], marker='o', linestyle='None', color='tab:green', markerfacecolor='tab:green', markeredgecolor='white', label='Transcript tokens')

        ax_tokens.axhline(0.0, color='0.6', linewidth=1.0)
        ax_tokens.vlines(plot_highlight['midpoint'], 0.0, plot_highlight['score'], colors=colors, linewidth=2.0, alpha=0.8)
        for row in plot_highlight.itertuples():
            y = row.score
            offset = 0.04 * max_abs
            text_y = y + offset if y >= 0 else y - offset
            va = 'bottom' if y >= 0 else 'top'
            ax_tokens.text(row.midpoint, text_y, row.word, rotation=90, ha='center', va=va, fontsize=8)
        ax_tokens.set_ylim(-max_abs * 1.3, max_abs * 1.3)
    else:
        ax_tokens.axhline(0.0, color='0.6', linewidth=1.0)
        ax_tokens.text(0.5, 0.5, 'No tokens matched the current filters', transform=ax_tokens.transAxes, ha='center', va='center', fontsize=10, color='0.4')
        ax_tokens.set_ylim(-1.0, 1.0)

    ax.set_xlim(t0, t1)
    ax.grid(True, alpha=0.3)
    ax.set_ylabel(col)
    ax.set_title(f'{question_label} | window {t0:.1f}–{t1:.1f} s', loc='left', fontsize=11)
    handles, labels = ax.get_legend_handles_labels()
    if token_handle is not None:
        handles.append(token_handle)
        labels.append('Transcript tokens')
    ax.legend(handles, labels, loc='upper right')
    plt.setp(ax.get_xticklabels(), visible=False)

    ax_tokens.set_xlim(t0, t1)
    ax_tokens.set_xlabel('Time (s)')
    ax_tokens.set_ylabel('Token score')
    ax_tokens.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    display_columns = ['word', 'start', 'end', 'duration', 'score', 'canonical_value', 'tr_value']
    if not highlight_ranked.empty:
        display_df = highlight_ranked[display_columns].head(MAX_TOKENS_DISPLAY).reset_index(drop=True)
        if len(highlight_ranked) > MAX_TOKENS_DISPLAY:
            print(f'Showing top {MAX_TOKENS_DISPLAY} of {len(highlight_ranked)} tokens (sorted by score).')
    else:
        display_df = subset[display_columns].head(MAX_TOKENS_DISPLAY).reset_index(drop=True)
        if len(subset) > MAX_TOKENS_DISPLAY:
            print('No tokens matched the current filters; showing first tokens in window.')
    display(display_df)


options = []
for col, abbr in zip(qa_columns, QA_ABBREVS):
    label = f"{col} | {abbr}" if abbr else col
    options.append((label, col))

qa_dropdown = widgets.Dropdown(options=options, description='Question:', layout=widgets.Layout(width='45%'))
window_slider = widgets.FloatRangeSlider(
    value=(0.0, min(120.0, float(canonical_time[-1]))),
    min=0.0,
    max=float(canonical_time[-1]),
    step=1.0,
    description='Window (s):',
    layout=widgets.Layout(width='70%')
)
threshold_slider = widgets.FloatSlider(
    value=min(0.05, QA_SCORE_ABS_MAX),
    min=0.0,
    max=max(0.1, QA_SCORE_ABS_MAX),
    step=0.01,
    readout_format='.2f',
    description='|score| ≥',
    layout=widgets.Layout(width='50%')
)
score_slider = widgets.FloatSlider(
    value=0.0,
    min=-1.0,
    max=1.0,
    step=0.01,
    readout_format='.2f',
    description='score ≥',
    layout=widgets.Layout(width='50%')
)
focus_peak_btn = widgets.Button(description='Focus on peak', icon='arrow-up')
focus_trough_btn = widgets.Button(description='Focus on trough', icon='arrow-down')

controls = widgets.VBox([
    qa_dropdown,
    window_slider,
    widgets.HBox([threshold_slider, score_slider]),
    widgets.HBox([focus_peak_btn, focus_trough_btn]),
])
out = widgets.Output()


def _update_alignment(*_):
    with out:
        out.clear_output()
        col = qa_dropdown.value
        t0, t1 = window_slider.value
        min_abs = float(threshold_slider.value)
        min_score = float(score_slider.value)
        subset = _prepare_subset(col, t0, t1)
        highlight = subset[subset['abs_score'] >= min_abs].copy()
        if min_score > score_slider.min + 1e-9:
            highlight = highlight[highlight['score'] >= min_score]
        _plot_alignment(col, subset, highlight, t0=t0, t1=t1, question_label=qa_name_lookup.get(col, col))


def _focus_window(extreme: str):
    col = qa_dropdown.value
    scores = QA_SCORE_CACHE.get(col)
    if scores is None:
        return
    base = TOKEN_BASE_DF.copy()
    base['score'] = scores
    base = base[np.isfinite(base['score'])]
    if base.empty:
        return
    idx = base['score'].idxmax() if extreme == 'high' else base['score'].idxmin()
    center = float(base.loc[idx, 'midpoint'])
    half = 0.5 * FOCUS_WINDOW_SECONDS
    t0 = max(window_slider.min, center - half)
    t1 = min(window_slider.max, center + half)
    target_value = float(base.loc[idx, 'score'])
    target_abs = abs(target_value)
    window_slider.value = (t0, t1)
    if target_abs > 0:
        new_threshold = min(threshold_slider.max, max(threshold_slider.min, target_abs * 0.6))
        threshold_slider.value = new_threshold
    if score_slider.value > target_value:
        score_slider.value = max(score_slider.min, target_value)
    _update_alignment()

focus_peak_btn.on_click(lambda _: _focus_window('high'))
focus_trough_btn.on_click(lambda _: _focus_window('low'))

_update_alignment()
for widget in (qa_dropdown, window_slider, threshold_slider, score_slider):
    widget.observe(_update_alignment, 'value')

display(controls, out)