# POC — Visualize Gating by Task

Displays routing plots and metrics for the latest run of each task (AG News, CIFAR-10, Video).
Each task has multiple scenarios: base, ablation_small, ablation_wide, cap_low.

In [None]:
import os, sys
from pathlib import Path
import yaml
import matplotlib.pyplot as plt
from IPython.display import display, Image

ROOT = Path.cwd().resolve()
if ROOT.name == "notebooks":
    sys.path.append(str(ROOT.parent)); PROJ = ROOT.parent
else:
    sys.path.append(str(ROOT)); PROJ = ROOT

def get_latest_run(task_name, scenario="base"):
    """Get the latest run directory for a given task and scenario."""
    poc_root = PROJ / "checkpoints" / "poc"
    task_dir = poc_root / task_name / scenario
    
    if not task_dir.exists():
        print(f"Task directory not found: {task_dir}")
        return None
    
    # Find newest run folder (yyyy-mm-dd-hhmmss format)
    run_dirs = [d for d in task_dir.iterdir() if d.is_dir()]
    if not run_dirs:
        print(f"No run directories found in {task_dir}")
        return None
    
    latest_run = max(run_dirs, key=lambda p: p.stat().st_mtime)
    return latest_run

def display_core_plots(plots_dir):
    """Display core plots: loss curves and balance/entropy plots."""
    if not plots_dir.exists():
        print(f"Plots directory not found: {plots_dir}")
        return
    
    print("### Core Plots")
    
    # Loss curves
    loss_plot = plots_dir / "loss_curves.png"
    if loss_plot.exists():
        print("**Loss Curves:**")
        display(Image(filename=str(loss_plot)))
    else:
        print("Missing: loss_curves.png")
    
    # Balance/entropy plots from train folder
    balance_train_dir = plots_dir / "balance_entropy_train"
    if balance_train_dir.exists():
        balance_plots = sorted(balance_train_dir.glob("balance_layer*_train.png"))[:2]
        if balance_plots:
            print("**Balance Plots (Train):**")
            for p in balance_plots:
                display(Image(filename=str(p)))
    
    # Entropy plots
    entropy_plots = sorted(plots_dir.glob("entropy_layer*_train.png"))[:2]
    if entropy_plots:
        print("**Entropy Plots:**")
        for p in entropy_plots:
            display(Image(filename=str(p)))

def display_expert_usage(plots_dir):
    """Display expert usage bars and heatmaps."""
    if not plots_dir.exists():
        return
    
    print("### Expert Usage and Heatmaps")
    
    # Expert usage bar charts
    bars = sorted(plots_dir.glob("expert_usage_epoch*_layer*_*"))[:2]
    if bars:
        print("**Expert Usage (Bars):**")
        for p in bars:
            display(Image(filename=str(p)))
    
    # Heatmaps
    heats = sorted(plots_dir.glob("heatmap_epoch*_layer*_*"))[:2]
    if heats:
        print("**Routing Heatmaps:**")
        for p in heats:
            display(Image(filename=str(p)))

print("Project root:", PROJ)

## AG News Task

### AG News - Base Scenario

In [None]:
print("## AG News - Base Scenario")
ag_run_base = get_latest_run("ag_news", "base")

if ag_run_base:
    print(f"Using run: {ag_run_base}")
    plots_dir = ag_run_base / "plots"
    
    display_core_plots(plots_dir)
    display_expert_usage(plots_dir)
else:
    print("No AG News base runs found")

### AG News - Ablation Small Scenario

In [None]:
print("## AG News - Ablation Small Scenario")
ag_run_small = get_latest_run("ag_news", "ablation_small")

if ag_run_small:
    print(f"Using run: {ag_run_small}")
    plots_dir = ag_run_small / "plots"
    
    display_core_plots(plots_dir)
    display_expert_usage(plots_dir)
else:
    print("No AG News ablation_small runs found")

### AG News - Ablation Wide Scenario

In [None]:
print("## AG News - Ablation Wide Scenario")
ag_run_wide = get_latest_run("ag_news", "ablation_wide")

if ag_run_wide:
    print(f"Using run: {ag_run_wide}")
    plots_dir = ag_run_wide / "plots"
    
    display_core_plots(plots_dir)
    display_expert_usage(plots_dir)
else:
    print("No AG News ablation_wide runs found")

### AG News - Cap Low Scenario

In [None]:
print("## AG News - Cap Low Scenario")
ag_run_cap = get_latest_run("ag_news", "cap_low")

if ag_run_cap:
    print(f"Using run: {ag_run_cap}")
    plots_dir = ag_run_cap / "plots"
    
    display_core_plots(plots_dir)
    display_expert_usage(plots_dir)
else:
    print("No AG News cap_low runs found")

## CIFAR-10 Task

### CIFAR-10 - Base Scenario

In [None]:
print("## CIFAR-10 - Base Scenario")
cifar_run_base = get_latest_run("cifar10", "base")

if cifar_run_base:
    print(f"Using run: {cifar_run_base}")
    plots_dir = cifar_run_base / "plots"
    
    display_core_plots(plots_dir)
    display_expert_usage(plots_dir)
else:
    print("No CIFAR-10 base runs found")

### CIFAR-10 - Ablation Small Scenario

In [None]:
print("## CIFAR-10 - Ablation Small Scenario")
cifar_run_small = get_latest_run("cifar10", "ablation_small")

if cifar_run_small:
    print(f"Using run: {cifar_run_small}")
    plots_dir = cifar_run_small / "plots"
    
    display_core_plots(plots_dir)
    display_expert_usage(plots_dir)
else:
    print("No CIFAR-10 ablation_small runs found")

### CIFAR-10 - Ablation Wide Scenario

In [None]:
print("## CIFAR-10 - Ablation Wide Scenario")
cifar_run_wide = get_latest_run("cifar10", "ablation_wide")

if cifar_run_wide:
    print(f"Using run: {cifar_run_wide}")
    plots_dir = cifar_run_wide / "plots"
    
    display_core_plots(plots_dir)
    display_expert_usage(plots_dir)
else:
    print("No CIFAR-10 ablation_wide runs found")

### CIFAR-10 - Cap Low Scenario

In [None]:
print("## CIFAR-10 - Cap Low Scenario")
cifar_run_cap = get_latest_run("cifar10", "cap_low")

if cifar_run_cap:
    print(f"Using run: {cifar_run_cap}")
    plots_dir = cifar_run_cap / "plots"
    
    display_core_plots(plots_dir)
    display_expert_usage(plots_dir)
else:
    print("No CIFAR-10 cap_low runs found")

## Video Task

### Video - Base Scenario

In [None]:
print("## Video - Base Scenario")
video_run_base = get_latest_run("video", "base")

if video_run_base:
    print(f"Using run: {video_run_base}")
    plots_dir = video_run_base / "plots"
    
    if plots_dir.exists() and any(plots_dir.iterdir()):
        display_core_plots(plots_dir)
        display_expert_usage(plots_dir)
    else:
        print("Plots directory is empty (this is expected for video task)")
else:
    print("No Video base runs found")

### Video - Ablation Small Scenario

In [None]:
print("## Video - Ablation Small Scenario")
video_run_small = get_latest_run("video", "ablation_small")

if video_run_small:
    print(f"Using run: {video_run_small}")
    plots_dir = video_run_small / "plots"
    
    if plots_dir.exists() and any(plots_dir.iterdir()):
        display_core_plots(plots_dir)
        display_expert_usage(plots_dir)
    else:
        print("Plots directory is empty (this is expected for video task)")
else:
    print("No Video ablation_small runs found")

### Video - Ablation Wide Scenario

In [None]:
print("## Video - Ablation Wide Scenario")
video_run_wide = get_latest_run("video", "ablation_wide")

if video_run_wide:
    print(f"Using run: {video_run_wide}")
    plots_dir = video_run_wide / "plots"
    
    if plots_dir.exists() and any(plots_dir.iterdir()):
        display_core_plots(plots_dir)
        display_expert_usage(plots_dir)
    else:
        print("Plots directory is empty (this is expected for video task)")
else:
    print("No Video ablation_wide runs found")

### Video - Cap Low Scenario

In [None]:
print("## Video - Cap Low Scenario")
video_run_cap = get_latest_run("video", "cap_low")

if video_run_cap:
    print(f"Using run: {video_run_cap}")
    plots_dir = video_run_cap / "plots"
    
    if plots_dir.exists() and any(plots_dir.iterdir()):
        display_core_plots(plots_dir)
        display_expert_usage(plots_dir)
    else:
        print("Plots directory is empty (this is expected for video task)")
else:
    print("No Video cap_low runs found")