# Experiment 2.5.1: Only one anchor

This is a re-run of [Ex 2.5](./ex-2.5-only-red.ipynb) with more mature tooling. See the earlier notebook for discussion.

We only anchor _red_, without the planarity regularizer used in earlier experiments. The model should perform as well or better than previous models, and the interventions should be at least as precise.

In [1]:
from __future__ import annotations

nbid = '2.5.1'  # ID for tagging assets
nbname = 'Only red'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-preppy'

In [2]:
# Basic setup: Logging, Experiment (Modal)
import logging

import modal

from infra.requirements import uv_freeze, project_packages
from utils.logging import SimpleLoggingConfig
from ex_color.vis import NbViz

logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')

image = (
    modal.Image.debian_slim()
    .pip_install(*uv_freeze(all_groups=True, not_groups='dev'))
    .add_local_python_source(*project_packages())
)
volume = modal.Volume.from_name(f'{project}-{nbid}', create_if_missing=True, version=2)
app = modal.App(name=f'{project}-{nbid}', image=image, volumes={'/data': volume})

viz = NbViz(nbid)
None  # prevent auto-display of this cell

## Model parameters

We use the following regularizers:

- **Anchor:** pins `red` to $(1,0,0,0)$ (4D)
- **Separate:** angular repulsion to reduce global clumping (applied within each batch)
- **Unitarity:** pulls all embeddings to the surface of the unit hypersphere

Unlike Ex 2.4:
- **Planarity:** has been removed.

In [3]:
import torch

from ex_color.loss import AngularAnchor, Separate, RegularizerConfig

K = 4  # bottleneck dimensionality
N = 1  # number of nonlinear layers
H = 16  # hidden layer size
RED = (1, 0, 0, 0)
assert len(RED) == K
BATCH_SIZE = 64
CUBE_SUBDIVISIONS = 8
NUM_RUNS = 60
RUN_SEEDS = [i for i in range(NUM_RUNS)]

reg_separate = RegularizerConfig(
    name='separate',
    compute_loss_term=Separate(power=100.0, shift=True),
    label_affinities=None,
    layer_affinities=['bottleneck'],
)
reg_anchor = RegularizerConfig(
    name='anchor',
    compute_loss_term=AngularAnchor(torch.tensor(RED, dtype=torch.float32)),
    label_affinities={'red': 1.0},
    layer_affinities=['bottleneck'],
    phase=('train', 'validate'),
)

In [4]:
from mini.temporal.dopesheet import Dopesheet

dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')
viz.tab_dopesheet(dopesheet)
viz.plot_dopesheet(dopesheet)

## Parameter schedule 
|   STEP | PHASE   |   ACTION |      lr |   separate |   anchor |
|-------:|:--------|---------:|--------:|-----------:|---------:|
|      0 | Train   |          |   1e-08 |            |     0.01 |
|     10 |         |          |   0.01  |            |          |
|    375 |         |          |         |            |          |
|    750 |         |          |   0.1   |      0.015 |     0.1  |
|   1125 |         |          |         |            |          |
|   1425 |         |          |   0.1   |      0     |     0    |
|   1500 |         |          |   0.05  |            |          |

## Data

Data is the same as last time: color cubes with values in RGB.


In [5]:
from torch.utils.data import DataLoader, RandomSampler

from ex_color.data.cube_dataset import prep_color_dataset, redness, stochastic_labels, exact_labels


def prep_train_data(training_subs: int, *, batch_size: int) -> DataLoader:
    dataset = prep_color_dataset(
        training_subs,
        sample_at='cell-corners',
        red=lambda c: redness(c) ** 8 * 0.08,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=4,
        sampler=RandomSampler(dataset, num_samples=len(dataset), replacement=True),
        collate_fn=stochastic_labels,
    )


def prep_val_data(training_subs: int, *, batch_size: int) -> DataLoader:
    dataset = prep_color_dataset(
        training_subs,
        sample_at='cell-centers',
        red=lambda c: redness(c) == 1,
    )
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=2,
        collate_fn=exact_labels,
    )

## Train

In [6]:
from typing import Callable

import torch
import wandb

from ex_color.model import CNColorMLP
from ex_color.seed import set_deterministic_mode
from ex_color.workflow import train_model
from ex_color.evaluation import Result
from utils.time import hour


@app.function(
    cpu=1,
    max_containers=20,
    timeout=1 * hour,
    env={'WANDB_API_KEY': wandb.Api().api_key or ''},
)
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
    *,
    seed: int,
    score_fn: Callable[[CNColorMLP], float],
):
    """Train the model with the given dopesheet and variant."""
    logging_config.apply()

    set_deterministic_mode(seed)

    train_loader = prep_train_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)
    val_loader = prep_val_data(CUBE_SUBDIVISIONS, batch_size=BATCH_SIZE)
    model = CNColorMLP(K, n_nonlinear=N)
    res = train_model(
        model,
        dopesheet,
        regularizers,
        train_loader,
        val_loader,
        experiment_name=experiment_name,
        project=project,
        hparams={'seed': seed},
    )

    score = score_fn(res.model)
    key = f'model-{res.id_}.pt'
    torch.save(res.model.state_dict(), f'/data/{key}')
    return Result(seed, key, res.url, res.summary, score)

In [7]:
from math import cos, radians

from ex_color.evaluation import EvaluationPlan, ScoreByHSVSimilarity
from ex_color.intervention import InterventionConfig, Suppression, BoundedFalloff


falloff = BoundedFalloff(
    cos(radians(90)),
    1,  # completely squash aligned vectors
    0,  # constant effect (no fall-off)
)
suppression = InterventionConfig(
    apply=Suppression(torch.tensor(RED), falloff),
    layer_affinities=['bottleneck'],
)
suppression_plan = EvaluationPlan(
    {'suppression'},
    lambda m: m,
    [suppression],
)

score_fn = ScoreByHSVSimilarity(suppression_plan, (0.0, 1.0, 1.0), power=2.0, cube_subdivisions=CUBE_SUBDIVISIONS)

In [8]:
import asyncio

# Reload dopesheet: makes tweaking params during development easier
dopesheet = Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv')
regularizers = [reg_separate, reg_anchor]


async def sweep():
    logging_config.apply()
    workers = [train.remote.aio(dopesheet, regularizers, seed=seed, score_fn=score_fn) for seed in RUN_SEEDS]
    return await asyncio.gather(*workers)


with app.run():
    results = await sweep()

In [9]:
from IPython.display import display
from ex_color.evaluation import results_to_dataframe

runs_df = results_to_dataframe(results)
# Show min, max, mean, stddev of each column
log.info(f'Summary statistics for all {len(runs_df)} runs:')
display(runs_df.describe().loc[['min', 'max', 'mean', 'std']].style.format(precision=4))

print('Correlation of reconstruction error vs. similarity to anchor')
viz.plot_boxplot(runs_df['score'], ylabel='', xlim=(None, 1), tags=('score',))

print('Reconstruction loss')
viz.plot_boxplot(runs_df['val_recon'], ylabel='', log_scale=True, tags=('val_recon',))

print('Anchor loss')
viz.plot_boxplot(runs_df['val_anchor'], ylabel='', log_scale=True, tags=('val_anchor',))

I 334.8 no.2.5.1:Summary statistics for all 60 runs:


Unnamed: 0,seed,score,labels/n/red,_runtime,val_recon,val_anchor,val_loss,labels/n/_any,labels/n_total
min,0.0,0.8513,62.0,51.1265,0.0,0.0003,0.0,62.0,96064.0
max,59.0,0.9762,103.0,159.3958,0.0,0.0042,0.0,103.0,96064.0
mean,29.5,0.9231,81.65,92.9487,0.0,0.0016,0.0,81.65,96064.0
std,17.4642,0.0316,8.9041,23.6588,0.0,0.001,0.0,8.9041,0.0


Correlation of reconstruction error vs. similarity to anchor


Reconstruction loss


Anchor loss


Select the best runs from the Pareto front of non-dominated runs, optimizing for both validation loss and score.

In [10]:
from ex_color.evaluation import pareto_front

non_dominated = pareto_front(runs_df, minimize=['val_recon', 'val_anchor'], maximize=['score'])
log.info(f'Best of {len(non_dominated)} non-dominated runs (Pareto front):')
display(non_dominated.sort_values(by='score', ascending=False).head(5).style.format(precision=4, hyperlinks='html'))

I 336.8 no.2.5.1:Best of 12 non-dominated runs (Pareto front):


Unnamed: 0,seed,wandb url,score,labels/n/red,_runtime,val_recon,val_anchor,val_loss,labels/n/_any,labels/n_total
8,8,https://wandb.ai/z0r/ex-preppy/runs/ilmm4gwa,0.9762,79,103.3465,0.0,0.002,0.0,79,96064
58,58,https://wandb.ai/z0r/ex-preppy/runs/sfl2qzqk,0.9745,80,97.7478,0.0,0.0042,0.0,80,96064
44,44,https://wandb.ai/z0r/ex-preppy/runs/6rc3fhzl,0.9731,85,123.7886,0.0,0.0014,0.0,85,96064
25,25,https://wandb.ai/z0r/ex-preppy/runs/50zb1ctm,0.9553,81,89.0483,0.0,0.0022,0.0,81,96064
18,18,https://wandb.ai/z0r/ex-preppy/runs/moxj233f,0.9553,87,88.69,0.0,0.0011,0.0,87,96064


In [11]:
from typing import cast

from mini.data import load_checkpoint_from_volume

best_run = results[cast(int, non_dominated['score'].idxmax())]
log.info(f'Loading checkpoint of best run: seed={best_run.seed}, score={best_run.score:.4f} @ {best_run.url}')
model = CNColorMLP(K, n_nonlinear=N)
model = load_checkpoint_from_volume(model, volume, best_run.checkpoint_key)

I 336.8 no.2.5.1:Loading checkpoint of best run: seed=8, score=0.9762 @ https://wandb.ai/z0r/ex-preppy/runs/ilmm4gwa


## Results

In [12]:
# # Generate a list of dimensions to visualize
# from itertools import combinations
# [
#     (
#         b,
#         a,
#         (a + 1) % 4 if (a + 1) % 4 not in (a, b) else (a + 2) % 4,
#     )
#     for a, b in combinations((0, 1, 2, 3), 2)
# ]

In [13]:
from ex_color.evaluation import TestSet

test_set = TestSet.create()

In [14]:
from IPython.display import clear_output

baseline_results = test_set.evaluate(model, [], tags={'baseline'})
clear_output()

viz.plot_cube(baseline_results)
# viz.plot_recon_loss(baseline_results)
# viz.plot_latent_space(
#     baseline_results,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1)],
# )

### Suppression

In [15]:
from math import cos, radians
from IPython.display import clear_output

from ex_color.intervention import Suppression, BoundedFalloff, InterventionConfig


falloff = BoundedFalloff(
    cos(radians(90)),  # cos(max_angle)
    1,  # completely squash fully-aligned vectors
    # 2,  # soft rim, sharp hub
    0,
)
suppression = InterventionConfig(
    apply=Suppression(torch.tensor(RED), falloff),
    layer_affinities=['bottleneck'],
)
suppression_results = test_set.evaluate(model, [suppression], tags={'suppression'})
clear_output()

viz.plot_cube(suppression_results)
# viz.plot_recon_loss(suppression_results)
# viz.plot_latent_space(
#     suppression_results,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1)],
# )

### Ablation

Included for comparison/completeness, but this model was not really designed for it.

In [16]:
from IPython.display import clear_output
from ex_color.surgery import ablate

ablated_model = ablate(model, 'bottleneck', [0])
ablation_results = test_set.evaluate(ablated_model, [], tags={'ablated'})
clear_output()

viz.plot_cube(ablation_results)
# viz.plot_recon_loss(ablation_results)
# viz.plot_latent_space(
#     ablation_results,
#     dims=[(1, 0, 2), (2, 0, 1), (3, 0, 1)],
# )

In [17]:
import numpy as np
from ex_color.vis.helpers import ThemedAnnotation


max_error = np.max(
    [
        baseline_results.loss_cube['MSE'],
        ablation_results.loss_cube['MSE'],
        suppression_results.loss_cube['MSE'],
    ]
)

dims = ((2, 0, 1), (1, 2, 0))
pruned_dims = ((1, None, 0), (0, 1, None))

print('Baseline')
viz.plot_stacked_results(
    baseline_results,
    latent_dims=dims,
    max_error=max_error,
)

print('Suppression')
viz.plot_stacked_results(
    suppression_results,
    latent_dims=dims,
    max_error=max_error,
    latent_annotations=[
        ThemedAnnotation(direction=RED, angle=2 * (np.pi / 2 - falloff.a), dashed=True),
    ],
)

print('Ablation')
viz.plot_stacked_results(
    ablation_results,
    latent_dims=dims,
    max_error=max_error,
)

Baseline


Suppression


Ablation


In [18]:
viz.tab_error_vs_color(baseline_results, suppression_results, ablation_results)
viz.tab_error_vs_color_latex(baseline_results, suppression_results, ablation_results)

Name,RGB,Baseline,Suppression,Δ Sup,Ablated,Δ Abl
red,,0.001,0.234,0.233,0.334,0.334
orange,,0.0,0.12,0.12,0.137,0.137
yellow,,0.0,0.04,0.04,0.029,0.029
lime,,0.0,0.0,0.0,0.0,0.0
green,,0.0,0.0,0.0,0.026,0.026
teal,,0.0,0.0,0.0,0.135,0.135
cyan,,0.0,0.0,0.0,0.248,0.248
azure,,0.0,0.0,0.0,0.132,0.132
blue,,0.0,0.0,0.0,0.025,0.024
purple,,0.0,0.0,0.0,0.0,0.0


```latex
\begin{table}
\centering
\label{tab:placeholder}
\caption{Reconstruction error by color and intervention method}
\sisetup{
    round-mode = places,
    round-precision = 3,
    table-auto-round = true,
    % drop-zero-decimal = true,
}
\begin{tabular}{l c g g g}
\toprule
\multicolumn{2}{c}{{Color}} & \multicolumn{1}{c}{{Baseline}} & \multicolumn{1}{c}{{Suppression}} & \multicolumn{1}{c}{{Ab}} \\
\midrule
Red        & \swatch{FF0000} &  0.000569792 &  0.233233362 &  0.333869487 \\
Orange     & \swatch{FF7F00} &  0.000023389 &  0.119690016 &  0.137201920 \\
Yellow     & \swatch{FFFF00} &  0.000036972 &  0.040286772 &  0.029353250 \\
Lime       & \swatch{7FFF00} &  0.000013694 &  0.000000042 &  0.000000040 \\
Green      & \swatch{00FF00} &  0.000108810 &  0.000000000 &  0.025526717 \\
Teal       & \swatch{00FF7F} &  0.000030726 &  0.000000000 &  0.135105342 \\
Cyan       & \swatch{00FFFF} &  0.000124518 &  0.000000000 &  0.247848153 \\
Azure      & \swatch{007FFF} &  0.000082252 &  0.000000000 &  0.132054538 \\
Blue       & \swatch{0000FF} &  0.000121482 &  0.000000000 &  0.024472622 \\
Purple     & \swatch{7F00FF} &  0.000014293 &  0.000167103 &  0.000164949 \\
Magenta    & \swatch{FF00FF} &  0.000278841 &  0.041148938 &  0.030989304 \\
Pink       & \swatch{FF007F} &  0.000000015 &  0.120681360 &  0.132281214 \\
Black      & \swatch{000000} &  0.000528365 &  0.004112418 &  0.003650059 \\
Dark gray  & \swatch{3F3F3F} &  0.000018525 &  0.001252160 &  0.001274158 \\
Gray       & \swatch{7F7F7F} &  0.000056444 &  0.000024199 &  0.000024105 \\
Light gray & \swatch{BFBFBF} &  0.000010967 &  0.000000000 &  0.000791224 \\
White      & \swatch{FFFFFF} &  0.000166427 &  0.000000000 &  0.001296695 \\
\bottomrule
\end{tabular}
\end{table}
```

In [19]:
viz.plot_error_vs_similarity(
    suppression_results,
    (0, 1, 1),
    anchor_name='red',
    power=2,
)

viz.plot_error_vs_similarity(
    ablation_results,
    (0, 1, 1),
    anchor_name='red',
    power=3,
)

MSE,sim² suppression: r = 0.99, R²: 0.98, p = 0


MSE,sim³ ablated: r = 0.68, R²: 0.47, p = 0
