# Experiment 2.9.1: Delete only red with two repulsive regularizers

This is a re-run of [Ex 2.9](./ex-2.9-delete-only-red-5d.ipynb) with more mature tooling. See the earlier notebook for discussion. Unlike 2.9, we only use one label (_red_), but unlike 2.10.1, we use both an anti-subspace and anti-anchor regularizer.

In [1]:
from __future__ import annotations

nbid = '2.9.1'  # ID for tagging assets
nbname = 'Ablate red (only), 5D'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-preppy'

In [2]:
# Basic setup: Logging, Experiment (Modal)
import logging

import modal

from infra.requirements import freeze, project_packages
from mini.experiment import Experiment
from utils.logging import SimpleLoggingConfig

logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')

run = Experiment(experiment_name, project=project)
run.image = modal.Image.debian_slim().pip_install(*freeze(all=True)).add_local_python_source(*project_packages())
run.before_each(logging_config.apply)
None  # prevent auto-display of this cell

## Model parameters

Like Ex 2.9, we use the following regularizers:

- **Anchor:** pins `red` to $(1,0,0,0,0)$ (5D)
- **AxisAlignedSubspace:** repels everything from dimension $1$ (with varying weight, see schedule)
- **Separate:** angular repulsion to reduce global clumping (applied within each batch)

Since we're isolating _red_, we have 5D latent embeddings and two nonlinear activation functions in the encoder and decoder, to allow the latent space to be warped more.

But unlike 2.9:
- **Anti-anchor:** has been removed, relying on anti-subspace to keep other concepts clear of the dimension to be ablated.
- **Unitarity:** is present in this list, but we'll do a run without it too.

In [3]:
import torch

from ex_color.loss import AngularAnchor, AntiAnchor, AxisAlignedSubspace, Separate, RegularizerConfig

from ex_color.training import TrainingModule

K = 5  # bottleneck dimensionality
N = 2  # number of nonlinear layers
H = 10  # hidden layer size
RED = (1,) + (0,) * (K - 1)
ANTI_RED = (-1,) + (0,) * (K - 1)
assert len(RED) == len(ANTI_RED) == K
BATCH_SIZE = 64
CUBE_SUBDIVISIONS = 8
NUM_RUNS = 100  # to probe seed sensitivity
RUN_SEEDS = [i for i in range(NUM_RUNS)]

reg_separate = RegularizerConfig(
    name='separate',
    compute_loss_term=Separate(power=100.0, shift=True),
    label_affinities=None,
    layer_affinities=['bottleneck'],
)
reg_anchor = RegularizerConfig(
    name='anchor',
    compute_loss_term=AngularAnchor(torch.tensor(RED, dtype=torch.float32)),
    label_affinities={'red': 1.0},
    layer_affinities=['bottleneck'],
    phase=('train', 'validate'),
)
reg_anti_anchor = RegularizerConfig(
    name='anti-anchor',
    compute_loss_term=AntiAnchor(torch.tensor(ANTI_RED, dtype=torch.float32)),
    label_affinities=None,
    layer_affinities=['bottleneck'],
    phase=('train', 'validate'),
)
reg_anti_subspace = RegularizerConfig(
    name='anti-subspace',
    compute_loss_term=AxisAlignedSubspace((0,), invert=True),
    label_affinities=None,
    layer_affinities=['bottleneck'],
)

In [4]:
from typing import cast

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from IPython.display import display, Markdown

from mini.temporal.dopesheet import Dopesheet
from mini.temporal.timeline import Timeline
from mini.temporal.vis import plot_timeline, realize_timeline, ParamGroup
from utils.nb import displayer_mpl
from utils.plt import Theme

dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')

display(Markdown(f"""## Parameter schedule \n{dopesheet.to_markdown()}"""))


def plot_dopesheet(dopesheet: Dopesheet, theme: Theme):
    timeline = Timeline(dopesheet)
    history_df = realize_timeline(timeline)
    keyframes_df = dopesheet.as_df()

    fig = plt.figure(figsize=(9, 3), constrained_layout=True)
    axs = fig.subplots(2, 1, sharex=True, height_ratios=[3, 1])
    ax1, ax2 = cast(tuple[Axes, ...], axs)

    plot_timeline(
        history_df,
        keyframes_df,
        groups=(ParamGroup(name='', params=[p for p in dopesheet.props if p not in {'lr'}]),),
        theme=theme,
        ax=ax1,
        show_phase_labels=False,
    )
    ax1.set_ylabel('Weight')
    ax1.set_xlabel('')

    plot_timeline(
        history_df,
        keyframes_df,
        groups=(ParamGroup(name='', params=['lr']),),
        theme=theme,
        ax=ax2,
        show_legend=False,
        show_phase_labels=False,
    )
    ax2.set_ylabel('LR')

    # add a little space on the y-axis extents
    ax1.set_ylim(ax1.get_ylim()[0] * 1.1, ax1.get_ylim()[1] * 1.1)
    ax2.set_ylim(ax2.get_ylim()[0] * 1.1, ax2.get_ylim()[1] * 1.1)

    return fig


with displayer_mpl(
    f'large-assets/ex-{nbid}-dopesheet.png',
    alt_text="""Plot showing the parameter schedule for the training run, titled "{title}". The plot has two sections: the upper section shows various regularization weights over time, and the lower section shows the learning rate over time. The x-axis represents training steps.""",
) as show:
    show(lambda theme: plot_dopesheet(dopesheet, theme))

## Parameter schedule 
|   STEP | PHASE   |   ACTION |      lr |   separate |   anchor |   anti-anchor |   anti-subspace |
|-------:|:--------|---------:|--------:|-----------:|---------:|--------------:|----------------:|
|      0 | Train   |          |   1e-08 |            |      0   |          0    |           0.25  |
|     10 |         |          |   0.01  |            |          |               |                 |
|    248 |         |          |         |      0.01  |      0.1 |          0.05 |                 |
|    750 |         |          |   0.1   |      0.001 |      0.1 |               |           0.003 |
|   1425 |         |          |   0.1   |      0     |      0   |          0    |           0     |
|   1500 |         |          |   0.05  |            |          |               |                 |

## Data

Data is the same as last time: color cubes with values in RGB.


In [5]:
from torch import Tensor
from torch.utils.data import DataLoader, RandomSampler

from ex_color.data.color_cube import ColorCube
from ex_color.data.cube_dataset import prep_color_dataset, redness, stochastic_labels, exact_labels


def prep_train_data(training_subs: int, *, batch_size: int) -> DataLoader:
    dataset = prep_color_dataset(
        training_subs,
        sample_at='cell-corners',
        red=lambda c: redness(c) ** 8 * 0.08,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=4,
        sampler=RandomSampler(dataset, num_samples=len(dataset), replacement=True),
        collate_fn=stochastic_labels,
    )


def prep_val_data(training_subs: int, *, batch_size: int) -> DataLoader:
    dataset = prep_color_dataset(
        training_subs,
        sample_at='cell-centers',
        red=lambda c: redness(c) == 1,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=2,
        sampler=RandomSampler(dataset, num_samples=batch_size, replacement=True),
        collate_fn=exact_labels,
    )

For test data it can be useful to use an HSV cube: then it can be sliced for visualization. The RGB cube is also useful for visualizations and quantitative analyses where an unbiased distribution is needed.

In [6]:
import numpy as np

# from ex_color.data.color_cube import ColorCube
from ex_color.data.cyclic import arange_cyclic


hsv_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=1 / 12),
    s=np.linspace(0, 1, 4),
    v=np.linspace(0, 1, 5),
)

n_h = hsv_cube.shape[0]

hd_hsv_cube = ColorCube.from_hsv(
    # Extend hue range to encompass the end pixels of the low-res cube above
    h=np.linspace(0 - 1 / n_h, 1 + 1 / n_h, 300),
    s=np.linspace(0, 1, 48),
    v=np.linspace(0, 1, 48),
)
hd_hsv_cube = hd_hsv_cube[::2, ::2, ::2]

rgb_cube = ColorCube.from_rgb(
    r=np.linspace(0, 1, 20),
    g=np.linspace(0, 1, 20),
    b=np.linspace(0, 1, 20),
)

# with displayer_mpl(
#     f'large-assets/ex-{nbid}-true-colors.png',
#     alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. The first slice shows a grayscale gradient from black to white; the last shows the fully-saturated color spectrum.""",
# ) as show:
#     show(lambda: plot_colors(hsv_cube, title='True colors'))

## Training & inference utils

In [15]:
from tempfile import gettempdir
import wandb
from ex_color.model import CNColorMLP


# @run.thither(env={'WANDB_API_KEY': wandb.Api().api_key})
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
    k_bottleneck: int,
    n_nonlinear: int,
    k_codec: int,
    *,
    seed: int | None = None,
) -> CNColorMLP:
    """Train the model with the given dopesheet and variant."""
    import lightning as L
    from lightning.pytorch.loggers import WandbLogger

    from ex_color.callbacks import LabelProportionCallback
    from ex_color.seed import set_deterministic_mode

    from utils.progress.lightning import LightningProgress

    log.info(f'Training with: {[r.name for r in regularizers]}')

    if seed is not None:
        set_deterministic_mode(seed)

    train_loader = prep_train_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)
    val_loader = prep_val_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)

    model = CNColorMLP(k_bottleneck, n_nonlinear=n_nonlinear, k_codec=k_codec)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.debug(f'Model initialized with {total_params:,} trainable parameters.')

    training_module = TrainingModule(model, dopesheet, torch.nn.MSELoss(), regularizers)

    logger = WandbLogger(experiment_name, project=project, save_dir=gettempdir())
    logger.log_hyperparams({'seed': seed})

    trainer = L.Trainer(
        max_steps=len(dopesheet),
        callbacks=[
            LightningProgress(),
            LabelProportionCallback(prefix='labels', get_active_labels=lambda: training_module.active_labels),
        ],
        enable_checkpointing=False,
        enable_model_summary=False,
        # enable_progress_bar=True,
        val_check_interval=len(dopesheet) // 10,
        check_val_every_n_epoch=None,
        logger=logger,
        log_every_n_steps=min(50, len(train_loader)),
    )

    print(f'max_steps: {len(dopesheet)}, train_loader length: {len(train_loader)}')

    # Train the model
    try:
        trainer.fit(training_module, train_loader, val_loader)
    finally:
        wandb.finish()
    # This is only a small model, so it's OK to return it rather than storing and loading a checkpoint remotely
    return model

In [16]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

from ex_color.inference import InferenceModule
from ex_color.intervention.intervention import InterventionConfig


async def infer_with_latent_capture(
    model: CNColorMLP,
    test_data: Tensor,
    interventions: list[InterventionConfig],
    layer_name: str = 'bottleneck',
) -> tuple[Tensor, Tensor]:
    module = InferenceModule(model, interventions, capture_layers=[layer_name])
    import lightning as L

    trainer = L.Trainer(enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False, logger=False)
    batches = trainer.predict(
        module,
        DataLoader(
            TensorDataset(test_data.reshape((-1, 3))),
            batch_size=64,
            collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
        ),
    )
    assert batches is not None
    preds = [item for batch in batches for item in batch]
    y = torch.cat(preds).reshape(test_data.shape)
    # Read captured activations as a flat [N, D] tensor
    latents = module.read_captured(layer_name)
    return y, latents

In [17]:
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import pandas as pd

from torch.nn import functional as F


async def test(model: CNColorMLP, interventions: list[InterventionConfig], test_data: ColorCube) -> ColorCube:
    x = torch.tensor(test_data.rgb_grid, dtype=torch.float32)
    y, h = await infer_with_latent_capture(model, x, interventions, 'bottleneck')
    per_color_loss = F.mse_loss(y, x, reduction='none').mean(dim=-1)
    return test_data.assign(
        recon=y.numpy().reshape((*test_data.shape, -1)),
        MSE=per_color_loss.numpy().reshape((*test_data.shape, -1)),
        latents=h.numpy().reshape((*test_data.shape, -1)),
    )


async def test_named(
    model: CNColorMLP, interventions: list[InterventionConfig], test_data: pd.DataFrame
) -> pd.DataFrame:
    x = torch.tensor(test_data['rgb'], dtype=torch.float32)
    y, _ = await infer_with_latent_capture(model, x, interventions, 'bottleneck')
    per_color_loss = F.mse_loss(y, x, reduction='none').mean(dim=-1)
    y_tuples = [tuple(row) for row in y.numpy()]
    return test_data.assign(recon=y_tuples, MSE=per_color_loss.numpy())  # pyright: ignore[reportArgumentType]

In [18]:
# # Generate a list of dimensions to visualize
# from itertools import combinations
# [
#     (
#         b,
#         a,
#         (a + 1) % 5 if (a + 1) % 5 not in (a, b) else (a + 2) % 5,
#     )
#     for a, b in combinations((0, 1, 2, 3, 4), 2)
# ]

In [19]:
from typing import Sequence

from ex_color.vis import (
    plot_colors,
    plot_cube_series,
    plot_latent_grid_3d_from_cube,
)
from utils.nb import displayer_mpl


def tags_for_file(tags: Sequence[str]) -> str:
    import re

    tags = [re.sub(r'[^a-zA-Z0-9]+', '-', tag.lower()) for tag in tags]
    return '-'.join(tags)


def visualize_reconstructed_cube(data: ColorCube, *, tags: Sequence[str] = ()):
    with displayer_mpl(
        f'large-assets/ex-{nbid}-pred-colors-{tags_for_file(tags)}.png',
        alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color.""",
    ) as show:
        show(
            lambda: plot_colors(
                data,
                title=f'Predicted colors · {" · ".join(tags)}',
                colors='recon',
                colors_compare='color',
            )
        )


def visualize_reconstruction_loss(data: ColorCube, *, tags: Sequence[str] = ()):
    max_loss = np.max(data['MSE'])
    median_loss = np.median(data['MSE'])
    with displayer_mpl(
        f'large-assets/ex-{nbid}-loss-colors-{tags_for_file(tags)}.png',
        alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue.""",
    ) as show:
        show(
            lambda: plot_cube_series(
                data.permute('hsv')[:, -1:, :: (data.shape[2] // -5)],
                data.permute('svh')[:, -1:, :: -(data.shape[0] // -3)],
                data.permute('vsh')[:, -1:, :: -(data.shape[0] // -3)],
                title=f'Reconstruction error · {" · ".join(tags)}',
                var='MSE',
                figsize=(12, 3),
            )
        )
    print(f'Max loss: {max_loss:.2g}')
    print(f'Median MSE: {median_loss:.2g}')


def visualize_latent_space(data: ColorCube, *, tags: Sequence[str] = (), dims: Sequence[tuple[int, int, int]]):
    with displayer_mpl(
        f'large-assets/ex-{nbid}-latents-{tags_for_file(tags)}.png',
        alt_text="""Two rows of three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a hypersphere, with each plot showing one 2D projection.""",
    ) as show:
        show(
            lambda theme: plot_latent_grid_3d_from_cube(
                data,
                colors='recon',
                colors_compare='color',
                latents='latents',
                title=f'Latents ·  · {" · ".join(tags)}',
                dims=dims,
                dot_radius=10,
                theme=theme,
            )
        )

In [20]:
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any, Callable

from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from skimage.color import hsv2rgb, rgb2hsv
from scipy.stats import pearsonr

from ex_color.data import hsv_similarity
from ex_color.data.color_cube import ColorCube
from ex_color.intervention.intervention import InterventionConfig
from ex_color.vis import draw_stacked_results, ConicalAnnotation, draw_cube_scatter
from ex_color.vis.plot_latent_slices import LatentD
from utils.nb import displayer_mpl
from utils.plt import Theme

ANCHOR_HSV = (0.0, 1.0, 1.0)


@dataclass
class Resultset:
    tags: Sequence[str]
    latent_cube: ColorCube
    color_slice_cube: ColorCube
    loss_cube: ColorCube
    named_colors: pd.DataFrame


def visualize_stacked_results(
    res: Resultset,
    *,
    latent_dims: tuple[LatentD, LatentD],
    max_error: float | None = None,
    latent_annotations: Sequence[ConicalAnnotation | Callable[[Theme], ConicalAnnotation]] = (),
) -> None:
    with displayer_mpl(
        f'large-assets/ex-{nbid}-results-{tags_for_file(res.tags)}.png',
        alt_text="""Composite figure with two latent panels (top), a color slice (middle), and a loss chart (bottom).""",
    ) as show:
        show(
            lambda theme: draw_stacked_results(
                res.latent_cube,
                res.color_slice_cube,
                res.loss_cube,
                latent_dims=latent_dims,
                theme=theme,
                max_error=max_error,
                latent_annotations=[
                    ann(theme) if not isinstance(ann, ConicalAnnotation) else ann for ann in latent_annotations
                ],
            )
        )


def scatter_similarity_vs_error(
    cube: ColorCube,
    anchor_hsv: tuple[float, float, float],
    *,
    anchor_name: str,
    theme: Theme,
    power: float,
):
    """Scatter plot of similarity to anchor vs reconstruction error."""
    cube = cube.assign(hsv=rgb2hsv(cube['color']))
    cube = cube.assign(similarity=hsv_similarity(cube['hsv'], np.array(anchor_hsv), hemi=True, mode='angular') ** power)

    fig, ax = plt.subplots(figsize=(4, 4), constrained_layout=True)
    draw_cube_scatter(ax, cube, theme=theme, x_var='similarity', y_var='MSE')
    ax.set_ylabel(r'MSE')
    ax.set_xlabel(rf'$\text{{sim}}_\text{{{anchor_name}}}^{{{power:.2g}}}$')
    ax.legend(loc='upper left')
    return fig


def visualize_error_vs_similarity(
    cube: ColorCube,
    anchor_hsv: tuple[float, float, float],
    *,
    tags: Sequence[str] = (),
    anchor_name: str = 'anchor',
    power: float,
) -> None:
    with displayer_mpl(
        f'large-assets/ex-{nbid}-error-vs-similarity-{tags_for_file(tags)}.png',
        alt_text=f"""Scatter plot showing reconstruction error versus similarity to {anchor_name}. Each point represents a color, with its position on the x-axis indicating how similar it is to pure red, and its position on the y-axis indicating the reconstruction error (mean squared error) for that color. The points are colored according to their actual color values.""",
    ) as show:
        show(
            lambda theme: scatter_similarity_vs_error(
                cube, anchor_hsv, theme=theme, anchor_name=anchor_name, power=power
            )
        )


def hstack_named_results(*res: Resultset) -> pd.DataFrame:
    """Create a table of results across several experiments."""
    names = [' '.join(r.tags) for r in res]
    df = res[0].named_colors[['name', 'rgb', 'hsv', 'MSE']].rename(columns={'MSE': names[0]})
    for name, r in zip(names[1:], res[1:], strict=True):
        df = df.merge(
            r.named_colors[['name', 'MSE']].rename(columns={'MSE': name}),
            on='name',
        )
        df[f'{name}-delta'] = df[name] - df[names[0]]
    return df


def metric_prefix(name: str) -> str:
    return name.replace(' ', '_').replace('-', '_')


@dataclass(frozen=True)
class CorrelationStats:
    correlation: float
    r_squared: float
    p_value: float
    power: float

    def to_row(self, prefix: str) -> dict[str, float]:
        return {
            f'{prefix}_r': self.correlation,
            f'{prefix}_r2': self.r_squared,
            f'{prefix}_p': self.p_value,
            f'{prefix}_power': self.power,
        }


@dataclass(frozen=True)
class CorrelationSpec:
    plan: str
    anchor_hsv: tuple[float, float, float]
    power: float


@dataclass
class EvaluationContext:
    model: CNColorMLP
    interventions: Sequence[InterventionConfig]
    extras: dict[str, Any] = field(default_factory=dict)


@dataclass(frozen=True)
class EvaluationPlan:
    name: str
    tags: Sequence[str]
    setup: Callable[[CNColorMLP], EvaluationContext]


@dataclass(frozen=True)
class RunMetrics:
    seed: int
    correlations: Mapping[str, CorrelationStats]

    def to_row(self) -> dict[str, float]:
        row: dict[str, float] = {'seed': float(self.seed)}
        for name, stats in sorted(self.correlations.items()):
            prefix = metric_prefix(name)
            row.update(stats.to_row(prefix))
            row[f'{prefix}_abs'] = abs(stats.correlation)
        return row

    def score(self, plan: str) -> float:
        stats = self.correlations.get(plan)
        if stats is None:
            raise KeyError(f'No correlation recorded for plan {plan!r}.')
        return abs(stats.correlation)


@dataclass
class BestRunArtifacts:
    seed: int
    model: CNColorMLP
    metrics: RunMetrics
    plans: Mapping[str, EvaluationPlan]
    results: dict[str, Resultset]
    contexts: dict[str, EvaluationContext]
    named_colors: pd.DataFrame

    def get_extra(self, plan: str, key: str, default: Any | None = None) -> Any | None:
        context = self.contexts.get(plan)
        if context is None:
            return default
        return context.extras.get(key, default)


def error_correlation(cube: ColorCube, anchor_hsv: tuple[float, float, float], *, power: float) -> tuple[float, float]:
    """Compute correlation between similarity to anchor and reconstruction error."""
    cube = cube.assign(hsv=rgb2hsv(cube['color']))
    cube = cube.assign(similarity=hsv_similarity(cube['hsv'], np.array(anchor_hsv), hemi=True, mode='angular') ** power)
    corr, p_value = pearsonr(cube['similarity'].flatten(), cube['MSE'].flatten())
    return float(corr), float(p_value)


def correlation_stats(cube: ColorCube, *, anchor_hsv: tuple[float, float, float], power: float) -> CorrelationStats:
    corr, p_value = error_correlation(cube, anchor_hsv, power=power)
    return CorrelationStats(correlation=corr, r_squared=corr**2, p_value=p_value, power=power)


async def evaluate_seed(
    seed: int,
    plans: Mapping[str, EvaluationPlan],
    correlation_specs: Sequence[CorrelationSpec],
) -> tuple[RunMetrics, CNColorMLP, dict[str, EvaluationContext]]:
    model = await train(
        dopesheet,
        [reg_separate, reg_anchor, reg_anti_anchor, reg_anti_subspace],
        K,
        N,
        H,
        seed=seed,
    )

    correlations: dict[str, CorrelationStats] = {}
    contexts: dict[str, EvaluationContext] = {}
    for spec in correlation_specs:
        plan = plans[spec.plan]
        context = contexts.get(plan.name)
        if context is None:
            context = plan.setup(model)
            contexts[plan.name] = context
        cube = await test(context.model, list(context.interventions), rgb_cube)
        correlations[plan.name] = correlation_stats(cube, anchor_hsv=spec.anchor_hsv, power=spec.power)
    metrics = RunMetrics(seed=seed, correlations=correlations)
    return metrics, model, contexts


async def collect_full_results(
    seed: int,
    model: CNColorMLP,
    plans: Sequence[EvaluationPlan],
    *,
    named_colors_factory: Callable[[], pd.DataFrame],
    correlation_specs: Sequence[CorrelationSpec],
    precomputed_contexts: Mapping[str, EvaluationContext] | None = None,
) -> BestRunArtifacts:
    contexts: dict[str, EvaluationContext] = {}
    results: dict[str, Resultset] = {}
    base_named_colors = named_colors_factory()
    for plan in plans:
        context = precomputed_contexts.get(plan.name) if precomputed_contexts else None
        if context is None:
            context = plan.setup(model)
        contexts[plan.name] = context
        interventions = list(context.interventions)
        results[plan.name] = Resultset(
            tags=list(plan.tags),
            latent_cube=await test(context.model, interventions, rgb_cube),
            color_slice_cube=await test(context.model, interventions, hsv_cube),
            loss_cube=await test(context.model, interventions, hd_hsv_cube),
            named_colors=await test_named(context.model, interventions, base_named_colors.copy()),
        )
    correlations: dict[str, CorrelationStats] = {}
    for spec in correlation_specs:
        correlations[spec.plan] = correlation_stats(
            results[spec.plan].latent_cube,
            anchor_hsv=spec.anchor_hsv,
            power=spec.power,
        )
    metrics = RunMetrics(seed=seed, correlations=correlations)
    plans_by_name = {plan.name: plan for plan in plans}
    return BestRunArtifacts(
        seed=seed,
        model=model,
        metrics=metrics,
        plans=plans_by_name,
        results=results,
        contexts=contexts,
        named_colors=base_named_colors,
    )


async def run_multi_seed_training(
    seeds: Sequence[int],
    plans: Sequence[EvaluationPlan],
    correlation_specs: Sequence[CorrelationSpec],
    *,
    best_plan: str,
    named_colors_factory: Callable[[], pd.DataFrame],
) -> tuple[list[RunMetrics], BestRunArtifacts]:
    metrics: list[RunMetrics] = []
    best: tuple[RunMetrics, CNColorMLP, dict[str, EvaluationContext]] | None = None
    plans_by_name = {plan.name: plan for plan in plans}

    total_runs = len(seeds)
    for index, seed in enumerate(seeds, start=1):
        clear_output()
        print(f'Best so far: {best[0].to_row() if best is not None else "N/A"}')
        print(f'[{index}/{total_runs}] Training seed {seed}...')
        run_metrics, model, contexts = await evaluate_seed(seed, plans_by_name, correlation_specs)
        metrics.append(run_metrics)
        if best is None or run_metrics.score(best_plan) > best[0].score(best_plan):
            best = (run_metrics, model, contexts)
        else:
            del model

    assert best is not None
    best_metrics, best_model, best_contexts = best
    best_artifacts = await collect_full_results(
        best_metrics.seed,
        best_model,
        plans,
        named_colors_factory=named_colors_factory,
        correlation_specs=correlation_specs,
        precomputed_contexts=best_contexts,
    )
    return metrics, best_artifacts

In [21]:
from math import cos, radians

from ex_color.data.color import get_named_colors_df
from ex_color.intervention import BoundedFalloff, Suppression
from ex_color.intervention.intervention import InterventionConfig
from ex_color.surgery import ablate, prune


def no_intervention_plan(model: CNColorMLP) -> EvaluationContext:
    return EvaluationContext(model=model, interventions=())


def ablation_plan(model: CNColorMLP) -> EvaluationContext:
    ablated = ablate(model, 'bottleneck', [0])
    return EvaluationContext(model=ablated, interventions=())


def pruning_plan(model: CNColorMLP) -> EvaluationContext:
    pruned = prune(model, 'bottleneck', [0])
    return EvaluationContext(model=pruned, interventions=())


def suppression_plan(model: CNColorMLP) -> EvaluationContext:
    falloff = BoundedFalloff(
        cos(radians(90)),
        1.0,
        0.0,
    )
    config = InterventionConfig(
        apply=Suppression(torch.tensor(RED), falloff),
        layer_affinities=['bottleneck'],
    )
    return EvaluationContext(
        model=model,
        interventions=[config],
        extras={'config': config, 'falloff': falloff},
    )


EVALUATION_PLANS = (
    EvaluationPlan(name='no intervention', tags=['no intervention'], setup=no_intervention_plan),
    EvaluationPlan(name='ablated', tags=['ablated'], setup=ablation_plan),
    EvaluationPlan(name='pruned', tags=['pruned'], setup=pruning_plan),
    EvaluationPlan(name='suppression', tags=['suppression'], setup=suppression_plan),
)

CORRELATION_SPECS = (
    CorrelationSpec(plan='ablated', anchor_hsv=ANCHOR_HSV, power=3),
    CorrelationSpec(plan='suppression', anchor_hsv=ANCHOR_HSV, power=2),
)

BEST_PLAN = 'ablated'
POWER_BY_PLAN = {spec.plan: spec.power for spec in CORRELATION_SPECS}


def named_colors_factory() -> pd.DataFrame:
    return get_named_colors_df(n_hues=12, n_grays=5)

## Train

In [22]:
# Reload dopesheet: makes tweaking params during development easier
dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')

run_metrics, best_run = await run_multi_seed_training(
    RUN_SEEDS,
    EVALUATION_PLANS,
    CORRELATION_SPECS,
    best_plan=BEST_PLAN,
    named_colors_factory=named_colors_factory,
)
run_metrics_df = pd.DataFrame(m.to_row() for m in run_metrics).sort_values('seed').set_index('seed')

print(
    f'Completed {len(run_metrics)} runs. Best seed {best_run.seed} has {BEST_PLAN} r = '
    f'{best_run.metrics.correlations[BEST_PLAN].correlation:+.3f} '
    f'(R² = {best_run.metrics.correlations[BEST_PLAN].r_squared:.3f}).'
)

model = best_run.model
best_run_metrics = best_run.metrics
named_colors = best_run.named_colors
results = best_run.results
contexts = best_run.contexts

suppression_extra = best_run.get_extra('suppression', 'config')
falloff_extra = best_run.get_extra('suppression', 'falloff')
if not isinstance(suppression_extra, InterventionConfig) or not isinstance(falloff_extra, BoundedFalloff):
    raise RuntimeError('Suppression extras were not recorded for the best run.')

suppression = suppression_extra
falloff = falloff_extra

no_intervention_results = results['no intervention']
ablation_results = results['ablated']
pruned_results = results['pruned']
suppression_results = results['suppression']

Best so far: {'seed': 47.0, 'ablated_r': 0.9798149234738697, 'ablated_r2': 0.9600372842621051, 'ablated_p': 0.0, 'ablated_power': 3, 'ablated_abs': 0.9798149234738697, 'suppression_r': 0.9847591455092971, 'suppression_r2': 0.969750574664201, 'suppression_p': 0.0, 'suppression_power': 2, 'suppression_abs': 0.9847591455092971}
[100/100] Training seed 99...
I 2261.6 no.2.9.1:Training with: ['separate', 'anchor', 'anti-anchor', 'anti-subspace']


INFO: Seed set to 99


I 2261.6 li.fa.ut.se:Seed set to 99
I 2261.6 ex.se:PyTorch set to deterministic mode
I 2261.6 ex.se:PyTorch set to deterministic mode


INFO: GPU available: False, used: False


I 2262.9 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2262.9 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2262.9 li.py.ut.ra:HPU available: False, using: 0 HPUs
max_steps: 1501, train_loader length: 8
max_steps: 1501, train_loader length: 8


/workspaces/ex-color-transformer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Starting phase: Train


INFO: `Trainer.fit` stopped: `max_steps=1501` reached.


I 2280.9 li.py.ut.ra:`Trainer.fit` stopped: `max_steps=1501` reached.
I 2280.9 ex.ca.la:Label frequencies (n=96064): _any: 87 (0.091%), red: 87 (0.091%)
I 2280.9 ex.ca.la:Label frequencies (n=96064): _any: 87 (0.091%), red: 87 (0.091%)


0,1
epoch,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇█████
labels/_any,▁
labels/epoch/_any,▁█▆▅▇▇▄▅▅▅▅▄▅▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
labels/epoch/red,▁█▅▇█▅▅▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
labels/red,▁
train_anchor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁
train_anti-anchor,██▇▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_anti-subspace,▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▅▂▆▄▂▄▄▂▄▅▃█▂▃▅▅▄▃▅▅▄▄
train_loss,█▄▄▂▂▂▂▁▁▁▁▁▁▆▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_recon,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,187.0
labels/_any,0.00091
labels/epoch/_any,0.00091
labels/epoch/red,0.00091
labels/red,0.00091
train_anchor,0.0
train_anti-anchor,0.0
train_anti-subspace,0.0856
train_loss,3e-05
train_recon,3e-05


INFO: GPU available: False, used: False


I 2282.3 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2282.3 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2282.3 li.py.ut.ra:HPU available: False, using: 0 HPUs


/workspaces/ex-color-transformer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
INFO: GPU available: False, used: False
INFO: GPU available: False, used: False


I 2282.6 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2282.6 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2282.6 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2283.1 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2283.1 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2283.1 li.py.ut.ra:HPU available: False, using: 0 HPUs


/workspaces/ex-color-transformer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
INFO: GPU available: False, used: False


I 2283.3 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2283.3 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2283.3 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2283.3 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2283.3 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2283.3 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2284.9 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2284.9 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2284.9 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2284.9 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2284.9 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2284.9 li.py.ut.ra:HPU available: False, using: 0 HPUs


/workspaces/ex-color-transformer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
INFO: GPU available: False, used: False
INFO: GPU available: False, used: False


I 2285.2 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2285.2 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2285.2 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2285.2 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2285.2 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2285.2 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2286.3 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2286.3 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2286.3 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2286.3 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2286.3 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2286.3 li.py.ut.ra:HPU available: False, using: 0 HPUs


/workspaces/ex-color-transformer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
INFO: GPU available: False, used: False
INFO: GPU available: False, used: False


I 2286.6 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2286.6 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2286.6 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2286.6 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2286.6 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2286.6 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2288.9 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2289.0 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2289.0 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2289.0 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2289.0 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2289.0 li.py.ut.ra:HPU available: False, using: 0 HPUs


/workspaces/ex-color-transformer/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
INFO: GPU available: False, used: False
INFO: GPU available: False, used: False


I 2289.6 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2289.6 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2289.6 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2289.6 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2289.6 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2289.6 li.py.ut.ra:HPU available: False, using: 0 HPUs


INFO: GPU available: False, used: False


I 2291.5 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 2291.5 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 2291.5 li.py.ut.ra:HPU available: False, using: 0 HPUs
Completed 100 runs. Best seed 47 has ablated r = +0.980 (R² = 0.960).
Completed 100 runs. Best seed 47 has ablated r = +0.980 (R² = 0.960).


### Multi-run summary
We trained multiple seeds to understand how sensitive the intervention metrics are to initialization. The table below captures the correlation story for each run, and we pick the best checkpoint by maximizing the absolute ablation correlation.

In [23]:
from IPython.display import display

display(run_metrics_df)

prefixes = {metric_prefix(spec.plan) for spec in CORRELATION_SPECS}
correlation_columns = [
    column for column in run_metrics_df.columns if any(column.startswith(prefix) for prefix in prefixes)
]
correlation_summary = run_metrics_df[correlation_columns].agg(['mean', 'std', 'min', 'max']).T.rename_axis('metric')
display(correlation_summary)

Unnamed: 0_level_0,ablated_r,ablated_r2,ablated_p,ablated_power,ablated_abs,suppression_r,suppression_r2,suppression_p,suppression_power,suppression_abs
seed,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0.0,0.921529,0.849216,0.0,3,0.921529,0.946401,0.895675,0.0,2,0.946401
1.0,0.945684,0.894318,0.0,3,0.945684,0.958356,0.918447,0.0,2,0.958356
2.0,0.939028,0.881774,0.0,3,0.939028,0.970026,0.940951,0.0,2,0.970026
3.0,0.865224,0.748613,0.0,3,0.865224,0.927171,0.859646,0.0,2,0.927171
4.0,0.902018,0.813637,0.0,3,0.902018,0.956959,0.915771,0.0,2,0.956959
...,...,...,...,...,...,...,...,...,...,...
95.0,0.905823,0.820515,0.0,3,0.905823,0.934974,0.874176,0.0,2,0.934974
96.0,0.971634,0.944073,0.0,3,0.971634,0.953305,0.908791,0.0,2,0.953305
97.0,0.962716,0.926822,0.0,3,0.962716,0.965537,0.932261,0.0,2,0.965537
98.0,0.930735,0.866268,0.0,3,0.930735,0.918169,0.843034,0.0,2,0.918169


Unnamed: 0_level_0,mean,std,min,max
metric,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ablated_r,0.931351,0.033753,0.818136,0.979815
ablated_r2,0.868543,0.061725,0.669346,0.960037
ablated_p,0.0,0.0,0.0,0.0
ablated_power,3.0,0.0,3.0,3.0
ablated_abs,0.931351,0.033753,0.818136,0.979815
suppression_r,0.956301,0.024416,0.861479,0.991759
suppression_r2,0.915102,0.046036,0.742145,0.983585
suppression_p,0.0,0.0,0.0,0.0
suppression_power,2.0,0.0,2.0,2.0
suppression_abs,0.956301,0.024416,0.861479,0.991759


## Test

In [24]:
from IPython.display import clear_output

clear_output()

visualize_reconstructed_cube(
    no_intervention_results.color_slice_cube.permute('svh'),
    tags=no_intervention_results.tags,
)
# visualize_reconstruction_loss(no_intervention_results.loss_cube, tags=no_intervention_results.tags)
# visualize_latent_space(
#     no_intervention_results.latent_cube,
#     tags=no_intervention_results.tags,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (4, 1, 2), (3, 2, 4), (4, 3, 0)],
# )

### Ablation

In [25]:
from IPython.display import clear_output

clear_output()

visualize_reconstructed_cube(
    ablation_results.color_slice_cube.permute('svh'),
    tags=ablation_results.tags,
)
# visualize_reconstruction_loss(ablation_results.loss_cube, tags=ablation_results.tags)
# visualize_latent_space(
#     ablation_results.latent_cube,
#     tags=ablation_results.tags,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (4, 1, 2), (3, 2, 4), (4, 3, 0)],
# )

### Pruning

In [26]:
from IPython.display import clear_output

clear_output()

visualize_reconstructed_cube(
    pruned_results.color_slice_cube.permute('svh'),
    tags=pruned_results.tags,
)
# visualize_reconstruction_loss(pruned_results.loss_cube, tags=pruned_results.tags)
# visualize_latent_space(
#     pruned_results.latent_cube,
#     tags=pruned_results.tags,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (4, 1, 2), (3, 2, 4), (4, 3, 0)],
# )

### Suppression

Included for comparison/completeness, but this model was not really designed for it.

In [27]:
from IPython.display import clear_output

clear_output()

visualize_reconstructed_cube(
    suppression_results.color_slice_cube.permute('svh'),
    tags=suppression_results.tags,
)
# visualize_reconstruction_loss(suppression_results.loss_cube, tags=suppression_results.tags)
# visualize_latent_space(
#     suppression_results.latent_cube,
#     tags=suppression_results.tags,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )

### Stacked visualizations for paper

In [28]:
from IPython.display import display

print('Max error:')
display(
    {
        'No intervention': float(no_intervention_results.loss_cube['MSE'].max()),
        'Ablation': float(ablation_results.loss_cube['MSE'].max()),
        'Pruned': float(pruned_results.loss_cube['MSE'].max()),
        'Suppression': float(suppression_results.loss_cube['MSE'].max()),
    }
)

Max error:


{'No intervention': 9.104946366278455e-05,
 'Ablation': 0.32055285573005676,
 'Pruned': 0.32055285573005676,
 'Suppression': 0.18120326101779938}

In [29]:
def themed_annotation(
    theme: Theme, direction: Sequence[float], angle: float, dashed: bool = False
) -> ConicalAnnotation:
    return ConicalAnnotation(
        direction=direction,
        angle=angle,
        color=theme.val('black', dark='#fff'),
        linewidth=theme.val(0.75, dark=1),
        **dict(
            dashes=theme.val((8, 8), dark=(4, 4)),
            gapcolor=theme.val('#ddda', dark='#222a'),
        ) if dashed else {}
    )  # fmt: skip


max_error = np.max(
    [
        no_intervention_results.loss_cube['MSE'],
        ablation_results.loss_cube['MSE'],
        pruned_results.loss_cube['MSE'],
    ]
)

print('Baseline')
visualize_stacked_results(
    no_intervention_results,
    latent_dims=((3, 0, 1), (3, 2, 4)),
    max_error=max_error,
)

print('Ablation')
visualize_stacked_results(
    ablation_results,
    latent_dims=((3, 0, 1), (3, 2, 4)),
    max_error=max_error,
)

print('Pruned')
visualize_stacked_results(
    pruned_results,
    latent_dims=((2, None, 0), (2, 1, 3)),
    max_error=max_error,
)

print('Suppression')
visualize_stacked_results(
    suppression_results,
    latent_dims=((3, 0, 1), (3, 2, 0)),
    # latent_dims=((1, 0, 2), (1, 2, 0)),
    max_error=max_error,
    latent_annotations=[
        lambda theme: themed_annotation(theme, direction=RED, angle=2 * (np.pi / 2 - falloff.a), dashed=True),
    ],
)

Baseline


Ablation


Pruned


Suppression


### Tabular results: error vs color per intervention

In [30]:
from IPython.display import display
from ex_color.vis import ColorTableHtmlFormatter

df = hstack_named_results(no_intervention_results, ablation_results, pruned_results, suppression_results)

display(ColorTableHtmlFormatter().style(df))

Name,RGB,No Intervention,Ablated,Δ Abl,Pruned,Δ Pru,Suppression,Δ Sup
red,,0.0,0.333,0.333,0.333,0.333,0.198,0.198
orange,,0.0,0.123,0.123,0.123,0.123,0.112,0.112
yellow,,0.0,0.017,0.017,0.017,0.017,0.024,0.024
lime,,0.0,0.0,0.0,0.0,0.0,0.0,0.0
green,,0.0,0.0,0.0,0.0,0.0,0.0,0.0
teal,,0.0,0.0,0.0,0.0,0.0,0.0,0.0
cyan,,0.0,0.0,-0.0,0.0,-0.0,0.0,-0.0
azure,,0.0,0.0,0.0,0.0,0.0,0.0,0.0
blue,,0.0,0.0,0.0,0.0,0.0,0.0,0.0
purple,,0.0,0.002,0.002,0.002,0.002,0.002,0.002


In [31]:
from IPython.display import display

import pandas as pd

from ex_color.vis import ColorTableLatexFormatter

formatter = ColorTableLatexFormatter()
# print(formatter.preamble)
latex = formatter.to_str(
    pd.DataFrame(
        {
            'color': df['name'].str.capitalize(),
            'rgb': df['rgb'],
            'baseline': df['no intervention'],
            'ablated': df['ablated-delta'],
            'pruned': df['pruned-delta'],
            'suppression': df['suppression-delta'],
        }
    ),
    caption='Reconstruction error by color and intervention method',
    label='tab:error-per-color-soft',
)
display({'text/markdown': f'```latex\n{latex}\n```', 'text/plain': latex}, raw=True)

```latex
\begin{table}
\centering
\label{tab:error-per-color-soft}
\caption{Reconstruction error by color and intervention method}
\sisetup{
    round-mode = places,
    round-precision = 3,
    table-auto-round = true,
    % drop-zero-decimal = true,
}
\begin{tabular}{l c g g g g}
\toprule
\multicolumn{2}{c}{{Color}} & \multicolumn{1}{c}{{Baseline}} & \multicolumn{1}{c}{{Ablated}} & \multicolumn{1}{c}{{Pruned}} & \multicolumn{1}{c}{{Suppression}} \\
\midrule
Red        & \swatch{FF0000} &  0.000036028 &  0.333297312 &  0.333297312 &  0.198443457 \\
Orange     & \swatch{FF7F00} &  0.000041728 &  0.123259634 &  0.123259634 &  0.111782402 \\
Yellow     & \swatch{FFFF00} &  0.000059968 &  0.017365806 &  0.017365806 &  0.023780219 \\
Lime       & \swatch{7FFF00} &  0.000084073 &  0.000207367 &  0.000207367 &  0.000209407 \\
Green      & \swatch{00FF00} &  0.000091481 &  0.000096842 &  0.000096842 &  0.000100940 \\
Teal       & \swatch{00FF7F} &  0.000000467 &  0.000002635 &  0.000002635 &  0.000002622 \\
Cyan       & \swatch{00FFFF} &  0.000104127 & -0.000003132 & -0.000003132 & -0.000003129 \\
Azure      & \swatch{007FFF} &  0.000046455 &  0.000071969 &  0.000071969 &  0.000072302 \\
Blue       & \swatch{0000FF} &  0.000007944 &  0.000199724 &  0.000199724 &  0.000211864 \\
Purple     & \swatch{7F00FF} &  0.000010084 &  0.001649732 &  0.001649732 &  0.001688957 \\
Magenta    & \swatch{FF00FF} &  0.000012721 &  0.038103655 &  0.038103655 &  0.060558371 \\
Pink       & \swatch{FF007F} &  0.000000190 &  0.147680908 &  0.147680908 &  0.119920596 \\
Black      & \swatch{000000} &  0.000074612 &  0.000788459 &  0.000788459 &  0.000860083 \\
Dark gray  & \swatch{3F3F3F} &  0.000000158 &  0.000464755 &  0.000464755 &  0.000457245 \\
Gray       & \swatch{7F7F7F} &  0.000010912 &  0.000178977 &  0.000178977 &  0.000178518 \\
Light gray & \swatch{BFBFBF} &  0.000014506 &  0.000211861 &  0.000211861 &  0.000210544 \\
White      & \swatch{FFFFFF} &  0.000111809 &  0.000714028 &  0.000714028 &  0.000731633 \\
\bottomrule
\end{tabular}
\end{table}
```

### Correlation: error vs. similarity to anchor per intervention

In [32]:
from utils.strings import sup

ablated_stats = best_run_metrics.correlations['ablated']
power = POWER_BY_PLAN['ablated']
print(
    f'MSE,sim{sup(power)} Ablated (seed {best_run.seed}): '
    f'r = {ablated_stats.correlation:+.2f}, R²: {ablated_stats.r_squared:.2f}, p = {ablated_stats.p_value:.3g}'
)

pruned_corr, pruned_p_value = error_correlation(pruned_results.latent_cube, ANCHOR_HSV, power=power)
print(f'MSE,sim{sup(power)} Pruned: r = {pruned_corr:+.2f}, R²: {pruned_corr**2:.2f}, p = {pruned_p_value:.3g}')

suppression_stats = best_run_metrics.correlations['suppression']
power = POWER_BY_PLAN['suppression']
print(
    f'MSE,sim{sup(power)} Suppression (seed {best_run.seed}): '
    f'r = {suppression_stats.correlation:+.2f}, R²: {suppression_stats.r_squared:.2f}, p = {suppression_stats.p_value:.3g}'
)

MSE,sim³ Ablated (seed 47): r = +0.98, R²: 0.96, p = 0
MSE,sim³ Pruned: r = +0.98, R²: 0.96, p = 0
MSE,sim² Suppression (seed 47): r = +0.98, R²: 0.97, p = 0


In [33]:
print('Ablated')
visualize_error_vs_similarity(
    ablation_results.latent_cube,
    ANCHOR_HSV,
    tags=ablation_results.tags,
    anchor_name='red',
    power=POWER_BY_PLAN['ablated'],
)

print('Pruned')
visualize_error_vs_similarity(
    pruned_results.latent_cube,
    ANCHOR_HSV,
    tags=pruned_results.tags,
    anchor_name='red',
    power=POWER_BY_PLAN['ablated'],
)

print('Suppression')
visualize_error_vs_similarity(
    suppression_results.latent_cube,
    ANCHOR_HSV,
    tags=suppression_results.tags,
    anchor_name='red',
    power=POWER_BY_PLAN['suppression'],
)

Ablated


Pruned


Suppression
