# Experiment 2.4.1: Soft intervention on red with color wheel


In [1]:
from __future__ import annotations

nbid = '2.4.1'  # ID for tagging assets
nbname = 'Soft intervention on red with color wheel'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-preppy'

In [2]:
# Basic setup: Logging, Experiment (Modal)
import logging

import modal

from infra.requirements import freeze, project_packages
from mini.experiment import Experiment
from utils.logging import SimpleLoggingConfig

logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')

run = Experiment(experiment_name, project=project)
run.image = modal.Image.debian_slim().pip_install(*freeze(all=True)).add_local_python_source(*project_packages())
run.before_each(logging_config.apply)
None  # prevent auto-display of this cell

## Model parameters

Like Ex 2.4, we use the following regularizers:

- **Anchor:** pins `red` to $(1,0,0,0)$ (4D)
- **AxisAlignedSubspace:** attracts vibrant hues to dimensions $0,1$
- **Separate:** angular repulsion to reduce global clumping (applied within each batch)

In [3]:
import torch

from ex_color.loss import AngularAnchor, AxisAlignedSubspace, Separate, RegularizerConfig

from ex_color.training import TrainingModule

K = 4  # bottleneck dimensionality
N = 1  # number of nonlinear layers
RED = (1, 0, 0, 0)
assert len(RED) == K

reg_separate = RegularizerConfig(
    name='separate',
    compute_loss_term=Separate(power=100.0, shift=True),
    label_affinities=None,
    layer_affinities=['bottleneck'],
)
reg_anchor = RegularizerConfig(
    name='anchor',
    compute_loss_term=AngularAnchor(torch.tensor(RED, dtype=torch.float32)),
    label_affinities={'red': 1.0},
    layer_affinities=['bottleneck'],
    phase=('train', 'validate'),
)
reg_subspace = RegularizerConfig(
    name='subspace',
    compute_loss_term=AxisAlignedSubspace((0, 1)),
    label_affinities={'vibrant': 1},
    # label_affinities={'primary': 1},
    layer_affinities=['bottleneck'],
    phase=('train', 'validate'),
)

In [4]:
from typing import cast

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from IPython.display import display, Markdown

from mini.temporal.dopesheet import Dopesheet
from mini.temporal.timeline import Timeline
from mini.temporal.vis import plot_timeline, realize_timeline, ParamGroup
from utils.nb import displayer_mpl
from utils.plt import Theme

dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')

display(Markdown(f"""## Parameter schedule \n{dopesheet.to_markdown()}"""))


def plot_dopesheet(dopesheet: Dopesheet, theme: Theme):
    timeline = Timeline(dopesheet)
    history_df = realize_timeline(timeline)
    keyframes_df = dopesheet.as_df()

    fig = plt.figure(figsize=(9, 3), constrained_layout=True)
    axs = fig.subplots(2, 1, sharex=True, height_ratios=[3, 1])
    ax1, ax2 = cast(tuple[Axes, ...], axs)

    plot_timeline(
        history_df,
        keyframes_df,
        groups=(ParamGroup(name='', params=[p for p in dopesheet.props if p not in {'lr'}]),),
        theme=theme,
        ax=ax1,
        show_phase_labels=False,
    )
    ax1.set_ylabel('Weight')
    ax1.set_xlabel('')

    plot_timeline(
        history_df,
        keyframes_df,
        groups=(ParamGroup(name='', params=['lr']),),
        theme=theme,
        ax=ax2,
        show_legend=False,
        show_phase_labels=False,
    )
    ax2.set_ylabel('LR')

    # add a little space on the y-axis extents
    ax1.set_ylim(ax1.get_ylim()[0] * 1.1, ax1.get_ylim()[1] * 1.1)
    ax2.set_ylim(ax2.get_ylim()[0] * 1.1, ax2.get_ylim()[1] * 1.1)

    return fig


with displayer_mpl(
    f'large-assets/ex-{nbid}-dopesheet.png',
    alt_text="""Plot showing the parameter schedule for the training run, titled "{title}". The plot has two sections: the upper section shows various regularization weights over time, and the lower section shows the learning rate over time. The x-axis represents training steps.""",
) as show:
    show(lambda theme: plot_dopesheet(dopesheet, theme))

## Parameter schedule 
|   STEP | PHASE   |   ACTION |      lr |   separate |   anchor |   subspace |
|-------:|:--------|---------:|--------:|-----------:|---------:|-----------:|
|      0 | Train   |          |   1e-08 |            |     0.01 |       0.01 |
|     10 |         |          |   0.01  |            |          |            |
|    375 |         |          |         |            |          |       0.1  |
|    750 |         |          |   0.1   |       0.01 |     0.1  |            |
|   1125 |         |          |         |            |          |       0.1  |
|   1425 |         |          |   0.1   |       0    |     0    |       0    |
|   1500 |         |          |   0.05  |            |          |            |

## Data

Data is the same as last time: color cubes with values in RGB.


In [5]:
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset, RandomSampler
import numpy as np

from ex_color.data.color_cube import ColorCube
from ex_color.data.cube_dataset import CubeDataset, redness, stochastic_labels, exact_labels, vibrancy


def prep_data() -> DataLoader:
    cube = ColorCube.from_rgb(
        r=np.linspace(0, 1, 10),
        g=np.linspace(0, 1, 10),
        b=np.linspace(0, 1, 10),
    )
    # Softly label _red_ - will be stochastically discretized in the dataloader
    cube = cube.assign(
        red=redness(cube['color']) ** 8 * 0.08,
        vibrant=vibrancy(cube['color']) ** 10 * 0.01,
        # desaturated=(1 - vibrancy(cube['color'])) ** 10 * 0.02,
    )
    dataset = CubeDataset(cube)
    return DataLoader(
        dataset,
        batch_size=64,
        num_workers=4,
        sampler=RandomSampler(dataset, num_samples=len(dataset), replacement=True),
        collate_fn=stochastic_labels,
    )


def prep_val_data() -> DataLoader:
    cube = ColorCube.from_rgb(
        r=np.linspace(0, 1, 4),
        g=np.linspace(0, 1, 4),
        b=np.linspace(0, 1, 4),
    )
    # Exact labels for validation: we only check where the prototypes are located
    cube = cube.assign(
        red=redness(cube['color']) == 1,
        # vibrant=vibrancy(cube['color']) == 1,
    )
    dataset = CubeDataset(cube)
    return DataLoader(
        dataset,
        # batch_size=len(dataset),
        num_workers=2,
        collate_fn=exact_labels,
    )

In [6]:
import numpy as np

# from ex_color.data.color_cube import ColorCube
from ex_color.data.cyclic import arange_cyclic


hsv_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=1 / 12),
    s=np.linspace(0, 1, 4),
    v=np.linspace(0, 1, 5),
)

n_h = hsv_cube.shape[0]

hd_hsv_cube = ColorCube.from_hsv(
    # Extend hue range to encompass the end pixels of the low-res cube above
    h=np.linspace(0 - 1 / n_h, 1 + 1 / n_h, 300),
    s=np.linspace(0, 1, 48),
    v=np.linspace(0, 1, 48),
)
hd_hsv_cube = hd_hsv_cube[::2, ::2, ::2]

rgb_cube = ColorCube.from_rgb(
    r=np.linspace(0, 1, 20),
    g=np.linspace(0, 1, 20),
    b=np.linspace(0, 1, 20),
)

# with displayer_mpl(
#     f'large-assets/ex-{nbid}-true-colors.png',
#     alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. The first slice shows a grayscale gradient from black to white; the last shows the fully-saturated color spectrum.""",
# ) as show:
#     show(lambda: plot_colors(hsv_cube, title='True colors'))

## Training & inference utils

Like in Ex 2.2..2.9, the model is trained with PyTorch Lightning, with regularizers applied as custom hooks.

In [7]:
from tempfile import gettempdir
import wandb
from ex_color.model import CNColorMLP


# @run.thither(env={'WANDB_API_KEY': wandb.Api().api_key})
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
    k_bottleneck: int,
    n_nonlinear: int,
    *,
    seed: int | None = None,
) -> CNColorMLP:
    """Train the model with the given dopesheet and variant."""
    import lightning as L
    from lightning.pytorch.loggers import WandbLogger

    from ex_color.callbacks import LabelProportionCallback
    from ex_color.seed import set_deterministic_mode

    from utils.progress.lightning import LightningProgress

    log.info(f'Training with: {[r.name for r in regularizers]}')

    if seed is not None:
        set_deterministic_mode(seed)

    train_loader = prep_data()
    val_loader = prep_val_data()

    model = CNColorMLP(k_bottleneck, n_nonlinear=n_nonlinear)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.debug(f'Model initialized with {total_params:,} trainable parameters.')

    training_module = TrainingModule(model, dopesheet, torch.nn.MSELoss(), regularizers)

    logger = WandbLogger(experiment_name, project=project, save_dir=gettempdir())

    trainer = L.Trainer(
        max_steps=len(dopesheet),
        callbacks=[
            LightningProgress(),
            LabelProportionCallback(prefix='labels', get_active_labels=lambda: training_module.active_labels),
        ],
        enable_checkpointing=False,
        enable_model_summary=False,
        # enable_progress_bar=True,
        check_val_every_n_epoch=10,
        logger=logger,
        log_every_n_steps=min(50, len(train_loader)),
    )

    print(f'max_steps: {len(dopesheet)}, train_loader length: {len(train_loader)}')

    # Train the model
    try:
        trainer.fit(training_module, train_loader, val_loader)
    finally:
        wandb.finish()
    # This is only a small model, so it's OK to return it rather than storing and loading a checkpoint remotely
    return model

We wrap the model that we trained above in an `InferenceModule`. We won't be using its intervention features.

In [8]:
import torch
import numpy as np

from ex_color.inference import InferenceModule
from ex_color.intervention.intervention import InterventionConfig


async def infer_with_latent_capture(
    model: CNColorMLP,
    test_data: Tensor,
    interventions: list[InterventionConfig],
    layer_name: str = 'bottleneck',
) -> tuple[Tensor, Tensor]:
    module = InferenceModule(model, interventions, capture_layers=[layer_name])
    import lightning as L

    trainer = L.Trainer(enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False, logger=False)
    batches = trainer.predict(
        module,
        DataLoader(
            TensorDataset(test_data.reshape((-1, 3))),
            batch_size=64,
            collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
        ),
    )
    assert batches is not None
    preds = [item for batch in batches for item in batch]
    y = torch.cat(preds).reshape(test_data.shape)
    # Read captured activations as a flat [N, D] tensor
    latents = module.read_captured(layer_name)
    return y, latents

In [9]:
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import pandas as pd

from torch.nn import functional as F


async def test(model: CNColorMLP, interventions: list[InterventionConfig], test_data: ColorCube) -> ColorCube:
    x = torch.tensor(test_data.rgb_grid, dtype=torch.float32)
    y, h = await infer_with_latent_capture(model, x, interventions, 'bottleneck')
    per_color_loss = F.mse_loss(y, x, reduction='none').mean(dim=-1)
    return test_data.assign(
        recon=y.numpy().reshape((*test_data.shape, -1)),
        MSE=per_color_loss.numpy().reshape((*test_data.shape, -1)),
        latents=h.numpy().reshape((*test_data.shape, -1)),
    )


async def test_named(
    model: CNColorMLP, interventions: list[InterventionConfig], test_data: pd.DataFrame
) -> pd.DataFrame:
    x = torch.tensor(test_data['rgb'], dtype=torch.float32)
    y, _ = await infer_with_latent_capture(model, x, interventions, 'bottleneck')
    per_color_loss = F.mse_loss(y, x, reduction='none').mean(dim=-1)
    y_tuples = [tuple(row) for row in y.numpy()]
    return test_data.assign(recon=y_tuples, MSE=per_color_loss.numpy())  # pyright: ignore[reportArgumentType]

In [10]:
# # Generate a list of dimensions to visualize
# from itertools import combinations
# [
#     (
#         b,
#         a,
#         (a + 1) % K if (a + 1) % K not in (a, b) else (a + 2) % K,
#     )
#     for a, b in combinations(tuple(range(K)), 2)
# ]

In [11]:
from typing import Sequence

from ex_color.vis import (
    plot_colors,
    plot_cube_series,
    plot_latent_grid_3d_from_cube,
)
from utils.nb import displayer_mpl


def tags_for_file(tags: Sequence[str]) -> str:
    import re

    tags = [re.sub(r'[^a-zA-Z0-9]+', '-', tag.lower()) for tag in tags]
    return '-'.join(tags)


def visualize_reconstructed_cube(data: ColorCube, *, tags: Sequence[str] = ()):
    with displayer_mpl(
        f'large-assets/ex-{nbid}-pred-colors-{tags_for_file(tags)}.png',
        alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color.""",
    ) as show:
        show(
            lambda: plot_colors(
                data,
                title=f'Predicted colors · {" · ".join(tags)}',
                colors='recon',
                colors_compare='color',
            )
        )


def visualize_reconstruction_loss(data: ColorCube, *, tags: Sequence[str] = ()):
    max_loss = np.max(data['MSE'])
    median_loss = np.median(data['MSE'])
    with displayer_mpl(
        f'large-assets/ex-{nbid}-loss-colors-{tags_for_file(tags)}.png',
        alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue.""",
    ) as show:
        show(
            lambda: plot_cube_series(
                data.permute('hsv')[:, -1:, :: (data.shape[2] // -5)],
                data.permute('svh')[:, -1:, :: -(data.shape[0] // -3)],
                data.permute('vsh')[:, -1:, :: -(data.shape[0] // -3)],
                title=f'Reconstruction error · {" · ".join(tags)}',
                var='MSE',
                figsize=(12, 3),
            )
        )
    print(f'Max loss: {max_loss:.2g}')
    print(f'Median MSE: {median_loss:.2g}')


def visualize_latent_space(data: ColorCube, *, tags: Sequence[str] = (), dims: Sequence[tuple[int, int, int]]):
    with displayer_mpl(
        f'large-assets/ex-{nbid}-latents-{tags_for_file(tags)}.png',
        alt_text="""Two rows of three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a hypersphere, with each plot showing one 2D projection.""",
    ) as show:
        show(
            lambda theme: plot_latent_grid_3d_from_cube(
                data,
                colors='recon',
                colors_compare='color',
                latents='latents',
                title=f'Latents ·  · {" · ".join(tags)}',
                dims=dims,
                dot_radius=10,
                theme=theme,
            )
        )

In [12]:
import importlib
import ex_color.vis

importlib.reload(ex_color.vis)

<module 'ex_color.vis' from '/workspaces/ex-color-transformer/src/ex_color/vis/__init__.py'>

In [33]:
from dataclasses import dataclass
from typing import Callable

import skimage as ski

from ex_color.data import hsv_similarity
from ex_color.vis import draw_stacked_results, ConicalAnnotation, draw_cube_scatter
from ex_color.vis.plot_latent_slices import LatentD
from utils.nb import displayer_mpl


@dataclass
class Resultset:
    tags: Sequence[str]
    latent_cube: ColorCube
    color_slice_cube: ColorCube
    loss_cube: ColorCube
    named_colors: pd.DataFrame


def visualize_stacked_results(
    res: Resultset,
    *,
    latent_dims: tuple[LatentD, LatentD],
    max_error: float | None = None,
    latent_annotations: Sequence[ConicalAnnotation | Callable[[Theme], ConicalAnnotation]] = (),
):
    with displayer_mpl(
        f'large-assets/ex-{nbid}-results-{tags_for_file(res.tags)}.png',
        alt_text="""Composite figure with two latent panels (top), a color slice (middle), and a loss chart (bottom).""",
    ) as show:
        show(
            lambda theme: draw_stacked_results(
                res.latent_cube,
                res.color_slice_cube,
                res.loss_cube,
                latent_dims=latent_dims,
                theme=theme,
                max_error=max_error,
                latent_annotations=[
                    ann(theme) if not isinstance(ann, ConicalAnnotation) else ann  #
                    for ann in latent_annotations
                ],
            )
        )


def scatter_similarity_vs_error(
    cube: ColorCube,
    anchor_hsv: tuple[float, float, float],
    *,
    anchor_name: str,
    theme: Theme,
):
    """Scatter plot of similarity to anchor vs reconstruction error."""
    cube = cube.assign(hsv=ski.color.rgb2hsv(cube['color']))
    cube = cube.assign(similarity=hsv_similarity(cube['hsv'], np.array(anchor_hsv), hemi=True, mode='cosine') ** 2)

    fig, ax = plt.subplots(figsize=(4, 4), constrained_layout=True)
    draw_cube_scatter(ax, cube, theme=theme, x_var='similarity', y_var='MSE')
    ax.set_ylabel(r'MSE')
    ax.set_xlabel(rf'$\text{{sim}}_\text{{{anchor_name}}}^2$')
    ax.legend(loc='upper left')
    return fig


def visualize_error_vs_similarity(
    cube: ColorCube,
    anchor_hsv: tuple[float, float, float],
    *,
    tags: Sequence[str] = (),
    anchor_name: str = 'anchor',
):
    with displayer_mpl(
        f'large-assets/ex-{nbid}-error-vs-similarity-{tags_for_file(tags)}.png',
        alt_text=f"""Scatter plot showing reconstruction error versus similarity to {anchor_name}. Each point represents a color, with its position on the x-axis indicating how similar it is to pure red, and its position on the y-axis indicating the reconstruction error (mean squared error) for that color. The points are colored according to their actual color values.""",
    ) as show:
        show(lambda theme: scatter_similarity_vs_error(cube, anchor_hsv, theme=theme, anchor_name=anchor_name))


def hstack_named_results(*res: Resultset) -> pd.DataFrame:
    """
    Create a table of results across several experiments.

    The output table begins with columns taken from the first Resultset:
    - name (of color)
    - rgb
    - hsv
    - {tags}: reconstruction error (MSE) of first Resultset

    Then, for each subsequent Resultset, two columns are added:
    - {tags}: reconstruction error (MSE) of that Resultset
    - {tags}-delta: change in reconstruction error relative to the first Resultset
    """
    names = [' '.join(r.tags) for r in res]
    df = res[0].named_colors[['name', 'rgb', 'hsv', 'MSE']].rename(columns={'MSE': names[0]})
    for name, r in zip(names[1:], res[1:], strict=True):
        df = df.merge(
            r.named_colors[['name', 'MSE']].rename(columns={'MSE': name}),
            on='name',
        )
        df[f'{name}-delta'] = df[name] - df[names[0]]
    return df

## Train

The first model we'll test is one that is trained without the subspace term, but with the unitarity regularization term included.

In [14]:
# Reload dopesheet: makes tweaking params during development easier
dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')

model = await train(
    dopesheet,
    [reg_separate, reg_anchor, reg_subspace],
    K,
    N,
    seed=0,  # Arbitrary but not cherry-picked
)

I 20.3 no.2.4.1:Training with: ['separate', 'anchor', 'subspace']


INFO: Seed set to 0


I 20.4 li.fa.ut.se:Seed set to 0
I 20.4 ex.se:  PyTorch set to deterministic mode


INFO: GPU available: False, used: False


I 20.7 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 20.7 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 20.7 li.py.ut.ra:HPU available: False, using: 0 HPUs
max_steps: 1501, train_loader length: 16


[34m[1mwandb[0m: Currently logged in as: [33mz0r[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Starting phase: Train


INFO: `Trainer.fit` stopped: `max_steps=1501` reached.


I 38.4 li.py.ut.ra:`Trainer.fit` stopped: `max_steps=1501` reached.
I 38.5 ex.ca.la:Label frequencies (n=89064): _any: 0.192%, red: 0.082%, vibrant: 0.111%


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
labels/_any,▁
labels/epoch/_any,█▅▁▁▂▄▄▃▃▃▃▃▃▃▄▃▅▅▄▄▄▄▅▄▅▄▄▄▄▄▅▅▄▅▅▅▅▅▅▅
labels/epoch/red,█▃▃▁▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▂▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂
labels/epoch/vibrant,▁▄▃▆▄▅▄▄▄▄▅▆▆▆▆▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇█████████
labels/red,▁
labels/vibrant,▁
train_anchor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▄▂▃▂▂▂▂▃▃▃▂▂▃▂▂▂▁▂▂▂▂▂▂▁▁▂▁▂▂▂▁▁▁▁▁▁▁▁▁
train_recon,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,92.0
labels/_any,0.00192
labels/epoch/_any,0.00192
labels/epoch/red,0.00082
labels/epoch/vibrant,0.00111
labels/red,0.00082
labels/vibrant,0.00111
train_anchor,0.0
train_loss,6e-05
train_recon,6e-05


In [15]:
from IPython.display import clear_output

from ex_color.data.color import get_named_colors_df

named_colors = get_named_colors_df(n_hues=12, n_grays=5)

no_intervention_results = Resultset(
    tags=['no intervention'],
    latent_cube=await test(model, [], rgb_cube),
    color_slice_cube=await test(model, [], hsv_cube),
    loss_cube=await test(model, [], hd_hsv_cube),
    named_colors=await test_named(model, [], named_colors),
)
clear_output()

visualize_reconstructed_cube(no_intervention_results.color_slice_cube.permute('svh'), tags=no_intervention_results.tags)
# visualize_reconstruction_loss(no_intervention_results.loss_cube, tags=no_intervention_results.tags)
# visualize_latent_space(
#     no_intervention_results.latent_cube,
#     tags=no_intervention_results.tags,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )

Reconstruction loss looks about as good as last time. There's higher loss around red even without intervenion, but it's still low.

Latent space looks suprisingly good: the area opposing red is clear, and there's a pronounced collection of reds near the anchor point.

## Suppression

In [16]:
from math import cos, radians
from IPython.display import clear_output

from ex_color.intervention import Suppression, BoundedFalloff


falloff = BoundedFalloff(
    cos(radians(90)),  # cos(max_angle)
    1,  # completely squash fully-aligned vectors
    # 2,  # soft rim, sharp hub
    0,
)
suppression = InterventionConfig(
    apply=Suppression(torch.tensor(RED), falloff),
    layer_affinities=['bottleneck'],
)
suppression_results = Resultset(
    tags=['suppression'],
    latent_cube=await test(model, [suppression], rgb_cube),
    color_slice_cube=await test(model, [suppression], hsv_cube),
    loss_cube=await test(model, [suppression], hd_hsv_cube),
    named_colors=await test_named(model, [suppression], named_colors),
)
clear_output()

visualize_reconstructed_cube(suppression_results.color_slice_cube.permute('svh'), tags=suppression_results.tags)
# visualize_reconstruction_loss(suppression_results.loss_cube, tags=suppression_results.tags)
# visualize_latent_space(
#     suppression_results.latent_cube,
#     tags=suppression_results.tags,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )

## Repulsion

In [17]:
from math import cos, pi, radians
from IPython.display import clear_output

from ex_color.intervention import Repulsion, FastBezierMapper, LinearMapper

mapper = FastBezierMapper(
    0,  # cos(max_angle)
    cos(pi / 3),  # cos(min_angle)
)
mapper = LinearMapper(
    cos(radians(90)),
    cos(radians(70)),
)
repulsion = InterventionConfig(
    Repulsion(torch.tensor(RED), mapper),
    layer_affinities=['bottleneck'],
)
repulsion_results = Resultset(
    tags=['repulsion'],
    latent_cube=await test(model, [repulsion], rgb_cube),
    color_slice_cube=await test(model, [repulsion], hsv_cube),
    loss_cube=await test(model, [repulsion], hd_hsv_cube),
    named_colors=await test_named(model, [repulsion], named_colors),
)
clear_output()

visualize_reconstructed_cube(repulsion_results.color_slice_cube.permute('svh'), tags=repulsion_results.tags)
# visualize_reconstruction_loss(repulsion_results.loss_cube, tags=repulsion_results.tags)
# visualize_latent_space(
#     repulsion_results.latent_cube,
#     tags=repulsion_results.tags,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )

## Ablation

In [18]:
from ex_color.surgery import ablate

ablated_model = ablate(model, 'bottleneck', [0])

ablation_results = Resultset(
    tags=['ablation'],
    latent_cube=await test(ablated_model, [], rgb_cube),
    color_slice_cube=await test(ablated_model, [], hsv_cube),
    loss_cube=await test(ablated_model, [], hd_hsv_cube),
    named_colors=await test_named(ablated_model, [], named_colors),
)
clear_output()

visualize_reconstructed_cube(ablation_results.color_slice_cube.permute('svh'), tags=ablation_results.tags)
# visualize_reconstruction_loss(ablation_results.loss_cube, tags=ablation_results.tags)
# visualize_latent_space(
#     ablation_results.latent_cube,
#     tags=ablation_results.tags,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1), (2, 1, 3), (3, 1, 2), (3, 2, 0)],
# )

## Stacked visualizations for paper

In [19]:
from IPython.display import display

print('Max error:')
display(
    {
        'No intervention': float(no_intervention_results.loss_cube['MSE'].max()),
        'Suppression': float(suppression_results.loss_cube['MSE'].max()),
        'Repulsion': float(repulsion_results.loss_cube['MSE'].max()),
        'Ablation': float(ablation_results.loss_cube['MSE'].max()),
    }
)

Max error:


{'No intervention': 0.0004948975401930511,
 'Suppression': 0.1882771998643875,
 'Repulsion': 0.18570585548877716,
 'Ablation': 0.33162155747413635}

In [20]:
def themed_annotation(
    theme: Theme, direction: Sequence[float], angle: float, dashed: bool = False
) -> ConicalAnnotation:
    return ConicalAnnotation(
        direction=direction,
        angle=angle,
        color=theme.val('black', dark='#fff'),
        linewidth=theme.val(0.75, dark=1),
        **dict(
            dashes=theme.val((8, 8), dark=(4, 4)),
            gapcolor=theme.val('#ddda', dark='#222a'),
        ) if dashed else {}
    )  # fmt: skip


max_error = np.max(
    [
        no_intervention_results.loss_cube['MSE'],
        suppression_results.loss_cube['MSE'],
        repulsion_results.loss_cube['MSE'],
    ]
)

print('Baseline')
visualize_stacked_results(
    no_intervention_results,
    latent_dims=((1, 0, 2), (1, 2, 0)),
    max_error=max_error,
)

print('Suppression')
visualize_stacked_results(
    suppression_results,
    latent_dims=((1, 0, 2), (1, 2, 0)),
    max_error=max_error,
    latent_annotations=[
        lambda theme: themed_annotation(theme, direction=RED, angle=2 * (np.pi / 2 - falloff.a), dashed=True),
    ],
)

print('Repulsion')
visualize_stacked_results(
    repulsion_results,
    latent_dims=((1, 0, 2), (1, 2, 0)),
    max_error=max_error,
    latent_annotations=[
        lambda theme: themed_annotation(theme, direction=RED, angle=2 * (np.pi / 2 - mapper.a), dashed=True),
        lambda theme: themed_annotation(theme, direction=RED, angle=2 * (np.pi / 2 - mapper.b), dashed=False),
    ],
)

print('Ablation')
visualize_stacked_results(
    ablation_results,
    latent_dims=((1, 0, 2), (1, 2, 0)),
    max_error=max_error,
)

Baseline


Suppression


Repulsion


Ablation


## Tabular results: error vs color per intervention

In [21]:
from IPython.display import display
from ex_color.vis import ColorTableHtmlFormatter

df = hstack_named_results(no_intervention_results, suppression_results, repulsion_results, ablation_results)

display(ColorTableHtmlFormatter().style(df))

Name,RGB,No Intervention,Suppression,Δ Sup,Repulsion,Δ Rep,Ablation,Δ Abl
red,,0.001,0.206,0.205,0.198,0.197,0.333,0.333
orange,,0.0,0.121,0.121,0.071,0.071,0.144,0.144
yellow,,0.0,0.047,0.046,0.018,0.017,0.04,0.039
lime,,0.0,0.0,-0.0,0.0,-0.0,0.0,-0.0
green,,0.001,0.001,0.0,0.001,0.0,0.046,0.046
teal,,0.0,0.0,0.0,0.0,0.0,0.148,0.148
cyan,,0.001,0.001,0.0,0.001,0.0,0.345,0.344
azure,,0.0,0.0,0.0,0.0,0.0,0.171,0.171
blue,,0.001,0.001,0.0,0.001,0.0,0.049,0.049
purple,,0.0,0.0,0.0,0.0,0.0,0.001,0.001


In [22]:
from IPython.display import display

import pandas as pd

from ex_color.vis import ColorTableLatexFormatter

formatter = ColorTableLatexFormatter()
# print(formatter.preamble)
latex = formatter.to_str(
    pd.DataFrame(
        {
            'color': df['name'].str.capitalize(),
            'rgb': df['rgb'],
            'baseline': df['no intervention'],
            'suppression': df['suppression-delta'],
            'repulsion': df['repulsion-delta'],
            'ablation': df['ablation-delta'],
        }
    ),
    caption='Reconstruction error by color and intervention method',
    label='tab:error-per-color-soft',
)
display({'text/markdown': f'```latex\n{latex}\n```', 'text/plain': latex}, raw=True)

```latex
\begin{table}
\centering
\label{tab:error-per-color-soft}
\caption{Reconstruction error by color and intervention method}
\sisetup{
    round-mode = places,
    round-precision = 3,
    table-auto-round = true,
    % drop-zero-decimal = true,
}
\begin{tabular}{l c g g g g}
\toprule
\multicolumn{2}{c}{{Color}} & \multicolumn{1}{c}{{Baseline}} & \multicolumn{1}{c}{{Suppression}} & \multicolumn{1}{c}{{Repulsion}} & \multicolumn{1}{c}{{Ablation}} \\
\midrule
Red        & \swatch{FF0000} &  0.000768037 &  0.205282241 &  0.196813837 &  0.332565308 \\
Orange     & \swatch{FF7F00} &  0.000007175 &  0.121251166 &  0.071423836 &  0.144406125 \\
Yellow     & \swatch{FFFF00} &  0.000465135 &  0.046414554 &  0.017316537 &  0.039200030 \\
Lime       & \swatch{7FFF00} &  0.000046972 & -0.000013951 & -0.000010727 & -0.000014021 \\
Green      & \swatch{00FF00} &  0.000651126 &  0.000000000 &  0.000000000 &  0.045809921 \\
Teal       & \swatch{00FF7F} &  0.000004109 &  0.000000000 &  0.000000000 &  0.147797063 \\
Cyan       & \swatch{00FFFF} &  0.001199420 &  0.000000000 &  0.000000000 &  0.343560725 \\
Azure      & \swatch{007FFF} &  0.000035421 &  0.000000000 &  0.000000000 &  0.171221405 \\
Blue       & \swatch{0000FF} &  0.000555592 &  0.000000000 &  0.000000000 &  0.048887715 \\
Purple     & \swatch{7F00FF} &  0.000046439 &  0.000000000 &  0.000000000 &  0.000859078 \\
Magenta    & \swatch{FF00FF} &  0.000616221 &  0.030436348 &  0.011453097 &  0.025803939 \\
Pink       & \swatch{FF007F} &  0.000131380 &  0.113162704 &  0.074021526 &  0.142593428 \\
Black      & \swatch{000000} &  0.000072439 &  0.000000000 &  0.000000000 &  0.001567130 \\
Dark gray  & \swatch{3F3F3F} &  0.000058124 &  0.000000000 &  0.000000000 &  0.000449589 \\
Gray       & \swatch{7F7F7F} &  0.000161277 &  0.000000000 &  0.000000000 &  0.000215866 \\
Light gray & \swatch{BFBFBF} &  0.000028848 &  0.000000000 &  0.000000000 &  0.000028761 \\
White      & \swatch{FFFFFF} &  0.000146175 &  0.000000000 &  0.000000000 &  0.000147133 \\
\bottomrule
\end{tabular}
\end{table}
```

## Correlation: error vs. similarity to anchor per intervention

In [34]:
import skimage as ski
from scipy.stats import pearsonr

from ex_color.data import hsv_similarity


def error_correlation(cube: ColorCube, anchor_hsv: tuple[float, float, float]) -> tuple[float, float]:
    """Compute correlation between similarity to anchor and reconstruction error."""
    # Use squared similarity to match the squared error term
    cube = cube.assign(hsv=ski.color.rgb2hsv(cube['color']))
    cube = cube.assign(red_similarity=hsv_similarity(cube['hsv'], np.array(anchor_hsv), hemi=True, mode='cosine') ** 2)
    corr, p_value = pearsonr(cube['red_similarity'].flatten(), cube['MSE'].flatten())
    return float(corr), float(p_value)


corr, p_value = error_correlation(suppression_results.loss_cube, (0, 1, 1))
print(f'Suppression: r = {corr:.2f}, R²: {corr**2:.2f}, p = {p_value:.3g}')
corr, p_value = error_correlation(repulsion_results.loss_cube, (0, 1, 1))
print(f'Repulsion: r = {corr:.2f}, R²: {corr**2:.2f}, p = {p_value:.3g}')
corr, p_value = error_correlation(ablation_results.loss_cube, (0, 1, 1))
print(f'Ablation: r = {corr:.2f}, R²: {corr**2:.2f}, p = {p_value:.3g}')

Suppression: r = 0.99, R²: 0.97, p = 0
Repulsion: r = 0.96, R²: 0.92, p = 0
Ablation: r = 0.67, R²: 0.45, p = 0


In [35]:
print('Suppression')
visualize_error_vs_similarity(
    suppression_results.loss_cube.permute('vsh'),
    (0, 1, 1),
    tags=suppression_results.tags,
    anchor_name='red',
)

print('Repulsion')
visualize_error_vs_similarity(
    repulsion_results.loss_cube.permute('vsh'),
    (0, 1, 1),
    tags=repulsion_results.tags,
    anchor_name='red',
)

print('Ablation')
visualize_error_vs_similarity(
    ablation_results.loss_cube.permute('vsh'),
    (0, 1, 1),
    tags=ablation_results.tags,
    anchor_name='red',
)

Suppression


Repulsion


Ablation
