# Experiment 2.1: Subspace ablation

In the [1.x series](/README.md#m1-preliminary-experiments-with-color) of experiments (Milestone 1), we validated our ideas for imposing structure on latent space. Providing only weak supervision, we were able to guide a simple RGB autoencoder to use the color wheel for its latent representations. In this series, we will attempt to inhibit and even delete certain concepts from the model.

To start with, let's take one of the earlier experiments and see what happens when we delete the _hue_ subspace.

In [1]:
# Basic setup: Logging, Experiment (Modal)
from __future__ import annotations
import logging

import modal

from infra.requirements import freeze, project_packages
from mini.experiment import Experiment
from utils.logging import SimpleLoggingConfig

logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color', 'track')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()

# ID for tagging assets
nbid = '2.1'
# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')
experiment_name = f'ex-color-{nbid}'

run = Experiment(experiment_name)
run.image = modal.Image.debian_slim().pip_install(*freeze(all=True)).add_local_python_source(*project_packages())
run.before_each(logging_config.apply)
None  # prevent auto-display of this cell

In [2]:
import torch

from mini.temporal.dopesheet import Dopesheet
from ex_color.loss import Anchor, Separate, planarity, unitarity, RegularizerConfig
from ex_color.model import ColorMLP
from ex_color.training import TrainingModule


ALL_REGULARIZERS = [
    RegularizerConfig(
        name='reg-polar',
        compute_loss_term=Anchor(torch.tensor([1, 0, 0, 0], dtype=torch.float32)),
        label_affinities={'red': 1.0},
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-separate',
        compute_loss_term=Separate(power=10.0, shift=False),
        label_affinities=None,
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-planar',
        compute_loss_term=planarity,
        label_affinities={'vibrant': 1.0},
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-norm-v',
        compute_loss_term=unitarity,
        label_affinities={'vibrant': 1.0},
        layer_affinities=['encoder'],
    ),
    RegularizerConfig(
        name='reg-norm',
        compute_loss_term=unitarity,
        label_affinities=None,
        layer_affinities=['encoder'],
    ),
]

Data is the same as last time:
- Train: an HSV cube (of RGB values)
- Test: an RGB cube

In [4]:
from functools import partial
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
import numpy as np

from ex_color.data.color_cube import ColorCube
from ex_color.data.cube_sampler import vibrancy
from ex_color.data.cyclic import arange_cyclic
from ex_color.labelling import collate_with_generated_labels


def prep_data() -> tuple[DataLoader, Tensor]:
    """
    Prepare data for training.

    Returns: (train, val)
    """
    hsv_cube = ColorCube.from_hsv(
        h=arange_cyclic(step_size=10 / 360),
        s=np.linspace(0, 1, 10),
        v=np.linspace(0, 1, 10),
    )
    hsv_tensor = torch.tensor(hsv_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
    vibrancy_tensor = torch.tensor(vibrancy(hsv_cube).flatten(), dtype=torch.float32)
    hsv_dataset = TensorDataset(hsv_tensor, vibrancy_tensor)

    labeller = partial(
        collate_with_generated_labels,
        soft=False,  # Use binary labels (stochastic) to simulate the labelling of internet text
        scale={'red': 0.5, 'vibrant': 0.5},
    )
    # Desaturated and dark colors are over-represented in the cube, so we use a weighted sampler to balance them out
    hsv_loader = DataLoader(
        hsv_dataset,
        batch_size=64,
        sampler=WeightedRandomSampler(
            weights=hsv_cube.bias.flatten().tolist(),
            num_samples=len(hsv_dataset),
            replacement=True,
        ),
        collate_fn=labeller,
    )

    rgb_cube = ColorCube.from_rgb(
        r=np.linspace(0, 1, 8),
        g=np.linspace(0, 1, 8),
        b=np.linspace(0, 1, 8),
    )
    rgb_tensor = torch.tensor(rgb_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
    return hsv_loader, rgb_tensor

In [5]:
import wandb

wandb_api_key = wandb.Api().api_key


@run.thither
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
):
    """Train the model with the given dopesheet and variant."""
    import lightning as L
    from lightning.pytorch.loggers import WandbLogger

    # from ex_color.lightning_callbacks import AuthAimLogger
    from ex_color.seed import set_deterministic_mode

    # from track.aim import get_repo_loc
    from utils.progress.lightning import LightningProgress

    log.info(f'Training with: {[r.name for r in regularizers]}')

    seed = 0
    set_deterministic_mode(seed)
    # repo_loc = get_repo_loc()

    hsv_loader, rgb_tensor = prep_data()

    model = ColorMLP(4)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.debug(f'Model initialized with {total_params:,} trainable parameters.')

    training_module = TrainingModule(model, dopesheet, torch.nn.MSELoss(), regularizers)

    wandb.login(key=wandb_api_key)
    logger = WandbLogger(experiment_name, project='ex-color-transformer')

    trainer = L.Trainer(
        max_steps=len(dopesheet),
        callbacks=[
            LightningProgress(),
        ],
        enable_checkpointing=False,
        enable_model_summary=False,
        # enable_progress_bar=True,
        logger=logger,
    )

    print(f'max_steps: {len(dopesheet)}, hsv_loader length: {len(hsv_loader)}')

    # Train the model
    trainer.fit(training_module, hsv_loader)


async with run():
    await train(Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv'), ALL_REGULARIZERS)

I 0.2 no.2.1:  Training with: ['reg-polar', 'reg-separate', 'reg-planar', 'reg-norm-v', 'reg-norm']
I 0.2 ex.se:   PyTorch set to deterministic mode
Seed set to 0
wandb: No netrc file found, creating one.
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
wandb: Currently logged in as: z0r to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
max_steps: 2001, hsv_loader length: 57
wandb: Tracking run with wandb version 0.21.0
wandb: Run data is saved locally in ./wandb/run-20250805_032314-2dvogdp5
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run ex-color-2.1
wandb: ⭐️ View project at https://wandb.ai/z0r/ex-color-transformer
wandb: 🚀 View run at https://wandb.ai/z0r/ex-color-transformer/runs/2dvogdp5
/usr/local/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' do