# Day26 Category Overlay Subject Sweep


Generate overlay grids for every story/run/option discovered under a subject root (e.g., ``/flash/PaoU/seann/fmri-edm-ccm/figs_new/UTS02``). The notebook keeps runs/options separate and saves overlays back into each option directory.

In [1]:
from pathlib import Path
from math import ceil
from typing import List, Sequence, Tuple, Optional

from PIL import Image
import matplotlib.pyplot as plt

# --- Configuration ---

ROOT_SUBJECT_DIR = Path('/flash/PaoU/seann/fmri-edm-ccm/figs_new/UTS02')

PLOT_SUFFIX = 'rho_progression'
PREDICTION_OVERLAY_SUFFIX = 'prediction_overlay'
CATEGORY_PREFIX = 'cat_'
NUM_COLUMNS = 3

VALID_WINDOW_PREFIXES = ('none_', 'movavg_', 'gauss_', 'guass_')

STORY_FILTERS = None  # e.g. ['wheretheressmoke']
RUN_FILTERS = None    # e.g. ['run-7']
OPTION_FILTERS = None # e.g. ['day26_smoothing_cli_MDE90step_detrend_gaussian_CAE']

GENERATE_WINDOW_OVERLAYS = True
GENERATE_BEST_PREDICTION_OVERLAYS = True
GENERATE_HEATMAPS = True

MAX_MISSING_LOG = 5
SUBJECT_OVERRIDE = None


In [None]:
import pandas as pd
import numpy as np
from matplotlib import cm

CATEGORIES_DISPLAY = [
    'Abstract', 'Communal', 'Emotional', 'Locational', 'Mental', 'Numeric',
    'Professional', 'Social', 'Tactile', 'Temporal', 'Violent', 'Visual'
]
METRIC_CONFIG = [
    ('ρ', 'rho', 'rho_test'),
    ('CAE', 'cae', 'cae_test'),
    ('RMSE', 'rmse', 'rmse_test'),
]
LAG_OPTIONS = ('lag', 'nolag')


def matches_filters(value: str, filters: Optional[Sequence[str]]) -> bool:
    if not filters:
        return True
    return value in filters


def resolve_subject(root_dir: Path, override: Optional[str]) -> str:
    return override or root_dir.name


def resolve_story_dirs(base_dir: Path, story: str, run: Optional[str] = None) -> List[Path]:
    if not base_dir.exists():
        return []
    direct = base_dir / story
    if direct.is_dir():
        return [direct]
    dirs = [path for path in sorted(base_dir.iterdir()) if path.is_dir()]
    exact_candidates = {story}
    if run:
        exact_candidates.update({f"{story}_{run}", f"{story}-{run}"})
    matches = [path for path in dirs if path.name in exact_candidates]
    if matches:
        return matches
    matches = [
        path for path in dirs
        if path.name.startswith(f"{story}_") or path.name.startswith(f"{story}-")
    ]
    if matches:
        return matches
    matches = [path for path in dirs if story in path.name]
    if matches:
        return matches
    if len(dirs) == 1:
        return dirs
    return []


def discover_option_dirs(
    root_dir: Path,
    category_prefix: str,
    story_filters: Optional[Sequence[str]] = None,
    run_filters: Optional[Sequence[str]] = None,
    option_filters: Optional[Sequence[str]] = None,
) -> List[Tuple[str, str, Path]]:
    options = []
    for story_dir in sorted(root_dir.iterdir()):
        if not story_dir.is_dir():
            continue
        if not matches_filters(story_dir.name, story_filters):
            continue
        for run_dir in sorted(story_dir.iterdir()):
            if not run_dir.is_dir():
                continue
            if not matches_filters(run_dir.name, run_filters):
                continue
            for option_dir in sorted(run_dir.iterdir()):
                if not option_dir.is_dir():
                    continue
                if not matches_filters(option_dir.name, option_filters):
                    continue
                if any(
                    child.is_dir() and child.name.startswith(category_prefix)
                    for child in option_dir.iterdir()
                ):
                    options.append((story_dir.name, run_dir.name, option_dir))
    return options


def discover_windows(
    option_dir: Path,
    category_prefix: str,
    valid_window_prefixes: Sequence[str],
) -> List[str]:
    windows = set()
    for category_dir in sorted(option_dir.iterdir()):
        if not category_dir.is_dir() or not category_dir.name.startswith(category_prefix):
            continue
        for child in category_dir.iterdir():
            if not child.is_dir():
                continue
            if child.name.startswith(tuple(valid_window_prefixes)):
                windows.add(child.name)
    return sorted(windows)


def collect_category_images_for_window(
    option_dir: Path,
    window: str,
    plot_suffix: str,
    category_prefix: str,
    subject: str,
    story: str,
    run: Optional[str],
) -> Tuple[Sequence[Tuple[str, Path]], Sequence[Path]]:
    discovered = []
    missing = []

    for category_dir in sorted(option_dir.iterdir()):
        if not category_dir.is_dir() or not category_dir.name.startswith(category_prefix):
            continue
        category_label = category_dir.name.rstrip('_')
        subject_dir = category_dir / window / subject
        story_dirs = resolve_story_dirs(subject_dir, story, run)

        image_path = None
        for story_dir in story_dirs:
            candidate = (
                story_dir
                / 'day22_category_mde'
                / 'plots'
                / category_label
                / f'mde_{category_label}_{plot_suffix}.png'
            )
            if candidate.exists():
                image_path = candidate
                break

        if image_path is None:
            fallback_story_dir = story_dirs[0] if story_dirs else subject_dir / story
            missing.append(
                fallback_story_dir
                / 'day22_category_mde'
                / 'plots'
                / category_label
                / f'mde_{category_label}_{plot_suffix}.png'
            )
            continue

        discovered.append((category_label, image_path))

    discovered.sort(key=lambda x: x[0])
    return discovered, missing


def make_overlay_figure(
    title: str,
    plot_suffix: str,
    category_images: Sequence[Tuple[str, Path]],
    num_columns: int,
):
    num_rows = ceil(len(category_images) / num_columns)
    fig, axes = plt.subplots(
        num_rows,
        num_columns,
        figsize=(num_columns * 4.0, num_rows * 3.2),
        squeeze=False,
    )
    axes_iter = axes.flatten()

    for ax in axes_iter:
        ax.axis('off')

    for ax, (category_label, image_path) in zip(axes_iter, category_images):
        with Image.open(image_path) as img:
            ax.imshow(img)
        pretty_label = category_label.replace('cat_', '').replace('_', ' ').title()
        ax.set_title(pretty_label, fontsize=10)

    fig.suptitle(f"{title} · {plot_suffix.replace('_', ' ')}", fontsize=14)
    fig.subplots_adjust(wspace=0.12, hspace=0.28)
    fig.tight_layout(rect=[0, 0, 1, 0.94], pad=0.6)
    return fig


def read_rho_for_window(option_dir: Path, window: str):
    csv_path = option_dir / f"{window}_rho_summary.csv"
    if not csv_path.exists():
        return None
    df = pd.read_csv(csv_path)
    if 'rho_test' in df.columns:
        col = 'rho_test'
    elif len(df.columns) >= 3:
        col = df.columns[2]
    else:
        col = df.columns[-1]
    df = df[['category', col]].copy()
    df.rename(columns={col: 'value'}, inplace=True)
    return df


def find_best_windows(option_dir: Path, window_list: Sequence[str]):
    best = {}
    for window in window_list:
        rho_df = read_rho_for_window(option_dir, window)
        if rho_df is None:
            print(f"  Skipping window {window}: missing rho summary.")
            continue
        for _, row in rho_df.iterrows():
            cat = str(row['category'])
            val = row['value']
            if pd.isna(val):
                continue
            current = best.get(cat)
            if current is None or val > current[1]:
                best[cat] = (window, float(val))
    return best


def collect_prediction_overlays(
    option_dir: Path,
    best_windows,
    subject: str,
    story: str,
    run: Optional[str],
    prediction_suffix: str,
):
    discovered = []
    missing = []
    category_dirs = {
        d.name.rstrip('_'): d
        for d in option_dir.iterdir()
        if d.is_dir() and d.name.startswith(CATEGORY_PREFIX)
    }
    for cat_key, (window, score) in sorted(best_windows.items()):
        category_dir = category_dirs.get(cat_key)
        if category_dir is None:
            missing.append(f"Missing category directory for {cat_key}")
            continue
        subject_dir = category_dir / window / subject
        story_dirs = resolve_story_dirs(subject_dir, story, run)
        image_path = None
        for story_dir in story_dirs:
            candidate = (
                story_dir
                / 'day22_category_mde'
                / 'plots'
                / cat_key
                / f"mde_{cat_key}_{prediction_suffix}.png"
            )
            if candidate.exists():
                image_path = candidate
                break
        if image_path is None:
            fallback_story_dir = story_dirs[0] if story_dirs else subject_dir / story
            missing.append(
                str(
                    fallback_story_dir
                    / 'day22_category_mde'
                    / 'plots'
                    / cat_key
                    / f"mde_{cat_key}_{prediction_suffix}.png"
                )
            )
            continue
        label = f"{cat_key} ({window}, rho={score:.3f})"
        discovered.append((label, image_path))
    return discovered, missing


def split_window_and_lag(name: str):
    for lag in LAG_OPTIONS:
        suffix = f"_{lag}"
        if name.endswith(suffix):
            return name[: -len(suffix)], lag
    return name, None


def window_to_seconds_label(window: str) -> str:
    if '_' not in window:
        return window
    numeric_part = window.split('_', 1)[1].replace('p', '.')
    try:
        value = float(numeric_part)
    except ValueError:
        return window
    if value.is_integer():
        value_str = str(int(value))
    else:
        value_str = f"{value:.2f}".rstrip('0').rstrip('.')
    return f"{value_str} second"


def window_sort_key(window: str) -> float:
    if '_' not in window:
        return float('inf')
    numeric_part = window.split('_', 1)[1].replace('p', '.')
    try:
        return float(numeric_part)
    except ValueError:
        return float('inf')


def format_category_name(raw_category: str) -> str:
    return raw_category.replace('cat_', '').replace('_', ' ').title()


def discover_windows_with_lag(option_dir: Path, valid_prefixes: tuple[str, ...]):
    buckets = {}
    for csv_path in option_dir.glob('*_rho_summary.csv'):
        base_name = csv_path.name.replace('_rho_summary.csv', '')
        window_base, lag = split_window_and_lag(base_name)
        if not window_base.startswith(valid_prefixes):
            continue
        buckets.setdefault(lag, set()).add(window_base)
    return {lag: sorted(wins, key=window_sort_key) for lag, wins in buckets.items()}


def read_metric_csv(csv_path: Path, preferred_test_col: str):
    if not csv_path.exists():
        return None
    df = pd.read_csv(csv_path)
    if preferred_test_col in df.columns:
        col = preferred_test_col
    elif len(df.columns) >= 3:
        col = df.columns[2]
    else:
        col = df.columns[-1]
    df = df[['category', col]].copy()
    df.rename(columns={col: 'value'}, inplace=True)
    return df


def build_metrics_table(option_dir: Path, windows: List[str], lag: Optional[str]) -> pd.DataFrame:
    rows = []
    index = []

    lag_suffix = f"_{lag}" if lag else ''
    lag_label = f" ({lag.replace('_', ' ')})" if lag else ''

    for window in windows:
        seconds_label = window_to_seconds_label(window)
        for metric_label, metric_key, preferred_col in METRIC_CONFIG:
            csv_path = option_dir / f"{window}{lag_suffix}_{metric_key}_summary.csv"
            metric_df = read_metric_csv(csv_path, preferred_col)

            row = {cat: np.nan for cat in CATEGORIES_DISPLAY}
            if metric_df is not None:
                for _, row_data in metric_df.iterrows():
                    pretty = format_category_name(str(row_data['category']))
                    if pretty in row:
                        row[pretty] = row_data['value']

            rows.append(row)
            index.append(f"{seconds_label}{lag_label} – {metric_label}")

    df = pd.DataFrame(rows, index=index, columns=CATEGORIES_DISPLAY)
    df['Average'] = df.mean(axis=1)
    return df


def build_heatmap_for_lag(
    option_dir: Path,
    windows: List[str],
    lag: Optional[str],
    subject: str,
    story: str,
    run: Optional[str],
):
    target_df = build_metrics_table(option_dir, windows, lag)

    rho_mask = target_df.index.str.contains('ρ')
    rho_values = target_df[rho_mask]
    if rho_values.empty:
        raise ValueError('No rho rows found; cannot color by rho.')
    rho_flat = rho_values.to_numpy(dtype=float).ravel()
    rho_flat = rho_flat[~np.isnan(rho_flat)]
    if rho_flat.size == 0:
        print(f"  Skipping heatmap for lag={lag or 'default'}: no numeric rho values.")
        return False

    rmin, rmax = rho_flat.min(), rho_flat.max()
    if rmin == rmax:
        rmax = rmin + 1e-9

    cmap = cm.get_cmap('RdYlGn')
    color_grid = np.ones((len(target_df.index), len(target_df.columns), 4))

    row_window_labels = [label.split('–')[0].strip() for label in target_df.index]

    for i, row_label in enumerate(target_df.index):
        window_label = row_window_labels[i]
        rho_row_label = f"{window_label} – ρ"
        if rho_row_label in target_df.index:
            rho_row = target_df.loc[rho_row_label]
        else:
            rho_row = target_df.iloc[i]

        for j, col_label in enumerate(target_df.columns):
            rho_val = rho_row.iloc[j]
            if pd.isna(rho_val):
                color_grid[i, j] = (0.85, 0.85, 0.85, 1.0)
                continue
            norm_val = (rho_val - rmin) / (rmax - rmin)
            norm_val = min(max(norm_val, 0.0), 1.0)
            color_grid[i, j] = cmap(norm_val)

    fig, ax = plt.subplots(figsize=(len(target_df.columns) * 0.9, len(target_df.index) * 0.6 + 2))
    ax.imshow(color_grid, aspect='auto')

    for i, row_label in enumerate(target_df.index):
        for j, col_label in enumerate(target_df.columns):
            value = target_df.iloc[i, j]
            if pd.isna(value):
                display_text = ''
            elif 'CAE' in row_label:
                display_text = f"{value:.2f}"
            else:
                display_text = f"{value:.3f}"

            face = color_grid[i, j][:3]
            brightness = 0.299 * face[0] + 0.587 * face[1] + 0.114 * face[2]
            text_color = 'black' if brightness > 0.5 else 'white'

            ax.text(j, i, display_text, ha='center', va='center', color=text_color, fontsize=8)

    ax.set_xticks(range(len(target_df.columns)))
    ax.set_xticklabels(target_df.columns, rotation=45, ha='right')
    ax.set_yticks(range(len(target_df.index)))
    ax.set_yticklabels(target_df.index)
    ax.set_xlabel('Semantic Categories')
    ax.set_ylabel('Smoothing Options / Metrics')

    main_title = 'Results: MDE (CAE) Best State Per Category Prediction Performance'
    lag_title = f" ({lag.replace('_', ' ')})" if lag else ''
    subtitle_parts = [f"Subject {subject}", f"Story {story}"]
    if run:
        subtitle_parts.append(run)
    subtitle_parts.append(option_dir.name)
    subtitle = f" {' · '.join(subtitle_parts)}{lag_title}"
    ax.set_title(f"{main_title}{subtitle}", fontsize=12)
    plt.tight_layout()

    suffix = f"_{lag}" if lag else ''
    heatmap_path = option_dir / f"day26_mde_per_category_heatmap{suffix}.png"
    fig.savefig(heatmap_path, dpi=200, bbox_inches='tight')
    plt.close(fig)
    print(f"  Saved heatmap to: {heatmap_path}")
    return True


def build_heatmaps_for_option(
    option_dir: Path,
    windows: List[str],
    valid_window_prefixes: Sequence[str],
    subject: str,
    story: str,
    run: Optional[str],
):
    windows_by_lag = {}
    for w in windows:
        base, lag = split_window_and_lag(w)
        if not base.startswith(tuple(valid_window_prefixes)):
            continue
        windows_by_lag.setdefault(lag, set()).add(base)

    if not windows_by_lag:
        windows_by_lag = discover_windows_with_lag(option_dir, tuple(valid_window_prefixes))

    if not windows_by_lag:
        print("  No smoothing windows discovered for heatmap generation.")
        return 0

    heatmaps_saved = 0

    for lag_key, win_set in windows_by_lag.items():
        ordered = sorted(win_set, key=window_sort_key)
        print(f"  Building heatmap for lag group: {lag_key or 'default'}")
        for w in ordered:
            print(f"    - {w}")
        try:
            if build_heatmap_for_lag(option_dir, ordered, lag_key, subject, story, run):
                heatmaps_saved += 1
        except ValueError as exc:
            print(f"  Skipping heatmap for lag={lag_key or 'default'}: {exc}")

    return heatmaps_saved


def process_option_dir(
    option_dir: Path,
    story: str,
    run: str,
    subject: str,
    plot_suffix: str,
    prediction_suffix: str,
    category_prefix: str,
    num_columns: int,
    valid_window_prefixes: Sequence[str],
    generate_window_overlays: bool,
    generate_best_prediction_overlays: bool,
    generate_heatmaps: bool,
    max_missing_log: int,
):
    print(f"=== {story} / {run} / {option_dir.name} ===")
    windows = discover_windows(option_dir, category_prefix, valid_window_prefixes)
    if windows:
        print(f"  Discovered windows ({len(windows)}): {', '.join(windows)}")
    else:
        print("  No smoothing windows found under category directories.")

    overlay_count = 0

    if generate_window_overlays and windows:
        for window in windows:
            print(f"  Processing window: {window}")
            category_images, missing_images = collect_category_images_for_window(
                option_dir,
                window,
                plot_suffix,
                category_prefix,
                subject,
                story,
                run,
            )
            print(f"    Located {len(category_images)} plot(s).")
            if missing_images:
                print(f"    Missing {len(missing_images)} plot(s) (showing up to {max_missing_log}):")
                for p in missing_images[:max_missing_log]:
                    print(f"      - {p}")

            if not category_images:
                print("    No plots found; skipping overlay.")
                continue

            fig = make_overlay_figure(window, plot_suffix, category_images, num_columns)
            output_path = option_dir / f"{window}_{plot_suffix}_overlay.png"
            output_path.parent.mkdir(parents=True, exist_ok=True)
            fig.savefig(output_path, dpi=200, bbox_inches='tight')
            plt.close(fig)

            overlay_count += 1
            print(f"    Saved overlay to: {output_path}")

    best_overlay_saved = False

    if generate_best_prediction_overlays and windows:
        best_windows = find_best_windows(option_dir, windows)
        if not best_windows:
            print("  No best windows determined; skipping prediction overlay.")
        else:
            best_images, missing_overlays = collect_prediction_overlays(
                option_dir,
                best_windows,
                subject,
                story,
                run,
                prediction_suffix,
            )
            if missing_overlays:
                print(f"  Missing prediction overlay files (showing up to {max_missing_log}):")
                for path in missing_overlays[:max_missing_log]:
                    print(f"    - {path}")

            if best_images:
                best_fig = make_overlay_figure(
                    'best_per_category',
                    prediction_suffix,
                    best_images,
                    num_columns,
                )
                best_output_path = option_dir / "best_prediction_overlay.png"
                best_output_path.parent.mkdir(parents=True, exist_ok=True)
                best_fig.savefig(best_output_path, dpi=200, bbox_inches='tight')
                plt.close(best_fig)
                best_overlay_saved = True
                print(f"  Saved best-window prediction overlay to: {best_output_path}")
            else:
                print("  No prediction overlay images found to assemble.")

    heatmaps_saved = 0

    if generate_heatmaps:
        heatmaps_saved = build_heatmaps_for_option(
            option_dir,
            windows,
            valid_window_prefixes,
            subject,
            story,
            run,
        )

    return {
        'windows': len(windows),
        'overlays': overlay_count,
        'best_overlay': best_overlay_saved,
        'heatmaps': heatmaps_saved,
    }


SyntaxError: unterminated string literal (detected at line 267) (2424381283.py, line 267)

In [None]:
# --- Discover available story/run/option directories ---

if not ROOT_SUBJECT_DIR.exists():
    raise FileNotFoundError(f"ROOT_SUBJECT_DIR does not exist: {ROOT_SUBJECT_DIR}")

subject = resolve_subject(ROOT_SUBJECT_DIR, SUBJECT_OVERRIDE)
options = discover_option_dirs(
    ROOT_SUBJECT_DIR,
    CATEGORY_PREFIX,
    STORY_FILTERS,
    RUN_FILTERS,
    OPTION_FILTERS,
)

if not options:
    raise ValueError("No option directories found under the subject root.")

print(f"Subject: {subject}")
print(f"Discovered {len(options)} option directorie(s):")
for story, run, option_dir in options:
    print(f"  - {story} / {run} / {option_dir.name}")


In [None]:
# --- Generate overlays for each option directory ---

total_overlays = 0
total_best_overlays = 0
total_heatmaps = 0

for story, run, option_dir in options:
    result = process_option_dir(
        option_dir,
        story,
        run,
        subject,
        PLOT_SUFFIX,
        PREDICTION_OVERLAY_SUFFIX,
        CATEGORY_PREFIX,
        NUM_COLUMNS,
        VALID_WINDOW_PREFIXES,
        GENERATE_WINDOW_OVERLAYS,
        GENERATE_BEST_PREDICTION_OVERLAYS,
        GENERATE_HEATMAPS,
        MAX_MISSING_LOG,
    )
    total_overlays += result['overlays']
    total_best_overlays += int(result['best_overlay'])
    total_heatmaps += result['heatmaps']

print("\nDone.")
print(f"Window overlays saved: {total_overlays}")
print(f"Best-window overlays saved: {total_best_overlays}")
print(f"Heatmaps saved: {total_heatmaps}")
