# Experiment 2.10: Delete only red with fewer regularization terms

In [Ex 2.9](ex-2.9-delete-only-red-5d.ipynb), succeed in deleting _red_ without deleting _cyan_ or other colors, with precision similar to an intervention with a cosine falloff. In that experiment, we used a subspace regularizer to attract desaturated colors to the last three dimensions. Let's see if we can get similar results without that regularizer. We'll also try removing the unitarity regularizer to verify that it's still needed.

## Hypothesis 1: subspace regularization is not needed

If we remove the subspace regularizer but keep unitarity, the anchor, and anti-anchor terms, the model will warp latent space enough to isolate _red_, and we will be able to delete _red_ without also deleting other colors. We should see error vs. color curves similar to those achieved in 2.9.

## Hypothesis 2: unitarity is still needed

If we remove unitarity in addition to the subspace term, the model latent space won't be regular enough to warp latent space enough to isolate _red_. Deleting red would also affect other colors.

In both cases, the bottleneck will be explicitly normalized.

In [1]:
from __future__ import annotations

nbid = '2.10'  # ID for tagging assets
nbname = 'Ablate red (only), 5D, fewer regularizers'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-color-transformer'

In [2]:
# Basic setup: Logging, Experiment (Modal)
import logging

import modal

from infra.requirements import freeze, project_packages
from mini.experiment import Experiment
from utils.logging import SimpleLoggingConfig

logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')

run = Experiment(experiment_name, project=project)
run.image = modal.Image.debian_slim().pip_install(*freeze(all=True)).add_local_python_source(*project_packages())
run.before_each(logging_config.apply)
None  # prevent auto-display of this cell

## Regularizers

Like Ex 2.9:

- **Anchor:** pins `red` to $(1,0,0,0,0)$ (5D)
- **Anti-anchor:** repels everything from $(-1,0,0,0,0)$ (5D)
- **Separate:** angular repulsion to reduce global clumping (applied within each batch)

But unlike 2.9:
- **AxisAlignedSubspace:** has been removed, relying on anti-anchor to keep other concepts clear of the dimension to be ablated.
- **Unitarity:** is present in this list, but we'll do a run without it too.

In [3]:
import torch

from mini.temporal.dopesheet import Dopesheet
from ex_color.loss import AngularAnchor, AntiAnchor, Separate, Unitarity, RegularizerConfig

from ex_color.training import TrainingModule

K = 5  # bottleneck dimensionality
RED = (1, 0, 0, 0, 0)
ANTI_RED = tuple(-c for c in RED)
assert len(RED) == len(ANTI_RED) == K

ALL_REGULARIZERS = [
    RegularizerConfig(
        name='reg-unit',
        compute_loss_term=Unitarity(),
        label_affinities=None,
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-anchor',
        compute_loss_term=AngularAnchor(torch.tensor(RED, dtype=torch.float32)),
        label_affinities={'red': 1.0},
        layer_affinities=['bottleneck'],
    ),
    RegularizerConfig(
        name='reg-separate',
        compute_loss_term=Separate(power=100.0, shift=True),
        label_affinities=None,
        layer_affinities=['bottleneck'],
    ),
    RegularizerConfig(
        name='reg-anti-anchor',
        compute_loss_term=AntiAnchor(torch.tensor(ANTI_RED, dtype=torch.float32)),
        label_affinities=None,
        layer_affinities=['bottleneck'],
    ),
]

## Data

Data is the same as last time: color cubes with values in RGB.


In [4]:
from functools import partial
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import numpy as np

from ex_color.data.color_cube import ColorCube
from ex_color.data.cube_sampler import vibrancy
from ex_color.data.cyclic import arange_cyclic
from ex_color.labelling import collate_with_generated_labels


def prep_data() -> tuple[DataLoader, Tensor]:
    """
    Prepare data for training.

    Returns: (train, val)
    """
    hsv_cube = ColorCube.from_hsv(
        h=arange_cyclic(step_size=10 / 360),
        s=np.linspace(0, 1, 10),
        v=np.linspace(0, 1, 10),
    )
    hsv_tensor = torch.tensor(hsv_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
    vibrancy_tensor = torch.tensor(vibrancy(hsv_cube).flatten(), dtype=torch.float32)
    hsv_dataset = TensorDataset(hsv_tensor, vibrancy_tensor)

    labeller = partial(
        collate_with_generated_labels,
        soft=False,  # Use binary labels (stochastic) to simulate the labelling of internet text
        red=0.5,
    )
    # Desaturated and dark colors are over-represented in the cube, so we use a weighted sampler to balance them out
    hsv_loader = DataLoader(
        hsv_dataset,
        batch_size=64,
        num_workers=2,
        sampler=WeightedRandomSampler(
            weights=hsv_cube.bias.flatten().tolist(),
            num_samples=len(hsv_dataset),
            replacement=True,
        ),
        collate_fn=labeller,
    )

    rgb_cube = ColorCube.from_rgb(
        r=np.linspace(0, 1, 8),
        g=np.linspace(0, 1, 8),
        b=np.linspace(0, 1, 8),
    )
    rgb_tensor = torch.tensor(rgb_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
    return hsv_loader, rgb_tensor

## Training

Like in Ex 2.2, the model is trained with PyTorch Lightning, with regularizers applied as custom hooks.

Unlike earlier experiments, the model now has two nonlinear activation functions in the encoder and decoder, to allow the latent space to be warped more.

In [5]:
import wandb
from ex_color.model import CNColorMLP


# @run.thither(env={'WANDB_API_KEY': wandb.Api().api_key})
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
    k_bottleneck: int,
) -> CNColorMLP:
    """Train the model with the given dopesheet and variant."""
    import lightning as L
    from lightning.pytorch.loggers import WandbLogger

    from ex_color.seed import set_deterministic_mode

    from utils.progress.lightning import LightningProgress

    log.info(f'Training with: {[r.name for r in regularizers]}')

    seed = 0
    set_deterministic_mode(seed)

    hsv_loader, _ = prep_data()

    model = CNColorMLP(k_bottleneck, n_nonlinear=2)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.debug(f'Model initialized with {total_params:,} trainable parameters.')

    training_module = TrainingModule(model, dopesheet, torch.nn.MSELoss(), regularizers)

    logger = WandbLogger(experiment_name, project='ex-color-transformer')

    trainer = L.Trainer(
        max_steps=len(dopesheet),
        callbacks=[
            LightningProgress(),
        ],
        enable_checkpointing=False,
        enable_model_summary=False,
        # enable_progress_bar=True,
        logger=logger,
    )

    print(f'max_steps: {len(dopesheet)}, hsv_loader length: {len(hsv_loader)}')

    # Train the model
    try:
        trainer.fit(training_module, hsv_loader)
    finally:
        wandb.finish()
    # This is only a small model, so it's OK to return it rather than storing and loading a checkpoint remotely
    return model

## Inference utils

We wrap the model that we trained above in an `InferenceModule`. We won't be using its intervention features.


In [6]:
from ex_color.inference import InferenceModule


async def infer(
    model: CNColorMLP,
    test_data: Tensor,
) -> Tensor:
    """Run inference with the given model."""
    import lightning as L

    inference_module = InferenceModule(model, [])
    trainer = L.Trainer(
        enable_checkpointing=False,
        enable_model_summary=False,
        enable_progress_bar=True,
    )
    reconstructed_colors_batches = trainer.predict(
        inference_module,
        DataLoader(
            TensorDataset(test_data.reshape((-1, 3))),
            batch_size=64,
            collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
        ),
    )
    assert reconstructed_colors_batches is not None
    # Flatten the list of batches to a single list of tensors
    reconstructed_colors = [item for batch in reconstructed_colors_batches for item in batch]
    # Reshape to match input
    return torch.cat(reconstructed_colors).reshape(test_data.shape)

In [7]:
import torch
import numpy as np

from ex_color.inference import InferenceModule


async def infer_with_latent_capture(
    model: CNColorMLP,
    test_data: Tensor,
    layer_name: str = 'bottleneck',
) -> tuple[Tensor, Tensor]:
    module = InferenceModule(model, [], capture_layers=[layer_name])
    import lightning as L

    trainer = L.Trainer(enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False)
    batches = trainer.predict(
        module,
        DataLoader(
            TensorDataset(test_data.reshape((-1, 3))),
            batch_size=64,
            collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
        ),
    )
    assert batches is not None
    preds = [item for batch in batches for item in batch]
    y = torch.cat(preds).reshape(test_data.shape)
    # Read captured activations as a flat [N, D] tensor
    latents = module.read_captured(layer_name)
    return y, latents

In [8]:
from IPython.display import clear_output

from ex_color.vis import plot_colors
from utils.nb import displayer_mpl


hsv_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=1 / 24),
    s=np.linspace(0, 1, 4),
    v=np.linspace(0, 1, 8),
).permute('svh')
x_hsv = torch.tensor(hsv_cube.rgb_grid, dtype=torch.float32)

hd_hsv_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=1 / 240),
    s=np.linspace(0, 1, 48),
    v=np.linspace(0, 1, 48),
)
hd_x_hsv = torch.tensor(hd_hsv_cube.rgb_grid, dtype=torch.float32)

rgb_cube = ColorCube.from_rgb(
    r=np.linspace(0, 1, 20),
    g=np.linspace(0, 1, 20),
    b=np.linspace(0, 1, 20),
)
x_rgb = torch.tensor(rgb_cube.rgb_grid, dtype=torch.float32)

with displayer_mpl(
    f'large-assets/ex-{nbid}-true-colors.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. The first slice shows a grayscale gradient from black to white; the last shows the fully-saturated color spectrum.""",
) as show:
    show(lambda: plot_colors(hsv_cube, title='True colors', colors=x_hsv.numpy()))

## No subspace

The first model we'll test is one that is trained without the subspace term, but with the unitarity regularization term included.

In [9]:
async with run():
    model_unit = await train(Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv'), ALL_REGULARIZERS, K)

I 23.7 no.2.10:Training with: ['reg-unit', 'reg-anchor', 'reg-separate', 'reg-anti-anchor']


INFO: Seed set to 0


I 23.7 li.fa.ut.se:Seed set to 0
I 23.7 ex.se:  PyTorch set to deterministic mode


INFO: GPU available: False, used: False


I 23.8 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 23.8 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 23.8 li.py.ut.ra:HPU available: False, using: 0 HPUs
max_steps: 3001, hsv_loader length: 57


[34m[1mwandb[0m: Currently logged in as: [33mz0r[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Starting phase: Train


INFO: `Trainer.fit` stopped: `max_steps=3001` reached.


I 69.1 li.py.ut.ra:`Trainer.fit` stopped: `max_steps=3001` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇████
train_loss,█▆▅▄▄▄▃▃▃▂▅▃▃▃▂▃▂▂▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_recon,█▄▃▂▁▂▃▂▂▂▃▂▂▃▄▂▂▂▂▂▂▂▃▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
train_reg-anchor,▁▁▁▇▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▂▁▁▁▁
train_reg-anti-anchor,▁▁▁▁▂▂▃▆█▅▆▂▁▂▁▁▁▂▁▂▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
train_reg-separate,█▅▅▄▅▃▃▂▃▃▂▂▃▂▂▃▂▁▂▂▁▂▁▂▁▁▁▂▂▁▁▁▁▁▁▁▂▂▂▁
train_reg-unit,█▇▅▄▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████

0,1
epoch,52.0
train_loss,2e-05
train_recon,2e-05
train_reg-anchor,0.0
train_reg-anti-anchor,0.00115
train_reg-separate,0.11513
train_reg-unit,0.00318
trainer/global_step,2999.0


In [10]:
from IPython.display import clear_output
from torch.nn import functional as F

from ex_color.vis import plot_colors, plot_cube_series


interventions = []
y_hsv = await infer(model_unit, x_hsv)
hd_y_hsv = await infer(model_unit, hd_x_hsv)
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-pred-colors-no-intervention-unit.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but some slight differences are visible; for example, "white" is slightly gray, and many of the fully-saturated colors are less saturated than they should be.""",
) as show:
    show(
        lambda: plot_colors(
            hsv_cube,
            title='Predicted colors · no intervention · no subspace',
            colors=y_hsv.numpy(),
            colors_compare=x_hsv.numpy(),
        )
    )

per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()

with displayer_mpl(
    f'large-assets/ex-{nbid}-loss-colors-no-intervention-unit.png',
    alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. The range of loss values is small, but there are two notable peaks at red, yellow, green, and cyan. Between those points the lines are wavy, reminiscent of an audio waveform.""",
) as show:
    show(
        lambda: plot_cube_series(
            loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
            loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
            loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
            title='Reconstruction error · no intervention · no subspace',
            var='MSE',
            figsize=(12, 3),
        )
    )
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')

Max loss: 0.00023
Median MSE: 1.5e-05


Reconstruction loss looks about as good as last time.

In [11]:
# # Generate a list of dimensions to visualize
# from itertools import combinations
# [
#     (
#         b,
#         a,
#         (a + 1) % 5 if (a + 1) % 5 not in (a, b) else (a + 2) % 5,
#     )
#     for a, b in combinations((0, 1, 2, 3, 4), 2)
# ]

In [12]:
from IPython.display import clear_output

from ex_color.vis import plot_latent_grid_3d

y_rgb, h_rgb = await infer_with_latent_capture(model_unit, x_rgb, 'bottleneck')
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-latents-no-intervention-unit.png',
    alt_text="""Two rows of three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a sphere. On the top row, the first plot shows a thick curve — like a tongue seen from the side — touching the top and right side of the sphere and passing through the middle. It is red at the top, purple and blue in the middle, and cyan on the right. The sphere is empty elsewhere. The other plots in the top row show different views of the same space, all with red at the top but a different horizontal axis. They look much more dome-shaped. The second row shows still more views, focused on the other dimensions. These are more spherical and almost look like color wheels, but with the colors out of order.""",
) as show:
    show(
        lambda theme: plot_latent_grid_3d(
            h_rgb,
            y_rgb,
            x_rgb,
            title='Latents · no intervention · no subspace',
            dims=[
                (1, 0, 2),
                (2, 0, 1),
                (3, 0, 1),
                # (4, 0, 1),
                # (2, 1, 3),
                # (3, 1, 2),
                (4, 1, 2),
                (3, 2, 4),
                # (4, 2, 3),
                (4, 3, 0),
            ],
            dot_radius=10,
            theme=theme,
        )
    )

Latent space looks suprisingly good: the area opposing red is clear, and there's a pronounced collection of reds near the anchor point.

### Ablation

In [13]:
from ex_color.surgery import ablate

ablated_model = ablate(model_unit, 'bottleneck', [0])

y_hsv = await infer(ablated_model, x_hsv)
hd_y_hsv = await infer(ablated_model, hd_x_hsv)
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-pred-colors-ablated-unit.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well up to yellow and purple, but disagree significantly near red. Desaturated colors and grays are almost unchanged.""",
) as show:
    show(
        lambda: plot_colors(
            hsv_cube,
            title='Predicted colors · ablated · no subspace',
            colors=y_hsv.numpy(),
            colors_compare=x_hsv.numpy(),
        )
    )

per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()
with displayer_mpl(
    f'large-assets/ex-{nbid}-loss-colors-ablated-unit.png',
    alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. There is very low error at yellow-green, green, cyan, blue, and purple; high error at red, and moderate error at yellow and magenta. Saturation and value show low error at white and black, with error levels gradually increasing toward vibrant red. The curves are fairly smooth, but show a high-frequency dip very close to red. In fact red-orange has higher error than pure red.""",
) as show:
    show(
        lambda: plot_cube_series(
            loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
            loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
            loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
            title='Reconstruction error · ablated · no subspace',
            var='MSE',
            figsize=(12, 3),
        )
    )
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')

Max loss: 0.36
Median MSE: 0.00043


Red has certainly been ablated, but cyan has been somewhat perturbed, especially the darker shades. The effect is very visible in the cube slices (top). Surprisingly, the loss curves are pretty low (but clearly non-zero) at green, cyan, and blue. I would have expected them to be higher, given how obviously perturbed the reconstructed colors are to the eye.

These results are clearly worse than [Ex 2.9](ex-2.9-delete-only-red-5d.ipynb) — so removing the subspace constraint seems to have harmed the model.

In [14]:
from IPython.display import clear_output

from ex_color.vis import plot_latent_grid_3d

y_rgb, h_rgb = await infer_with_latent_capture(ablated_model, x_rgb, 'bottleneck')
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-latents-ablated-unit.png',
    alt_text="""Two rows of three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a sphere. The vertical axis of each plot in the top row is the first dimension of latent space. The plots in the top row all have a line across the equator varying between purple, blue, green, and white. The bottom row shows similar colors, but with more of a ball-like appearance. Each circle has a point in the middle showing the true color of the sample; the bottom row shows that many of the warmer colors have been shifted to purple or black.""",
) as show:
    show(
        lambda theme: plot_latent_grid_3d(
            h_rgb,
            y_rgb,
            x_rgb,
            title='Latents · ablated · no subspace',
            dims=[
                (1, 0, 2),
                (2, 0, 1),
                (3, 0, 1),
                # (4, 0, 1),
                # (2, 1, 3),
                # (3, 1, 2),
                (4, 1, 2),
                (3, 2, 4),
                # (4, 2, 3),
                (4, 3, 0),
            ],
            dot_radius=10,
            theme=theme,
        )
    )

## No subspace, no unitarity

In [15]:
async with run():
    model_no_unit = await train(
        Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv'),
        [r for r in ALL_REGULARIZERS if not isinstance(r.compute_loss_term, Unitarity)],
        K,
    )

I 122.4 no.2.10:Training with: ['reg-anchor', 'reg-separate', 'reg-anti-anchor']


INFO: Seed set to 0


I 122.4 li.fa.ut.se:Seed set to 0
I 122.4 ex.se: PyTorch set to deterministic mode


INFO: GPU available: False, used: False


I 122.4 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 122.4 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 122.4 li.py.ut.ra:HPU available: False, using: 0 HPUs
max_steps: 3001, hsv_loader length: 57


Starting phase: Train


INFO: `Trainer.fit` stopped: `max_steps=3001` reached.


I 141.5 li.py.ut.ra:`Trainer.fit` stopped: `max_steps=3001` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇██
train_loss,█▆▅▃▃▃▂▄▅▂▃▃▃▂▃▃▂▄▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_recon,█▃▂▂▂▂▂▃▄▅▂▂▂▂▂▃▄▂▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_reg-anchor,▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▂▁▁▁▁▁
train_reg-anti-anchor,▁▁▁▁▁▅▅▃▄▄█▄▄▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_reg-separate,██▆▄▃▃▂▃▄▃▂▂▂▃▂▂▂▂▁▂▂▁▂▁▂▁▂▁▁▁▁▁▁▁▁▂▁▂▂▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇████

0,1
epoch,52.0
train_loss,5e-05
train_recon,5e-05
train_reg-anchor,0.0
train_reg-anti-anchor,0.0017
train_reg-separate,0.0947
trainer/global_step,2999.0


In [16]:
from IPython.display import clear_output
from torch.nn import functional as F

from ex_color.vis import plot_colors, plot_cube_series


interventions = []
y_hsv = await infer(model_no_unit, x_hsv)
hd_y_hsv = await infer(model_no_unit, hd_x_hsv)
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-pred-colors-no-intervention-no-unit.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well, but some slight differences are visible; for example, "white" is slightly gray, and many of the fully-saturated colors are less saturated than they should be.""",
) as show:
    show(
        lambda: plot_colors(
            hsv_cube,
            title='Predicted colors · no intervention · no subspace, no unitarity',
            colors=y_hsv.numpy(),
            colors_compare=x_hsv.numpy(),
        )
    )

per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()

with displayer_mpl(
    f'large-assets/ex-{nbid}-loss-colors-no-intervention-no-unit.png',
    alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. The range of loss values is small, but there are two notable peaks at red, yellow, green, and cyan. Between those points the lines are wavy, reminiscent of an audio waveform.""",
) as show:
    show(
        lambda: plot_cube_series(
            loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
            loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
            loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
            title='Reconstruction error · no intervention · no subspace, no unitarity',
            var='MSE',
            figsize=(12, 3),
        )
    )
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')

Max loss: 0.00031
Median MSE: 2.5e-05


The reconstruction loss for the model without unitarity is pretty good — roughly on par with the model with no unitarity constraint. That's a bit surprising, because I expected the explicit normalization to not allow enough signal to pass to the encoder. Let's see what latent space looks like.

In [17]:
from IPython.display import clear_output

from ex_color.vis import plot_latent_grid_3d

y_rgb, h_rgb = await infer_with_latent_capture(model_no_unit, x_rgb, 'bottleneck')
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-latents-no-intervention-no-unit.png',
    alt_text="""Two rows of three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a sphere. On the top row, the first plot shows a thick curve — like a tongue seen from the side — touching the top and right side of the sphere and passing through the middle. It is red at the top, purple and blue in the middle, and cyan on the right. The sphere is empty elsewhere. The other plots in the top row show different views of the same space, all with red at the top but a different horizontal axis. They look much more dome-shaped. The second row shows still more views, focused on the other dimensions. These are more spherical and almost look like color wheels, but with the colors out of order.""",
) as show:
    show(
        lambda theme: plot_latent_grid_3d(
            h_rgb,
            y_rgb,
            x_rgb,
            title='Latents · no intervention · no subspace, no unitarity',
            dims=[
                (1, 0, 2),
                (2, 0, 1),
                (3, 0, 1),
                # (4, 0, 1),
                # (2, 1, 3),
                # (3, 1, 2),
                (4, 1, 2),
                (3, 2, 4),
                # (4, 2, 3),
                (4, 3, 0),
            ],
            dot_radius=10,
            theme=theme,
        )
    )

Huh. Again that's surprisingly good: the spacing of the samples is about as regular as the model _with_ unitarity, and isolation of red colors looks significantly better. This time it has a nearly flat base (along dimension 1).

### Ablation

In [18]:
from ex_color.surgery import ablate

ablated_model = ablate(model_no_unit, 'bottleneck', [0])

y_hsv = await infer(ablated_model, x_hsv)
hd_y_hsv = await infer(ablated_model, hd_x_hsv)
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-pred-colors-ablated-no-unit.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color. The reconstructed and true colors agree fairly well up to yellow and purple, but disagree significantly near red. Desaturated colors and grays are almost unchanged.""",
) as show:
    show(
        lambda: plot_colors(
            hsv_cube,
            title='Predicted colors · ablated · no subspace, no unitarity',
            colors=y_hsv.numpy(),
            colors_compare=x_hsv.numpy(),
        )
    )

per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()
with displayer_mpl(
    f'large-assets/ex-{nbid}-loss-colors-ablated-no-unit.png',
    alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue. There is very low error at yellow-green, green, cyan, blue, and purple; high error at red, and moderate error at yellow and magenta. Saturation and value show low error at white and black, with error levels gradually increasing toward vibrant red. The curves are fairly smooth, but show a high-frequency dip very close to red. In fact red-orange has higher error than pure red.""",
) as show:
    show(
        lambda: plot_cube_series(
            loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
            loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
            loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -6)],
            title='Reconstruction error · ablated · no subspace, no unitarity',
            var='MSE',
            figsize=(12, 3),
        )
    )
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')

Max loss: 0.39
Median MSE: 6.1e-05


Wow OK that's _much_ better. Apart from _yellow_ which has a slightly higher reconstruction loss than in [Ex 2.9](ex-2.9-delete-only-red-5d.ipynb), there's less collateral damage across the board. The median error is particularly low, and the curves near _red_ look less noisy than the results in 2.9.

In [19]:
from IPython.display import clear_output

from ex_color.vis import plot_latent_grid_3d

y_rgb, h_rgb = await infer_with_latent_capture(ablated_model, x_rgb, 'bottleneck')
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-latents-ablated-no-unit.png',
    alt_text="""Two rows of three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a sphere. The vertical axis of each plot in the top row is the first dimension of latent space. The plots in the top row all have a line across the equator varying between purple, blue, green, and white. The bottom row shows similar colors, but with more of a ball-like appearance. Each circle has a point in the middle showing the true color of the sample; the bottom row shows that many of the warmer colors have been shifted to purple or black.""",
) as show:
    show(
        lambda theme: plot_latent_grid_3d(
            h_rgb,
            y_rgb,
            x_rgb,
            title='Latents · ablated · no subspace, no unitarity',
            dims=[
                (1, 0, 2),
                (2, 0, 1),
                (3, 0, 1),
                # (4, 0, 1),
                # (2, 1, 3),
                # (3, 1, 2),
                (4, 1, 2),
                (3, 2, 4),
                # (4, 2, 3),
                (4, 3, 0),
            ],
            dot_radius=10,
            theme=theme,
        )
    )

Here we clearly see that the first dimension has been removed from latent space, but it's hard to say what else has happened. I think maybe the space is now so high-dimensional that it's becoming hard to interpret.

## Conclusion

Hypothesis 1: partially confirmed. Without the subspace regularizer pulling desaturated colors away from the first two dimensions, base reconstruction loss was good, and we were able to ablate _red_. However there was some surprising collateral damage to other colors, particularly dark shades of _green_, _cyan_, and _blue_ (but not to vibrant _cyan_).

Hypothesis 2: disconfirmed! Removing the unitarity term didn't harm reconstruction at all, and in fact significantly improved the specificity of the ablation.

This makes me want to re-run some of the older experiments to see if unitarity was really needed for the lower-dimensional bottlenecks.