# Day 28 - Karaoke Category Video (LLM embeddings)

This notebook rebuilds category time series using LLM embeddings and renders a
karaoke-style MP4 with the active bin highlighted and its words listed at top.

Workflow:
1. Update the configuration cell (LM embedding path, subject/story, smoothing).
2. Run the generation cell to create `result`.
3. Run the token prep cell.
4. Pick 12 categories for the grid.
5. Run the video cell (requires ffmpeg).


In [None]:
import json
import math
import os
import sys
import warnings
from collections import Counter
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd

np.random.seed(42)
pd.options.display.max_columns = 60

project_root = Path('/flash/PaoU/seann/fmri-edm-ccm')
project_root.mkdir(parents=True, exist_ok=True)
os.chdir(project_root)

sys.path.append(str(project_root))
sys.path.append('/flash/PaoU/seann/pyEDM/src')
sys.path.append('/flash/PaoU/seann/MDE-main/src')

try:
    import ipywidgets as widgets
    from IPython.display import display
except Exception:
    widgets = None
    def display(obj):
        print(obj)

try:
    import matplotlib.pyplot as plt
except Exception as exc:
    plt = None
    warnings.warn(f'Matplotlib unavailable: {exc}')

from src.utils import load_yaml
from src.category_builder import generate_category_time_series, get_embedding_backend

EPS = 1e-12


In [None]:
# --- Configuration -------------------------------------------------------------------
cfg = load_yaml('configs/demo.yaml')
categories_cfg = cfg.get('categories', {}) or {}
cluster_csv_path = categories_cfg.get('cluster_csv_path', '')
prototype_weight_power = float(categories_cfg.get('prototype_weight_power', 1.0))
seconds_bin_width_default = float(categories_cfg.get('seconds_bin_width', 0.05))
temporal_weighting_default = str(categories_cfg.get('temporal_weighting', 'proportional')).lower()

paths = cfg.get('paths', {})
TR = float(cfg.get('TR', 2.0))
features_root = Path(paths.get('featurestest', 'featurestest'))
features_root.mkdir(parents=True, exist_ok=True)

SUBJECT = cfg.get('subject') or 'UTS01'
STORY = cfg.get('story') or 'wheretheressmoke'
TEMPORAL_WEIGHTING = temporal_weighting_default  # {'proportional', 'none'}
SECONDS_BIN_WIDTH = seconds_bin_width_default

# LLM embedding settings (matches run_day26_smoothing_llm.slurm defaults)
LM_EMBEDDING_PATH = Path(os.environ.get('LM_EMBEDDING_PATH', 'embeddings/gpt_tokens.npz'))
LM_LOWERCASE_TOKENS = os.environ.get('LM_LOWERCASE_TOKENS', 'true').lower() == 'true'
if not LM_EMBEDDING_PATH.is_absolute():
    LM_EMBEDDING_PATH = (project_root / LM_EMBEDDING_PATH).resolve()
if not LM_EMBEDDING_PATH.exists():
    raise FileNotFoundError(f'LLM embedding file not found: {LM_EMBEDDING_PATH}')

# canonical smoothing controls (edit to taste)
SMOOTHING_SECONDS = 1.00            # shorter window preserves fast dynamics for forecasting
SMOOTHING_METHOD = 'moving_average'       # {'moving_average', 'gaussian'}
GAUSSIAN_SIGMA_SECONDS = 0.5 * SMOOTHING_SECONDS  # tie sigma to window length for EDM
SMOOTHING_PAD_MODE = 'reflect'      # {'edge', 'reflect'}

SAVE_OUTPUTS = True  # toggle off to skip writing CSVs

# video settings
VIDEO_DIR = features_root / 'videos'
VIDEO_DIR.mkdir(parents=True, exist_ok=True)
KARAOKE_OUTPUT = str(VIDEO_DIR / f'karaoke_llm_{SUBJECT}_{STORY}.mp4')
KARAOKE_USE_DOMAIN = 'tr'  # 'tr' or 'canonical'
KARAOKE_FPS = 1
KARAOKE_WINDOW_SEC = 30.0
KARAOKE_PAD_LEFT_SEC = 0.5
KARAOKE_PAD_RIGHT_SEC = 2.0
KARAOKE_PLAYBACK_SPEED = 7.0  # 1.0 real-time; <1 slower; >1 faster
KARAOKE_YLIM_PAD_FRAC = 0.05  # add 5% headroom to y-lims
KARAOKE_ZSCORE = True  # z-score each category series before plotting
KARAOKE_LOG_EVERY_FRAMES = None  # None -> log every ~5s of video time
KARAOKE_BIN_WORDS_MAX = 30  # None to show all words in the bin
KARAOKE_CATEGORY_COUNT = 12  # must be 12 for the 4x3 grid
KARAOKE_CATEGORY_COLUMNS = None  # set to a list of 12 column names if desired

print(f'Subject/story: {SUBJECT} / {STORY}')
print(f'Cluster CSV: {cluster_csv_path or "<none>"}')
print(f'Temporal weighting: {TEMPORAL_WEIGHTING}')
print(f'Seconds bin width: {SECONDS_BIN_WIDTH}')
print(f'Smoothing: {SMOOTHING_METHOD} | window={SMOOTHING_SECONDS}s | sigma={GAUSSIAN_SIGMA_SECONDS}')
print(f'LLM embeddings: {LM_EMBEDDING_PATH} (lowercase={LM_LOWERCASE_TOKENS})')
print(f'Video output: {KARAOKE_OUTPUT}')
print(f'Video domain: {KARAOKE_USE_DOMAIN} | fps={KARAOKE_FPS} | window={KARAOKE_WINDOW_SEC}s')


In [None]:
embedding_backend = get_embedding_backend(
    'llm',
    lm_embedding_path=LM_EMBEDDING_PATH,
    lm_lowercase_tokens=LM_LOWERCASE_TOKENS,
)

result = generate_category_time_series(
    SUBJECT,
    STORY,
    cfg_base=cfg,
    categories_cfg_base=categories_cfg,
    cluster_csv_path=cluster_csv_path,
    temporal_weighting=TEMPORAL_WEIGHTING,
    prototype_weight_power=prototype_weight_power,
    smoothing_seconds=SMOOTHING_SECONDS,
    smoothing_method=SMOOTHING_METHOD,
    gaussian_sigma_seconds=GAUSSIAN_SIGMA_SECONDS,
    smoothing_pad=SMOOTHING_PAD_MODE,
    seconds_bin_width=SECONDS_BIN_WIDTH,
    features_root=features_root,
    paths=paths,
    TR=TR,
    embedding_backend=embedding_backend,
    save_outputs=SAVE_OUTPUTS,
)

canonical_df = result['canonical_df_selected']
tr_df = result['category_df_selected']
print()
print('Smoothing configuration:', result['smoothing'])
print()
print('Canonical preview:')
display(canonical_df.head())
print()
print('TR-aligned preview:')
display(tr_df.head())
if result['trimmed_df'] is not None:
    print()
    print(f"Trimmed window length: {len(result['trimmed_df'])} (max_lag_primary={result['max_lag_primary']})")


In [7]:
# Prepare transcript tokens and per-category scores
import numpy as np
import pandas as pd

if 'result' not in globals():
    raise RuntimeError('Run the generation cell first to populate `result`.')

_tokens_raw = result.get('event_records') or []
if not _tokens_raw:
    raise RuntimeError('No transcript events found - rerun upstream steps.')

_category_states = result.get('category_states') or {}
if not _category_states:
    raise RuntimeError('Category states missing from result; rerun the generation cell.')

_token_df = pd.DataFrame(_tokens_raw)
_token_df = _token_df[['word', 'start', 'end', 'embedding', 'embedding_norm']].copy()
_token_df['midpoint'] = 0.5 * (_token_df['start'] + _token_df['end'])
_token_df['duration'] = _token_df['end'] - _token_df['start']
_token_df['token_index'] = np.arange(len(_token_df))

_score_method = str(result.get('category_score_method', 'similarity')).lower()
_token_scores = {}
_abs_max = 0.0
for cat_name, state in _category_states.items():
    scores = []
    proto = state.get('prototype')
    proto_norm = state.get('prototype_norm') or 0.0
    lexicon = state.get('lexicon', {}) or {}
    for rec in _tokens_raw:
        word = rec['word']
        if _score_method == 'count':
            score = lexicon.get(word.lower(), np.nan)
        else:
            emb = rec.get('embedding')
            emb_norm = rec.get('embedding_norm') or 0.0
            if emb is None or proto is None or proto_norm <= 0 or emb_norm <= 0:
                score = np.nan
            else:
                score = float(np.clip(np.dot(emb, proto) / (emb_norm * proto_norm), -1.0, 1.0))
        scores.append(score)
    arr = np.array(scores, dtype=float)
    _token_scores[cat_name] = arr
    finite = np.abs(arr[np.isfinite(arr)])
    if finite.size:
        _abs_max = max(_abs_max, float(finite.max()))

TOKEN_BASE_DF = _token_df[['token_index', 'word', 'start', 'end', 'midpoint', 'duration']].copy()
TOKEN_SCORE_CACHE = _token_scores
TOKEN_SCORE_METHOD = _score_method
TOKEN_SCORE_ABS_MAX = _abs_max if _abs_max > 0 else 1.0

# Drop heavy objects from the temporary frame to free memory
del _token_df
del _tokens_raw


In [8]:
# Pick 12 categories for the 4x3 grid (edit as needed)
if KARAOKE_CATEGORY_COLUMNS is None:
    category_cols_12 = result['category_columns'][:KARAOKE_CATEGORY_COUNT]
else:
    category_cols_12 = list(KARAOKE_CATEGORY_COLUMNS)

if len(category_cols_12) != 12:
    raise ValueError(f'Expected 12 categories, got {len(category_cols_12)}.')

print('Selected categories:', category_cols_12)


Selected categories: ['cat_abstract', 'cat_communal', 'cat_emotional', 'cat_locational', 'cat_mental', 'cat_numeric', 'cat_professional', 'cat_social', 'cat_tactile', 'cat_temporal', 'cat_violent', 'cat_visual']


In [9]:
import numpy as np
import pandas as pd
import matplotlib.animation as animation
import matplotlib.patches as mpatches
import time

if plt is None:
    raise RuntimeError('Matplotlib unavailable in this environment.')


def make_karaoke_category_video(
    *,
    result: dict,
    token_df: pd.DataFrame,
    category_cols: list,
    out_mp4: str = "karaoke_categories.mp4",
    use_domain: str = "tr",          # "tr" or "canonical"
    fps: int = 30,
    window_sec: float = 30.0,        # how many seconds visible at once
    pad_left_sec: float = 0.5,
    pad_right_sec: float = 2.0,
    log_every_frames: int | None = None,
    playback_speed: float = 1.0,
    ypad_frac: float = 0.05,
    bin_words_max: int | None = 30,
    zscore: bool = True,
):
    '''
    Creates an MP4 where the top shows words within the active bin and bottom is a 4x3 grid.
    Assumes token_df has columns: word,start,end,midpoint (like your TOKEN_BASE_DF).
    '''

    assert use_domain in {"tr", "canonical"}
    assert len(category_cols) == 12, "Pass exactly 12 categories for a 4x3 grid."

    if playback_speed <= 0:
        raise ValueError("playback_speed must be > 0.")
    if ypad_frac < 0:
        raise ValueError("ypad_frac must be >= 0.")
    if bin_words_max is not None and bin_words_max <= 0:
        raise ValueError("bin_words_max must be > 0 or None.")

    if not isinstance(zscore, bool):
        raise ValueError("zscore must be a boolean.")

    if log_every_frames is None:
        log_every_frames = max(1, int(fps * 5))

    # --- Choose time base and series ---
    if use_domain == "canonical":
        df = result["canonical_df_selected"]
        t = 0.5 * (result["canonical_edges"][:-1] + result["canonical_edges"][1:])
        edges = result["canonical_edges"]
    else:
        df = result["category_df_selected"]
        t = result["tr_edges"][:-1]
        edges = result["tr_edges"]

    t = np.asarray(t, dtype=float)
    edges = np.asarray(edges, dtype=float)

    # Determine video duration from transcript
    t_end = float(token_df["end"].max())
    transcript_duration = t_end + pad_right_sec
    duration = transcript_duration / playback_speed
    n_frames = int(np.ceil(duration * fps))

    # Pre-extract series arrays for speed
    Y = np.vstack([df[c].to_numpy(dtype=float) for c in category_cols])  # shape (12, T)

    if zscore:
        means = np.nanmean(Y, axis=1, keepdims=True)
        stds = np.nanstd(Y, axis=1, keepdims=True)
        stds = np.where(np.isfinite(stds) & (stds > 0), stds, 1.0)
        Y = (Y - means) / stds

    # Robust y-lims per subplot (full range with padding)
    ylims = []
    for i in range(12):
        vals = Y[i]
        finite = vals[np.isfinite(vals)]
        if finite.size == 0:
            ylims.append((-1, 1))
        else:
            lo = float(np.min(finite))
            hi = float(np.max(finite))
            if np.isclose(lo, hi):
                pad = 1.0 if lo == 0 else abs(lo) * 0.1
            else:
                pad = (hi - lo) * ypad_frac
            ylims.append((lo - pad, hi + pad))

    # --- Figure layout: top karaoke + 4x3 plots ---
    fig = plt.figure(figsize=(16, 9), dpi=150)
    gs = fig.add_gridspec(5, 3, height_ratios=[0.65, 1, 1, 1, 1], hspace=0.35, wspace=0.25)

    ax_text = fig.add_subplot(gs[0, :])
    ax_text.axis("off")

    axes = []
    for r in range(1, 5):
        for c in range(3):
            axes.append(fig.add_subplot(gs[r, c]))

    # Initialize lines
    lines = []
    cursors = []
    highlights = []
    for i, ax in enumerate(axes):
        ax.set_title(category_cols[i].replace("cat_", ""), fontsize=10)
        ax.set_xlim(0, window_sec)
        ax.set_ylim(*ylims[i])
        ax.grid(True, alpha=0.25)
        (ln,) = ax.plot([], [], linewidth=1.6)
        cursor = ax.axvline(0, linewidth=1.2, alpha=0.9)
        highlight = mpatches.Rectangle(
            (0, 0), 0, 1,
            transform=ax.get_xaxis_transform(),
            facecolor="#f4c542",
            alpha=0.25,
            zorder=0,
        )
        ax.add_patch(highlight)
        lines.append(ln)
        cursors.append(cursor)
        highlights.append(highlight)

    # Text artists
    txt_bin = ax_text.text(0.01, 0.72, "", fontsize=16, fontweight="bold", va="center", ha="left")
    txt_words = ax_text.text(0.01, 0.36, "", fontsize=14, va="center", ha="left")
    txt_sub  = ax_text.text(0.01, 0.10, "", fontsize=11, va="center", ha="left", alpha=0.8)

    words = token_df["word"].astype(str).tolist()
    starts = token_df["start"].to_numpy(dtype=float)
    ends = token_df["end"].to_numpy(dtype=float)

    start_time = None

    def _bin_index(t_now: float) -> int:
        idx = int(np.searchsorted(edges, t_now, side="right") - 1)
        return int(np.clip(idx, 0, len(edges) - 2))

    def _format_bin_words(bin_words: list[str]) -> str:
        if not bin_words:
            return "(no words in bin)"
        if bin_words_max is not None and len(bin_words) > bin_words_max:
            shown = bin_words[:bin_words_max]
            return " ".join(shown) + " ..."
        return " ".join(bin_words)

    def init():
        for ln in lines:
            ln.set_data([], [])
        txt_bin.set_text("")
        txt_words.set_text("")
        txt_sub.set_text("")
        return lines + cursors + highlights + [txt_bin, txt_words, txt_sub]

    def update(frame):
        nonlocal start_time
        if start_time is None:
            start_time = time.perf_counter()

        t_now = (frame / fps) * playback_speed

        # sliding window
        w0 = max(0.0, t_now - window_sec + pad_left_sec)
        w1 = w0 + window_sec

        # slice series in the window (by time array)
        mask = (t >= w0) & (t <= w1)
        if not np.any(mask):
            return lines + cursors + highlights + [txt_bin, txt_words, txt_sub]

        tw = t[mask] - w0  # shift to window coords

        bin_idx = _bin_index(t_now)
        bin_start = float(edges[bin_idx])
        bin_end = float(edges[bin_idx + 1])

        view_start = max(bin_start, w0)
        view_end = min(bin_end, w1)
        span_start = view_start - w0
        span_width = max(0.0, view_end - view_start)

        for i in range(12):
            yw = Y[i, mask]
            lines[i].set_data(tw, yw)
            cursors[i].set_xdata([t_now - w0, t_now - w0])
            axes[i].set_xlim(0, window_sec)
            highlights[i].set_x(span_start)
            highlights[i].set_width(span_width)

        bin_mask = (starts < bin_end) & (ends > bin_start)
        bin_words = [words[i] for i in np.where(bin_mask)[0]]
        words_text = _format_bin_words(bin_words)

        txt_bin.set_text(
            f"Bin {bin_start:6.2f}-{bin_end:6.2f}s | {len(bin_words)} words"
        )
        txt_words.set_text(words_text)
        txt_sub.set_text(
            f"t = {t_now:6.2f}s | bin {bin_idx + 1}/{len(edges) - 1} | domain: {use_domain}"
        )

        if frame % log_every_frames == 0 or frame == n_frames - 1:
            elapsed = time.perf_counter() - start_time
            progress = (frame + 1) / n_frames
            eta = elapsed / progress - elapsed if progress > 0 else float('inf')
            print(
                f"Frame {frame + 1}/{n_frames} "
                f"({progress * 100:.1f}%) | "
                f"elapsed {elapsed / 60:.1f}m | ETA {eta / 60:.1f}m"
            )

        return lines + cursors + highlights + [txt_bin, txt_words, txt_sub]

    anim = animation.FuncAnimation(
        fig, update, init_func=init,
        frames=n_frames, interval=1000/fps, blit=False
    )

    writer = animation.FFMpegWriter(
        fps=fps,
        codec="libx264",
        bitrate=5000,
        extra_args=[
            "-pix_fmt", "yuv420p",
            "-profile:v", "baseline",
            "-level", "3.0",
            "-movflags", "+faststart"
        ]
    )
    anim.save(out_mp4, writer=writer)
    plt.close(fig)
    print(f"Saved: {out_mp4}")


In [10]:
import shutil

if plt is None:
    raise RuntimeError('Matplotlib unavailable in this environment.')

if shutil.which('ffmpeg') is None:
    raise RuntimeError('ffmpeg not found on PATH. Install ffmpeg to write MP4 files.')

make_karaoke_category_video(
    result=result,
    token_df=TOKEN_BASE_DF,
    category_cols=category_cols_12,
    out_mp4=KARAOKE_OUTPUT,
    use_domain=KARAOKE_USE_DOMAIN,
    fps=KARAOKE_FPS,
    window_sec=KARAOKE_WINDOW_SEC,
    pad_left_sec=KARAOKE_PAD_LEFT_SEC,
    pad_right_sec=KARAOKE_PAD_RIGHT_SEC,
        log_every_frames=KARAOKE_LOG_EVERY_FRAMES,
    playback_speed=KARAOKE_PLAYBACK_SPEED,
    ypad_frac=KARAOKE_YLIM_PAD_FRAC,
    bin_words_max=KARAOKE_BIN_WORDS_MAX,
    zscore=KARAOKE_ZSCORE,
)


Frame 1/302 (0.3%) | elapsed 0.0m | ETA 0.0m
Frame 6/302 (2.0%) | elapsed 0.1m | ETA 2.7m
Frame 11/302 (3.6%) | elapsed 0.1m | ETA 2.5m
Frame 16/302 (5.3%) | elapsed 0.1m | ETA 2.4m
Frame 21/302 (7.0%) | elapsed 0.2m | ETA 2.5m
Frame 26/302 (8.6%) | elapsed 0.2m | ETA 2.4m
Frame 31/302 (10.3%) | elapsed 0.3m | ETA 2.5m
Frame 36/302 (11.9%) | elapsed 0.3m | ETA 2.4m
Frame 41/302 (13.6%) | elapsed 0.4m | ETA 2.3m
Frame 46/302 (15.2%) | elapsed 0.4m | ETA 2.3m
Frame 51/302 (16.9%) | elapsed 0.5m | ETA 2.2m
Frame 56/302 (18.5%) | elapsed 0.5m | ETA 2.2m
Frame 61/302 (20.2%) | elapsed 0.5m | ETA 2.2m
Frame 66/302 (21.9%) | elapsed 0.6m | ETA 2.1m
Frame 71/302 (23.5%) | elapsed 0.6m | ETA 2.1m
Frame 76/302 (25.2%) | elapsed 0.7m | ETA 2.0m
Frame 81/302 (26.8%) | elapsed 0.7m | ETA 2.0m
Frame 86/302 (28.5%) | elapsed 0.8m | ETA 1.9m
Frame 91/302 (30.1%) | elapsed 0.8m | ETA 1.9m
Frame 96/302 (31.8%) | elapsed 0.9m | ETA 1.9m
Frame 101/302 (33.4%) | elapsed 0.9m | ETA 1.8m
Frame 106/302 (35.1%

## Next steps

- If you want different categories, set `KARAOKE_CATEGORY_COLUMNS` to an explicit list of 12.
- For faster preview, reduce `KARAOKE_FPS` or `KARAOKE_WINDOW_SEC`.
- The output MP4 path is set by `KARAOKE_OUTPUT`.
