# Experiment 2.2: Specific concept intervention

In the [1.x series](/README.md#m1-preliminary-experiments-with-color) of experiments (milestone 1), we validated our ideas for imposing structure on latent space. With only weak supervision, we guided a simple RGB autoencoder to use the color wheel for its latent representations. In this series, we'll try to inhibit and even delete certain concepts from the model.

To start, let's take one of the earlier experiments and see what happens when we suppress activations that align with _red_.

## Hypothesis

We've structured the latent space so red is located at $[1,0,0,0]$. If we suppress or redirect activations close to that vector, model performance on near-red colors should drop, while other colors remain mostly unaffected.

In [1]:
# Basic setup: Logging, Experiment (Modal)
from __future__ import annotations
import logging

import modal

from infra.requirements import freeze, project_packages
from mini.experiment import Experiment
from utils.logging import SimpleLoggingConfig

logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()

# ID for tagging assets
nbid = '2.2'
# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')
experiment_name = f'ex-color-{nbid}'

run = Experiment(experiment_name)
run.image = modal.Image.debian_slim().pip_install(*freeze(all=True)).add_local_python_source(*project_packages())
run.before_each(logging_config.apply)
None  # prevent auto-display of this cell

## Regularizers

Like Ex 1.7:

- **Anchor:** pins `red` to $(1,0,0,0)$
- **Separate:** angular repulsion to reduce global clumping (applied within each batch)
- **Planarity:** pulls vibrant hues to the $[0, 1]$ plane
- **Unitarity:** pulls all embeddings to the surface of the unit hypersphere, i.e. it makes the embedding vectors have unit length. There are two terms: one that affects all colors equally, and another that just operates on the vibrant colors (because they seemed to need a little more help).

In [None]:
import torch

from mini.temporal.dopesheet import Dopesheet
from ex_color.loss import Anchor, Separate, planarity, unitarity, RegularizerConfig
from ex_color.model import ColorMLP
from ex_color.training import TrainingModule

RED = (1, 0, 0, 0)

ALL_REGULARIZERS = [
    RegularizerConfig(
        name='reg-anchor',
        compute_loss_term=Anchor(torch.tensor(RED, dtype=torch.float32)),
        label_affinities={'red': 1.0},
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-separate',
        compute_loss_term=Separate(power=10.0, shift=False),
        label_affinities=None,
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-planar',
        compute_loss_term=planarity,
        label_affinities={'vibrant': 1.0},
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-unit-v',
        compute_loss_term=unitarity,
        label_affinities={'vibrant': 1.0},
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-unit',
        compute_loss_term=unitarity,
        label_affinities=None,
        layer_affinities=['encoder'],
    ),
]

## Interventions

Now, we'll define an intervention to suppress the concept of 'red'. This will use the `steer_away` function, which projects the activations onto the subspace orthogonal to the 'red' vector we used in the `Anchor` regularizer. This should effectively remove the 'redness' from the model's representations at inference time.

In [None]:
import torch

from ex_color.intervention.intervention import InterventionConfig

RED_VECTOR = torch.tensor([1, 0, 0, 0], dtype=torch.float32)

ALL_INTERVENTIONS = [
    # InterventionConfig(...),
]

## Data

Data is the same as last time:
- Train: an HSV cube (of RGB values)
- Test: an RGB cube

In [3]:
from functools import partial
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import numpy as np

from ex_color.data.color_cube import ColorCube
from ex_color.data.cube_sampler import vibrancy
from ex_color.data.cyclic import arange_cyclic
from ex_color.labelling import collate_with_generated_labels


def prep_data() -> tuple[DataLoader, Tensor]:
    """
    Prepare data for training.

    Returns: (train, val)
    """
    hsv_cube = ColorCube.from_hsv(
        h=arange_cyclic(step_size=10 / 360),
        s=np.linspace(0, 1, 10),
        v=np.linspace(0, 1, 10),
    )
    hsv_tensor = torch.tensor(hsv_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
    vibrancy_tensor = torch.tensor(vibrancy(hsv_cube).flatten(), dtype=torch.float32)
    hsv_dataset = TensorDataset(hsv_tensor, vibrancy_tensor)

    labeller = partial(
        collate_with_generated_labels,
        soft=False,  # Use binary labels (stochastic) to simulate the labelling of internet text
        scale={'red': 0.5, 'vibrant': 0.5},
    )
    # Desaturated and dark colors are over-represented in the cube, so we use a weighted sampler to balance them out
    hsv_loader = DataLoader(
        hsv_dataset,
        batch_size=64,
        sampler=WeightedRandomSampler(
            weights=hsv_cube.bias.flatten().tolist(),
            num_samples=len(hsv_dataset),
            replacement=True,
        ),
        collate_fn=labeller,
    )

    rgb_cube = ColorCube.from_rgb(
        r=np.linspace(0, 1, 8),
        g=np.linspace(0, 1, 8),
        b=np.linspace(0, 1, 8),
    )
    rgb_tensor = torch.tensor(rgb_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
    return hsv_loader, rgb_tensor

## Training

_Unlike_ earlier experiments, we've switched over to use PyTorch Lightning instead of our custom training loop. We also tried porting to Catalyst and Ignite, but we found that Lightning was the closest match to the shape that our training code had evolved into.

Functionally, not much has changed at this point, but now we should be able to take advantage of things like [Lightning's distributed processing support](https://lightning.ai/docs/pytorch/stable/api_references.html#strategies).

We have also switched to using Modal for remote compute, and Weights and Biases for experiment tracking. We also tried running our own Aim experiment tracker instance. It worked, but it was slow. We're not sure why; maybe we just hadn't configured the storage or networking properly. If you're curious, check out the [`aim` tag in the Git history](https://github.com/z0u/ex-color-transformer/tree/aim).

In [None]:
import wandb
from ex_color.inference import InferenceModule
from ex_color.intervention.intervention import InterventionConfig

wandb_api_key = wandb.Api().api_key


@run.thither
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
) -> ColorMLP:
    """Train the model with the given dopesheet and variant."""
    import lightning as L
    from lightning.pytorch.loggers import WandbLogger

    from ex_color.seed import set_deterministic_mode

    from utils.progress.lightning import LightningProgress

    log.info(f'Training with: {[r.name for r in regularizers]}')

    seed = 0
    set_deterministic_mode(seed)

    hsv_loader, rgb_tensor = prep_data()

    model = ColorMLP(4)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.debug(f'Model initialized with {total_params:,} trainable parameters.')

    training_module = TrainingModule(model, dopesheet, torch.nn.MSELoss(), regularizers)

    wandb.login(key=wandb_api_key)
    logger = WandbLogger(experiment_name, project='ex-color-transformer')

    trainer = L.Trainer(
        max_steps=len(dopesheet),
        callbacks=[
            LightningProgress(),
        ],
        enable_checkpointing=False,
        enable_model_summary=False,
        # enable_progress_bar=True,
        logger=logger,
    )

    print(f'max_steps: {len(dopesheet)}, hsv_loader length: {len(hsv_loader)}')

    # Train the model
    trainer.fit(training_module, hsv_loader)
    # This is only a small model, so it's OK to return it rather than storing and loading a checkpoint remotely
    return model


@run.thither
async def infer(
    model: ColorMLP,
    interventions: list[InterventionConfig],
) -> Tensor:
    """Run inference with the given model and interventions."""
    import lightning as L

    _, rgb_tensor = prep_data()
    inference_module = InferenceModule(model, interventions)
    trainer = L.Trainer(
        enable_checkpointing=False,
        enable_model_summary=False,
        enable_progress_bar=True,
    )
    reconstructed_colors_batches = trainer.predict(
        inference_module, DataLoader(TensorDataset(rgb_tensor), batch_size=64)
    )
    assert reconstructed_colors_batches is not None
    # Flatten the list of batches to a single list of tensors
    reconstructed_colors = [item for batch in reconstructed_colors_batches for item in batch]
    return torch.cat(reconstructed_colors)


async with run():
    model = await train(Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv'), ALL_REGULARIZERS)

    reconstructed_colors = await infer(model, ALL_INTERVENTIONS)
    log.info(f'Reconstructed {reconstructed_colors.shape[0]} colors.')

Seed set to 0
I 0.3 no.2.1:  Training with: ['reg-anchor', 'reg-separate', 'reg-planar', 'reg-unit-v', 'reg-unit']
I 0.3 ex.se:   PyTorch set to deterministic mode
wandb: No netrc file found, creating one.
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
wandb: Currently logged in as: z0r to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
max_steps: 2001, hsv_loader length: 57
wandb: Tracking run with wandb version 0.21.0
wandb: Run data is saved locally in ./wandb/run-20250807_084721-b25i121x
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run ex-color-2.1
wandb: ⭐️ View project at https://wandb.ai/z0r/ex-color-transformer
wandb: 🚀 View run at https://wandb.ai/z0r/ex-color-transformer/runs/b25i121x
/usr/local/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' d

## Suppression

Now that we have our model, let's try suppressing _red_. We'll use the `Suppression` function developed in [Ex 2.1](./ex-2.1-intervention-lobe.ipynb).

In [None]:
from math import cos, pi
import torch

from ex_color.intervention.bounded_falloff import BoundedFalloff
from ex_color.intervention.suppression import Suppression


intervention = Suppression(
    torch.tensor(RED, dtype=torch.float32),
    BoundedFalloff(cos(pi / 3), 1.0, 2),
)