# Day 29 - QAEmb Karaoke Inspection (All Questions)\n\nThis notebook inspects QAEmb scoring validity and renders one karaoke-style MP4 containing **all 29 QA questions** in a single grid.\n\nIt uses existing QAEmb outputs plus the story transcript to show: active bin words, evolving timeseries per question, and per-panel current bin score.\n

In [None]:
import csv
import json
import os
import shutil
import sys
import time
import textwrap
from pathlib import Path

import matplotlib.animation as animation
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display

PROJECT_ROOT = Path('/flash/PaoU/seann/fmri-edm-ccm')
os.chdir(PROJECT_ROOT)
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.decoding import load_transcript_words
from src.day19_category_builder import apply_smoothing_kernel, build_smoothing_kernel
from src.utils import load_yaml

# -----------------------------
# Configuration
# -----------------------------
SUBJECT = 'UTS01'
STORY = 'wheretheressmoke'

# Optional direct transcript file override (.csv/.tsv/.json/.TextGrid)
TRANSCRIPT_PATH = None

# Optional transcript root override (directory containing TextGrids/transcripts)
TRANSCRIPTS_ROOT = Path('/flash/PaoU/seann/ds003020_copy/derivative/TextGrids')

# QAEmb output root used by your pipeline artifacts
QAEMB_ROOT = PROJECT_ROOT / 'features_qaemb' / 'featuresqaemb' / 'qaemb'

USE_DOMAIN = 'tr'  # 'tr' or 'canonical'
FPS = 8
PLAYBACK_SPEED = 7.0
WINDOW_SEC = 40.0
PAD_LEFT_SEC = 0.5
PAD_RIGHT_SEC = 2.0
BIN_WORDS_MAX = 30
Z_SCORE = False
LOG_EVERY_FRAMES = None
# 'qa_context' highlights the exact previous-n-gram context scored by QAEmb for the active token.
# 'domain_bin' keeps the old TR/canonical bin highlight.
HIGHLIGHT_MODE = 'qa_context'

# Show all 29 questions in one video: 6x5 = 30 panels
GRID_ROWS = 6
GRID_COLS = 5
# None => keep full question text in titles (no word truncation)
TITLE_MAX_WORDS = None
TITLE_WRAP_CHARS = 42

VIDEO_DIR = QAEMB_ROOT.parent / 'videos'
VIDEO_DIR.mkdir(parents=True, exist_ok=True)
OUT_MP4 = VIDEO_DIR / f'karaoke_qaemb_allq_{SUBJECT}_{STORY}.mp4'

# -----------------------------
# Load artifacts
# -----------------------------
subject_dir = QAEMB_ROOT / 'subjects' / SUBJECT / STORY
story_dir = QAEMB_ROOT / 'stories' / STORY
token_dir = QAEMB_ROOT / 'tokens' / SUBJECT

tr_path = subject_dir / 'qaemb_timeseries.csv'
canonical_path = story_dir / 'qaemb_timeseries_seconds.csv'
meta_path = subject_dir / 'qaemb_metadata.json'
questions_path = token_dir / f'{STORY}_qaemb_questions.json'
qa_tokens_path = token_dir / f'{STORY}_qaemb_tokens.npy'

required = [tr_path, canonical_path, meta_path, questions_path, qa_tokens_path]
missing = [str(p) for p in required if not p.exists()]
if missing:
    raise FileNotFoundError('Missing required QAEmb artifacts:\n' + '\n'.join(missing))

tr_df = pd.read_csv(tr_path)
canonical_df = pd.read_csv(canonical_path)
metadata = json.loads(meta_path.read_text())
questions = json.loads(questions_path.read_text())
qa_matrix = np.load(qa_tokens_path)

qa_cols = [c for c in tr_df.columns if c.startswith('qa_q')]
qa_cols = sorted(qa_cols)

if not qa_cols:
    raise ValueError(f'No QA columns found in {tr_path}')

if len(qa_cols) != len(questions):
    raise ValueError(f'Question count mismatch: columns={len(qa_cols)}, questions={len(questions)}')

if qa_matrix.shape[1] != len(questions):
    raise ValueError(f'qa_matrix columns mismatch: matrix={qa_matrix.shape[1]}, questions={len(questions)}')

if GRID_ROWS * GRID_COLS < len(qa_cols):
    raise ValueError(f'Grid too small for all questions: {GRID_ROWS}x{GRID_COLS} < {len(qa_cols)}')

print('Loaded artifacts:')
print(f'  TR series:         {tr_path} -> {tr_df.shape}')
print(f'  Canonical series:  {canonical_path} -> {canonical_df.shape}')
print(f'  Token QA matrix:   {qa_tokens_path} -> {qa_matrix.shape}')
print(f'  Questions:         {len(questions)}')
print(f'  Metadata checkpoint: {metadata.get("checkpoint")}')
CONTEXT_NGRAM = int(metadata.get('context_ngram', 10))
TIME_ANCHOR = str(metadata.get('time_anchor', 'onset')).lower()
print(f'  QA context ngram: {CONTEXT_NGRAM} | time_anchor: {TIME_ANCHOR}')

# -----------------------------
# Load transcript tokens
# -----------------------------
def _load_textgrid_events(path: Path):
    try:
        import textgrid as _textgrid
    except Exception as exc:
        raise ImportError(
            'Reading .TextGrid requires the `textgrid` package in this notebook kernel. '
            'Install it (e.g., `pip install textgrid`) or provide CSV/TSV/JSON via TRANSCRIPT_PATH.'
        ) from exc

    tg = _textgrid.TextGrid.fromFile(str(path))
    events = []
    for tier in tg:
        name = str(getattr(tier, 'name', '')).lower()
        if 'word' not in name:
            continue
        for item in tier:
            mark = str(getattr(item, 'mark', '')).strip()
            if not mark:
                continue
            start = float(getattr(item, 'minTime', 0.0))
            end = float(getattr(item, 'maxTime', start))
            events.append((mark, start, end))
    if not events:
        raise ValueError(f'No word-tier events parsed from TextGrid: {path}')
    events.sort(key=lambda x: x[1])
    return events


def _story_variants(story: str):
    variants = {
        story,
        story.replace(' ', ''),
        story.replace('-', ''),
        story.replace('_', ''),
        story.lower(),
        story.lower().replace(' ', ''),
        story.lower().replace('-', ''),
        story.lower().replace('_', ''),
    }
    return sorted(v for v in variants if v)


def _load_transcript_from_explicit_path(path: Path):
    suffix = path.suffix.lower()
    if suffix == '.json':
        data = json.loads(path.read_text())
        events = [
            (
                str(entry['word']),
                float(entry['onset']),
                float(entry.get('offset', entry['onset'])),
            )
            for entry in data
        ]
    elif suffix in {'.csv', '.tsv'}:
        delimiter = '	' if suffix == '.tsv' else ','
        with path.open('r', newline='') as fh:
            reader = csv.DictReader(fh, delimiter=delimiter)
            events = [
                (
                    str(row['word']),
                    float(row['onset']),
                    float(row.get('offset', row['onset'])),
                )
                for row in reader
            ]
    elif suffix == '.textgrid':
        events = _load_textgrid_events(path)
    else:
        raise ValueError(f'Unsupported transcript extension: {suffix}')
    events.sort(key=lambda x: x[1])
    return events
cfg = load_yaml(PROJECT_ROOT / 'configs' / 'demo.yaml')
paths = cfg.get('paths', {}) or {}

if TRANSCRIPTS_ROOT is not None:
    tr_root = Path(TRANSCRIPTS_ROOT)
    if not tr_root.exists():
        raise FileNotFoundError(f'TRANSCRIPTS_ROOT does not exist: {tr_root}')
    # load_transcript_words checks paths['transcripts'] first via candidate path generation.
    paths = dict(paths)
    paths['transcripts'] = str(tr_root)

if TRANSCRIPT_PATH is not None:
    transcript_events = _load_transcript_from_explicit_path(Path(TRANSCRIPT_PATH))
    print(f'Loaded transcript from explicit path: {TRANSCRIPT_PATH}')
else:
    try:
        transcript_events = load_transcript_words(paths, SUBJECT, STORY)
        print('Loaded transcript via load_transcript_words(paths, subject, story).')
        if TRANSCRIPTS_ROOT is not None:
            print(f'  transcript root override: {Path(TRANSCRIPTS_ROOT)}')
    except FileNotFoundError as exc:
        if TRANSCRIPTS_ROOT is None:
            raise

        tr_root = Path(TRANSCRIPTS_ROOT)
        candidates = []
        for base in _story_variants(STORY):
            for ext in ('.TextGrid', '.textgrid', '.csv', '.tsv', '.json'):
                candidates.append(tr_root / f'{base}{ext}')
        existing = [p for p in candidates if p.exists()]

        if existing:
            chosen = existing[0]
            transcript_events = _load_transcript_from_explicit_path(chosen)
            print('load_transcript_words fallback activated.')
            print(f'  loaded from explicit candidate: {chosen}')
        else:
            sample = sorted(p.name for p in tr_root.glob('*') if p.is_file())[:20]
            sample_txt = '\n'.join(sample) if sample else '<no files>'
            raise FileNotFoundError(
                f"{exc}\n"
                f"TRANSCRIPTS_ROOT={tr_root}\n"
                f"No candidate transcript matched story='{STORY}'.\n"
                f"Sample files:\n{sample_txt}"
            )
if not transcript_events:
    raise ValueError('Transcript is empty.')

token_df = pd.DataFrame(transcript_events, columns=['word', 'start', 'end'])
token_df['word'] = token_df['word'].astype(str).str.strip()
token_df = token_df[token_df['word'] != ''].reset_index(drop=True)
token_df['midpoint'] = 0.5 * (token_df['start'] + token_df['end'])
token_df['duration'] = token_df['end'] - token_df['start']
token_df['token_index'] = np.arange(len(token_df), dtype=int)

if len(token_df) != qa_matrix.shape[0]:
    raise ValueError(
        'Token count mismatch between transcript and qa_matrix. '
        f'transcript_nonempty={len(token_df)}, qa_matrix_rows={qa_matrix.shape[0]}. '
        'Use TRANSCRIPT_PATH to force the exact transcript used during QAEmb generation if needed.'
    )

print(f'Transcript tokens (non-empty): {len(token_df)}')

# -----------------------------
# Validate saved QA series against recomputation from token matrix
# -----------------------------
def _corr(a: np.ndarray, b: np.ndarray) -> float:
    m = np.isfinite(a) & np.isfinite(b)
    if m.sum() < 2:
        return np.nan
    aa = a[m]
    bb = b[m]
    sa = aa.std()
    sb = bb.std()
    if sa == 0 or sb == 0:
        return np.nan
    return float(np.corrcoef(aa, bb)[0, 1])


def _resample_irregular_to_edges(token_times: np.ndarray, token_values: np.ndarray, target_edges: np.ndarray) -> np.ndarray:
    if token_values.size == 0:
        return np.empty((len(target_edges) - 1, 0), dtype=float)
    order = np.argsort(token_times)
    times_sorted = token_times[order]
    values_sorted = token_values[order]
    centers = 0.5 * (target_edges[:-1] + target_edges[1:])
    out = np.empty((centers.size, token_values.shape[1]), dtype=float)
    for j in range(token_values.shape[1]):
        out[:, j] = np.interp(
            centers,
            times_sorted,
            values_sorted[:, j],
            left=values_sorted[0, j],
            right=values_sorted[-1, j],
        )
    return out


def _sample_tokens_to_edges_step(token_times: np.ndarray, token_values: np.ndarray, target_edges: np.ndarray) -> np.ndarray:
    if token_values.size == 0:
        return np.empty((len(target_edges) - 1, 0), dtype=float)
    order = np.argsort(token_times, kind='stable')
    times_sorted = token_times[order]
    values_sorted = token_values[order]
    centers = 0.5 * (target_edges[:-1] + target_edges[1:])
    idx = np.searchsorted(times_sorted, centers, side='right') - 1
    idx = np.clip(idx, 0, len(times_sorted) - 1)
    return values_sorted[idx].astype(float, copy=True)


def _build_edges_from_df(df: pd.DataFrame) -> np.ndarray:
    starts = df['start_sec'].to_numpy(dtype=float)
    ends = df['end_sec'].to_numpy(dtype=float)
    if len(starts) == 0:
        raise ValueError('No rows in timeseries dataframe.')
    edges = np.concatenate([starts, [ends[-1]]])
    if not np.all(np.diff(edges) > 0):
        raise ValueError('Non-monotone edges detected in series dataframe.')
    return edges

tr_edges = _build_edges_from_df(tr_df)
canonical_edges = _build_edges_from_df(canonical_df)

anchor_mode = str(metadata.get('time_anchor', 'onset')).lower()
if anchor_mode not in {'onset', 'midpoint'}:
    anchor_mode = 'onset'

if anchor_mode == 'midpoint':
    token_times = token_df['midpoint'].to_numpy(dtype=float)
else:
    token_times = token_df['start'].to_numpy(dtype=float)

resample_method = str(metadata.get('resample_method', 'interp_linear')).lower()
if resample_method in {'step_previous', 'step', 'previous'}:
    expected_raw = _sample_tokens_to_edges_step(token_times, qa_matrix, tr_edges)
else:
    expected_raw = _resample_irregular_to_edges(token_times, qa_matrix, tr_edges)

tr_seconds = float(metadata.get('tr_seconds', cfg.get('TR', 2.0)))
smoothing_seconds = float(metadata.get('smoothing_seconds', 1.0))
smoothing_method = str(metadata.get('smoothing_method', 'moving_average')).lower()
gaussian_sigma_seconds = metadata.get('gaussian_sigma_seconds', None)
pad_mode = str(metadata.get('smoothing_pad_mode', 'reflect'))
postprocess = str(metadata.get('postprocess', '')).lower()

if postprocess == 'none' or smoothing_seconds <= 0:
    expected_selected = expected_raw.copy()
else:
    kernel = build_smoothing_kernel(
        tr_seconds,
        smoothing_seconds,
        method=smoothing_method,
        gaussian_sigma_seconds=gaussian_sigma_seconds,
    )
    if expected_raw.size and kernel.size > 1:
        expected_selected = apply_smoothing_kernel(expected_raw, kernel, pad_mode=pad_mode)
    else:
        expected_selected = expected_raw.copy()

observed_tr = tr_df[qa_cols].to_numpy(dtype=float)
observed_canonical = canonical_df[qa_cols].to_numpy(dtype=float)

if observed_tr.shape != expected_selected.shape:
    raise ValueError(f'Shape mismatch for TR validation: observed={observed_tr.shape}, expected={expected_selected.shape}')

if observed_canonical.shape != expected_selected.shape:
    raise ValueError(f'Shape mismatch for canonical validation: observed={observed_canonical.shape}, expected={expected_selected.shape}')

rows = []
for j, col in enumerate(qa_cols):
    obs = observed_tr[:, j]
    exp = expected_selected[:, j]
    diff = obs - exp
    m = np.isfinite(diff)
    rows.append(
        {
            'column': col,
            'question': questions[j],
            'corr_tr_vs_expected': _corr(obs, exp),
            'mae_tr_vs_expected': float(np.nanmean(np.abs(diff[m]))) if m.any() else np.nan,
            'max_abs_tr_vs_expected': float(np.nanmax(np.abs(diff[m]))) if m.any() else np.nan,
            'corr_canonical_vs_expected': _corr(observed_canonical[:, j], exp),
            'mean_tr_score': float(np.nanmean(obs)),
            'std_tr_score': float(np.nanstd(obs)),
        }
    )

validation_df = pd.DataFrame(rows)

print('\nValidation summary:')
print(f'  global max abs diff (TR vs expected): {np.nanmax(np.abs(observed_tr - expected_selected)):.6g}')
print(f'  global mean abs diff (TR vs expected): {np.nanmean(np.abs(observed_tr - expected_selected)):.6g}')
print(f'  global max abs diff (canonical vs expected): {np.nanmax(np.abs(observed_canonical - expected_selected)):.6g}')

print('\nWorst 10 questions by MAE (TR vs expected):')
display(validation_df.sort_values('mae_tr_vs_expected', ascending=False).head(10))

# -----------------------------
# Karaoke rendering with all questions in one video
# -----------------------------
def abbreviate_question(q: str, max_words: int | None = None) -> str:
    q = str(q).strip()
    if q.endswith('?'):
        q = q[:-1]
    if max_words is None:
        return q
    words = q.split()
    return ' '.join(words[:max_words])


def make_karaoke_qaemb_video_all_questions(
    *,
    token_df: pd.DataFrame,
    tr_df: pd.DataFrame,
    canonical_df: pd.DataFrame,
    tr_edges: np.ndarray,
    canonical_edges: np.ndarray,
    qa_cols: list,
    questions: list,
    out_mp4: Path,
    use_domain: str = 'tr',
    fps: int = 8,
    playback_speed: float = 7.0,
    window_sec: float = 40.0,
    pad_left_sec: float = 0.5,
    pad_right_sec: float = 2.0,
    bin_words_max: int | None = 30,
    zscore: bool = False,
    rows: int = 6,
    cols: int = 5,
    title_max_words: int | None = None,
    title_wrap_chars: int = 42,
    log_every_frames: int | None = None,
    highlight_mode: str = 'qa_context',
    context_ngram: int = 10,
    time_anchor: str = 'onset',
):
    if use_domain not in {'tr', 'canonical'}:
        raise ValueError("use_domain must be 'tr' or 'canonical'")
    if rows * cols < len(qa_cols):
        raise ValueError(f'Grid {rows}x{cols} too small for {len(qa_cols)} questions')
    if fps <= 0:
        raise ValueError('fps must be > 0')
    if playback_speed <= 0:
        raise ValueError('playback_speed must be > 0')
    if highlight_mode not in {'qa_context', 'domain_bin'}:
        raise ValueError("highlight_mode must be 'qa_context' or 'domain_bin'")
    if context_ngram < 1:
        raise ValueError('context_ngram must be >= 1')
    time_anchor = str(time_anchor).lower()
    if time_anchor not in {'onset', 'midpoint'}:
        raise ValueError("time_anchor must be 'onset' or 'midpoint'")

    if use_domain == 'canonical':
        df = canonical_df
        edges = canonical_edges
        t = 0.5 * (canonical_edges[:-1] + canonical_edges[1:])
    else:
        df = tr_df
        edges = tr_edges
        t = tr_edges[:-1]

    t = np.asarray(t, dtype=float)
    edges = np.asarray(edges, dtype=float)

    Y = np.vstack([df[c].to_numpy(dtype=float) for c in qa_cols])
    if zscore:
        mean = np.nanmean(Y, axis=1, keepdims=True)
        std = np.nanstd(Y, axis=1, keepdims=True)
        std = np.where(np.isfinite(std) & (std > 0), std, 1.0)
        Y = (Y - mean) / std

    ylims = []
    for i in range(Y.shape[0]):
        vals = Y[i]
        finite = vals[np.isfinite(vals)]
        if finite.size == 0:
            ylims.append((-1.0, 1.0))
            continue
        lo = float(np.min(finite))
        hi = float(np.max(finite))
        if np.isclose(lo, hi):
            pad = 1.0 if lo == 0 else abs(lo) * 0.1
        else:
            pad = 0.05 * (hi - lo)
        ylims.append((lo - pad, hi + pad))

    words = token_df['word'].astype(str).tolist()
    starts = token_df['start'].to_numpy(dtype=float)
    ends = token_df['end'].to_numpy(dtype=float)
    anchors = starts if time_anchor == 'onset' else 0.5 * (starts + ends)

    transcript_end = float(token_df['end'].max())
    video_duration = (transcript_end + pad_right_sec) / playback_speed
    n_frames = int(np.ceil(video_duration * fps))

    if log_every_frames is None:
        log_every_frames = max(1, int(fps * 5))

    fig = plt.figure(figsize=(24, 14), dpi=130)
    gs = fig.add_gridspec(rows + 1, cols, height_ratios=[0.95] + [1.0] * rows, hspace=0.36, wspace=0.24)

    ax_text = fig.add_subplot(gs[0, :])
    ax_text.axis('off')

    axes = []
    for r in range(1, rows + 1):
        for c in range(cols):
            axes.append(fig.add_subplot(gs[r, c]))

    lines = []
    cursors = []
    spans = []
    val_texts = []

    title_labels = []
    for i, q in enumerate(questions):
        q_text = abbreviate_question(q, max_words=title_max_words)
        if title_wrap_chars and title_wrap_chars > 0:
            q_text = textwrap.fill(q_text, width=title_wrap_chars)
        title_labels.append(f'{qa_cols[i]}\n{q_text}')

    for i, ax in enumerate(axes):
        if i >= len(qa_cols):
            ax.axis('off')
            lines.append(None)
            cursors.append(None)
            spans.append(None)
            val_texts.append(None)
            continue

        ax.set_title(title_labels[i], fontsize=6, linespacing=1.05, pad=2.5)
        ax.set_xlim(0, window_sec)
        ax.set_ylim(*ylims[i])
        ax.grid(True, alpha=0.22)
        (line,) = ax.plot([], [], linewidth=1.2)
        cursor = ax.axvline(0, linewidth=0.9, alpha=0.8)
        span = mpatches.Rectangle(
            (0, 0),
            0,
            1,
            transform=ax.get_xaxis_transform(),
            facecolor='#f4c542',
            alpha=0.22,
            zorder=0,
        )
        ax.add_patch(span)
        val_text = ax.text(
            0.98,
            0.92,
            '',
            transform=ax.transAxes,
            ha='right',
            va='top',
            fontsize=6,
            bbox={'facecolor': 'white', 'alpha': 0.55, 'edgecolor': 'none', 'pad': 1.2},
        )

        lines.append(line)
        cursors.append(cursor)
        spans.append(span)
        val_texts.append(val_text)

    txt_bin = ax_text.text(0.01, 0.74, '', fontsize=15, fontweight='bold', ha='left', va='center')
    txt_words = ax_text.text(0.01, 0.38, '', fontsize=13, ha='left', va='center')
    txt_meta = ax_text.text(0.01, 0.09, '', fontsize=11, ha='left', va='center', alpha=0.85)

    start_wall = None
    last_progress_len = 0
    max_rendered_frame = -1

    def _bin_idx(t_now: float) -> int:
        idx = int(np.searchsorted(edges, t_now, side='right') - 1)
        return int(np.clip(idx, 0, len(edges) - 2))

    def _qa_context_window(t_now: float):
        token_idx = int(np.searchsorted(anchors, t_now, side='right') - 1)
        token_idx = int(np.clip(token_idx, 0, len(anchors) - 1))

        ctx_start_idx = max(0, token_idx + 1 - int(context_ngram))
        if token_idx > ctx_start_idx:
            context_idxs = np.arange(ctx_start_idx, token_idx, dtype=int)
            h0 = float(starts[ctx_start_idx])
            h1 = float(ends[token_idx - 1])
            context_words = [words[k] for k in context_idxs]
        else:
            context_idxs = np.empty((0,), dtype=int)
            h0 = float(anchors[token_idx])
            h1 = h0
            context_words = []

        return token_idx, context_idxs, h0, h1, context_words

    def _format_words(bin_words: list[str]) -> str:
        if not bin_words:
            return '(no words in active bin)'
        if bin_words_max is not None and len(bin_words) > bin_words_max:
            return ' '.join(bin_words[:bin_words_max]) + ' ...'
        return ' '.join(bin_words)

    def init():
        artists = []
        for i in range(len(axes)):
            if lines[i] is None:
                continue
            lines[i].set_data([], [])
            val_texts[i].set_text('')
            artists.extend([lines[i], cursors[i], spans[i], val_texts[i]])
        txt_bin.set_text('')
        txt_words.set_text('')
        txt_meta.set_text('')
        artists.extend([txt_bin, txt_words, txt_meta])
        return artists

    def update(frame: int):
        nonlocal start_wall
        if start_wall is None:
            start_wall = time.perf_counter()

        t_now = (frame / fps) * playback_speed
        w0 = max(0.0, t_now - window_sec + pad_left_sec)
        w1 = w0 + window_sec

        mask = (t >= w0) & (t <= w1)
        if not np.any(mask):
            return []

        tw = t[mask] - w0

        b = _bin_idx(t_now)
        b0 = float(edges[b])
        b1 = float(edges[b + 1])

        if highlight_mode == 'qa_context':
            token_idx, context_idxs, h0, h1, bin_words = _qa_context_window(t_now)
            view0 = max(h0, w0)
            view1 = min(h1, w1)
            span_x = view0 - w0
            span_w = max(0.0, view1 - view0)
            highlight_label = (
                f'QA context token={token_idx} anchor={anchors[token_idx]:6.2f}s '
                f'({len(context_idxs)} prior words)'
            )
            token_label = words[token_idx] if 0 <= token_idx < len(words) else '<na>'
        else:
            view0 = max(b0, w0)
            view1 = min(b1, w1)
            span_x = view0 - w0
            span_w = max(0.0, view1 - view0)
            bin_mask = (starts < b1) & (ends > b0)
            bin_words = [words[i] for i in np.where(bin_mask)[0]]
            highlight_label = f'Bin {b0:6.2f}-{b1:6.2f}s'
            token_label = '<domain-bin>'

        current_scores = Y[:, b]

        top_idx = np.argsort(np.nan_to_num(current_scores, nan=-np.inf))[::-1][:5]
        top_text = ' | '.join(
            [f'q{j:02d}:{current_scores[j]:.2f}' for j in top_idx if np.isfinite(current_scores[j])]
        )

        artists = []
        for i in range(len(axes)):
            if lines[i] is None:
                continue
            yw = Y[i, mask]
            lines[i].set_data(tw, yw)
            cursors[i].set_xdata([t_now - w0, t_now - w0])
            spans[i].set_x(span_x)
            spans[i].set_width(span_w)
            val = current_scores[i]
            val_texts[i].set_text(f'{val:.2f}' if np.isfinite(val) else 'nan')
            artists.extend([lines[i], cursors[i], spans[i], val_texts[i]])

        txt_bin.set_text(f'{highlight_label} | {len(bin_words)} words')
        txt_words.set_text(_format_words(bin_words))
        txt_meta.set_text(
            f't={t_now:6.2f}s | bin {b + 1}/{len(edges) - 1} | mode={highlight_mode} | '
            f'token={token_label} | top5 {top_text}'
        )

        nonlocal last_progress_len, max_rendered_frame
        max_rendered_frame = max(max_rendered_frame, frame)

        # Inline progress tracker (single-line refresh) so render health is always visible.
        if frame % log_every_frames == 0 or frame == n_frames - 1:
            elapsed = time.perf_counter() - start_wall
            progress = (frame + 1) / n_frames
            eta = elapsed / progress - elapsed if progress > 0 else float('inf')
            msg = (
                f'Rendered {frame + 1}/{n_frames} frames '
                f'({progress * 100:5.1f}%) | elapsed {elapsed / 60:5.1f}m | ETA {eta / 60:5.1f}m'
            )
            print('\r' + msg.ljust(last_progress_len), end='', flush=True)
            last_progress_len = max(last_progress_len, len(msg))
            if frame == n_frames - 1:
                print()

        artists.extend([txt_bin, txt_words, txt_meta])
        return artists

    anim = animation.FuncAnimation(
        fig,
        update,
        init_func=init,
        frames=n_frames,
        interval=1000 / fps,
        blit=False,
    )

    writer = animation.FFMpegWriter(
        fps=fps,
        codec='libx264',
        bitrate=7000,
        extra_args=['-pix_fmt', 'yuv420p', '-movflags', '+faststart'],
    )

    out_mp4.parent.mkdir(parents=True, exist_ok=True)
    anim.save(str(out_mp4), writer=writer)
    plt.close(fig)

    rendered_frames = max_rendered_frame + 1
    if rendered_frames != n_frames:
        print(
            f'WARNING: expected {n_frames} frames but update() saw {rendered_frames}. '
            'Video was still written; inspect output if this is unexpected.'
        )
    else:
        print(f'Frame render check: {rendered_frames}/{n_frames} frames rendered.')

    print(f'Saved: {out_mp4}')


if shutil.which('ffmpeg') is None:
    raise RuntimeError('ffmpeg not found on PATH. Install ffmpeg before rendering MP4.')

make_karaoke_qaemb_video_all_questions(
    token_df=token_df,
    tr_df=tr_df,
    canonical_df=canonical_df,
    tr_edges=tr_edges,
    canonical_edges=canonical_edges,
    qa_cols=qa_cols,
    questions=questions,
    out_mp4=OUT_MP4,
    use_domain=USE_DOMAIN,
    fps=FPS,
    playback_speed=PLAYBACK_SPEED,
    window_sec=WINDOW_SEC,
    pad_left_sec=PAD_LEFT_SEC,
    pad_right_sec=PAD_RIGHT_SEC,
    bin_words_max=BIN_WORDS_MAX,
    zscore=Z_SCORE,
    rows=GRID_ROWS,
    cols=GRID_COLS,
    title_max_words=TITLE_MAX_WORDS,
    title_wrap_chars=TITLE_WRAP_CHARS,
    log_every_frames=LOG_EVERY_FRAMES,
    highlight_mode=HIGHLIGHT_MODE,
    context_ngram=CONTEXT_NGRAM,
    time_anchor=TIME_ANCHOR,
)








