# Experiment 2.11: Measure label coverage required for effective weak supervision

In earlier experiments, we have shown that weak supervision is sufficient to produce a well-structure latent space. Our regularizers are selectively applied to samples based on their labels, but the labels are noisy. For example, the color $(1,0,0)$ is labelled "red" only 50% of the time, and other colors may be labelled "red" too — only less often. But the choice of 50% was arbitrary, and we never really explored what proportion of samples actually needed to be labelled.

In this experiment, we'll train the color autoencoder several times with different label coverage, and see what proportion of the training data needs to be labelled.

In [1]:
from __future__ import annotations

nbid = '2.11'  # ID for tagging assets
nbname = 'Measure label coverage required for effective weak supervision'
experiment_name = f'Ex {nbid}: {nbname}'
project = 'ex-preppy'

In [2]:
# Basic setup: Logging, Experiment (Modal)
import logging

import modal

from infra.requirements import freeze, project_packages
from mini.experiment import Experiment
from utils.logging import SimpleLoggingConfig

logging_config = (
    SimpleLoggingConfig()
    .info('notebook', 'utils', 'mini', 'ex_color')
    .error('matplotlib.axes')  # Silence warnings about set_aspect
)
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger(f'notebook.{nbid}')

run = Experiment(experiment_name, project=project)
run.image = modal.Image.debian_slim().pip_install(*freeze(all=True)).add_local_python_source(*project_packages())
run.before_each(logging_config.apply)
None  # prevent auto-display of this cell

## Regularizers

- **Anchor:** pins `red` to $(1,0,0,0)$
- **Anti-anchor:** repels everything from $(-1,0,0,0)$
- **Separate:** angular repulsion to reduce global clumping (applied within each batch)

In [3]:
import torch

from mini.temporal.dopesheet import Dopesheet
from ex_color.loss import AngularAnchor, Separate, RegularizerConfig, Planarity

from ex_color.training import TrainingModule

K = 4  # bottleneck dimensionality
RED = (1, 0, 0, 0)
ANTI_RED = tuple(-c for c in RED)
assert len(RED) == len(ANTI_RED) == K

ALL_REGULARIZERS = [
    # RegularizerConfig(
    #     name='reg-unit',
    #     compute_loss_term=Unitarity(),
    #     label_affinities=None,
    #     layer_affinities=['encoder'],
    # ),
    RegularizerConfig(
        name='reg-anchor',
        compute_loss_term=AngularAnchor(torch.tensor(RED, dtype=torch.float32)),
        label_affinities={'red': 1.0},
        layer_affinities=['bottleneck'],
        phase=('train', 'validate'),
    ),
    RegularizerConfig(
        name='reg-separate',
        compute_loss_term=Separate(power=100.0, shift=True),
        label_affinities=None,
        layer_affinities=['bottleneck'],
    ),
    # RegularizerConfig(
    #     name='reg-anti-anchor',
    #     compute_loss_term=AntiAnchor(torch.tensor(ANTI_RED, dtype=torch.float32)),
    #     label_affinities=None,
    #     layer_affinities=['bottleneck'],
    # ),
    RegularizerConfig(
        name='reg-planar',
        compute_loss_term=Planarity(),
        label_affinities={'vibrant': 1.0},
        layer_affinities=['bottleneck'],
        phase=('train', 'validate'),
    ),
]

## Data

Data is still a color cube, but this time we'll use an RGB cube for training instead of the HSV cube: since HSV has an over-representation of black and white, it would skew the label counts.

In [4]:
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset, RandomSampler
import numpy as np

from ex_color.data.color_cube import ColorCube
from ex_color.data.cube_dataset import CubeDataset, redness, vibrancy, stochastic_labels, exact_labels

# from ex_color.data.cube_sampler import vibrancy
from ex_color.data.cyclic import arange_cyclic


def prep_data() -> DataLoader:
    cube = ColorCube.from_rgb(
        r=np.linspace(0, 1, 10),
        g=np.linspace(0, 1, 10),
        b=np.linspace(0, 1, 10),
    )
    cube = cube.assign(
        red=redness(cube['color']) ** 10 * 0.07,
        vibrant=vibrancy(cube['color']) ** 100 * 0.02,
    )
    dataset = CubeDataset(cube)
    return DataLoader(
        dataset,
        batch_size=64,
        num_workers=4,
        sampler=RandomSampler(dataset, num_samples=len(dataset), replacement=True),
        collate_fn=stochastic_labels,
    )


def prep_val_data() -> DataLoader:
    cube = ColorCube.from_rgb(
        r=np.linspace(0, 1, 4),
        g=np.linspace(0, 1, 4),
        b=np.linspace(0, 1, 4),
    )
    cube = cube.assign(
        red=redness(cube['color']) == 1,
        vibrant=vibrancy(cube['color']) == 1,
    )
    dataset = CubeDataset(cube)
    return DataLoader(
        dataset,
        # batch_size=len(dataset),
        num_workers=4,
        collate_fn=exact_labels,
    )

## Training

Like in Ex 2.2, the model is trained with PyTorch Lightning, with regularizers applied as custom hooks.

Unlike earlier experiments, the model now has two nonlinear activation functions in the encoder and decoder, to allow the latent space to be warped more.

In [12]:
import wandb
from ex_color.callbacks import LabelProportionCallback
from ex_color.model import CNColorMLP


# @run.thither(env={'WANDB_API_KEY': wandb.Api().api_key})
async def train(
    dopesheet: Dopesheet,
    regularizers: list[RegularizerConfig],
    k_bottleneck: int,
) -> CNColorMLP:
    """Train the model with the given dopesheet and variant."""
    import lightning as L
    from lightning.pytorch.loggers import WandbLogger

    from ex_color.seed import set_deterministic_mode

    from utils.progress.lightning import LightningProgress

    log.info(f'Training with: {[r.name for r in regularizers]}')

    seed = 0
    set_deterministic_mode(seed)

    train_loader = prep_data()
    val_loader = prep_val_data()

    model = CNColorMLP(k_bottleneck, n_nonlinear=1)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.debug(f'Model initialized with {total_params:,} trainable parameters.')

    training_module = TrainingModule(model, dopesheet, torch.nn.MSELoss(), regularizers)

    logger = WandbLogger(experiment_name, project=project)

    trainer = L.Trainer(
        max_steps=len(dopesheet),
        callbacks=[
            LightningProgress(),
            LabelProportionCallback(prefix='labels', get_active_labels=lambda: training_module.active_labels),
        ],
        enable_checkpointing=False,
        enable_model_summary=False,
        # enable_progress_bar=True,
        check_val_every_n_epoch=10,
        logger=logger,
        log_every_n_steps=min(50, len(train_loader)),
    )

    print(f'max_steps: {len(dopesheet)}, train_loader length: {len(train_loader)}')

    # Train the model
    try:
        trainer.fit(training_module, train_loader, val_loader)
    finally:
        wandb.finish()
    # This is only a small model, so it's OK to return it rather than storing and loading a checkpoint remotely
    return model

## Inference utils

We wrap the model that we trained above in an `InferenceModule`. We won't be using its intervention features.


In [13]:
from ex_color.inference import InferenceModule


async def infer(
    model: CNColorMLP,
    test_data: Tensor,
) -> Tensor:
    """Run inference with the given model."""
    import lightning as L

    inference_module = InferenceModule(model, [])
    trainer = L.Trainer(
        enable_checkpointing=False,
        enable_model_summary=False,
        enable_progress_bar=True,
    )
    reconstructed_colors_batches = trainer.predict(
        inference_module,
        DataLoader(
            TensorDataset(test_data.reshape((-1, 3))),
            batch_size=64,
            collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
        ),
    )
    assert reconstructed_colors_batches is not None
    # Flatten the list of batches to a single list of tensors
    reconstructed_colors = [item for batch in reconstructed_colors_batches for item in batch]
    # Reshape to match input
    return torch.cat(reconstructed_colors).reshape(test_data.shape)

In [14]:
import torch
import numpy as np

from ex_color.inference import InferenceModule


async def infer_with_latent_capture(
    model: CNColorMLP,
    test_data: Tensor,
    layer_name: str = 'bottleneck',
) -> tuple[Tensor, Tensor]:
    module = InferenceModule(model, [], capture_layers=[layer_name])
    import lightning as L

    trainer = L.Trainer(enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False)
    batches = trainer.predict(
        module,
        DataLoader(
            TensorDataset(test_data.reshape((-1, 3))),
            batch_size=64,
            collate_fn=lambda batch: torch.stack([row[0] for row in batch], 0),
        ),
    )
    assert batches is not None
    preds = [item for batch in batches for item in batch]
    y = torch.cat(preds).reshape(test_data.shape)
    # Read captured activations as a flat [N, D] tensor
    latents = module.read_captured(layer_name)
    return y, latents

In [15]:
from IPython.display import clear_output

from ex_color.vis import plot_colors
from utils.nb import displayer_mpl


hsv_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=1 / 24),
    s=np.linspace(0, 1, 4),
    v=np.linspace(0, 1, 8),
).permute('svh')
x_hsv = torch.tensor(hsv_cube.rgb_grid, dtype=torch.float32)

hd_hsv_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=1 / 240),
    s=np.linspace(0, 1, 48),
    v=np.linspace(0, 1, 48),
)
hd_x_hsv = torch.tensor(hd_hsv_cube.rgb_grid, dtype=torch.float32)

rgb_cube = ColorCube.from_rgb(
    r=np.linspace(0, 1, 20),
    g=np.linspace(0, 1, 20),
    b=np.linspace(0, 1, 20),
)
x_rgb = torch.tensor(rgb_cube.rgb_grid, dtype=torch.float32)

with displayer_mpl(
    f'large-assets/ex-{nbid}-true-colors.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. The first slice shows a grayscale gradient from black to white; the last shows the fully-saturated color spectrum.""",
) as show:
    show(lambda: plot_colors(hsv_cube, title='True colors', colors=x_hsv.numpy()))

## Train

In [16]:
async with run():
    model_no_unit = await train(Dopesheet.from_csv(f'./ex-{nbid}-dopesheet.csv'), ALL_REGULARIZERS, K)

I 100.9 no.2.11:Training with: ['reg-anchor', 'reg-separate', 'reg-planar']


INFO: Seed set to 0


I 100.9 li.fa.ut.se:Seed set to 0
I 100.9 ex.se: PyTorch set to deterministic mode


INFO: GPU available: False, used: False


I 101.1 li.py.ut.ra:GPU available: False, used: False


INFO: TPU available: False, using: 0 TPU cores


I 101.1 li.py.ut.ra:TPU available: False, using: 0 TPU cores


INFO: HPU available: False, using: 0 HPUs


I 101.1 li.py.ut.ra:HPU available: False, using: 0 HPUs
max_steps: 1501, train_loader length: 16


Starting phase: Train


INFO: `Trainer.fit` stopped: `max_steps=1501` reached.


I 117.6 li.py.ut.ra:`Trainer.fit` stopped: `max_steps=1501` reached.
I 117.6 ex.ca.la:Label frequencies (n=89064): _any: 0.162%, red: 0.049%, vibrant: 0.113%


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇████
labels/_any,▁
labels/epoch/_any,█▃▃▃▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂
labels/epoch/red,█▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
labels/epoch/vibrant,▅▁▅▅█▄▄▃▄▃▃▃▃▄▅▅▅▅▅▅▅▅▆▆▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆
labels/red,▁
labels/vibrant,▁
train_loss,█▅▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_recon,█▁▁▁▁▁▁▁▁▁▁▁▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_reg-anchor,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,92.0
labels/_any,0.00162
labels/epoch/_any,0.00162
labels/epoch/red,0.00049
labels/epoch/vibrant,0.00113
labels/red,0.00049
labels/vibrant,0.00113
train_loss,6e-05
train_recon,6e-05
train_reg-anchor,0.0


In [17]:
from IPython.display import clear_output
from torch.nn import functional as F

from ex_color.vis import plot_colors, plot_cube_series


interventions = []
y_hsv = await infer(model_no_unit, x_hsv)
hd_y_hsv = await infer(model_no_unit, hd_x_hsv)
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-pred-colors-no-intervention.png',
    alt_text="""Plot showing four slices of the HSV cube, titled "{title}". Nominally, each slice has constant saturation, but varies in value (brightness) from top to bottom, and in hue from left to right. Each color value is represented as a square patch of that color. The outer portion of the patches shows the color as reconstructed by the model; the inner portion shows the true (input) color.""",
) as show:
    show(
        lambda: plot_colors(
            hsv_cube,
            title='Predicted colors · no intervention',
            colors=y_hsv.numpy(),
            colors_compare=x_hsv.numpy(),
        )
    )

per_color_loss = F.mse_loss(hd_y_hsv, hd_x_hsv, reduction='none').mean(dim=-1)
loss_cube = hd_hsv_cube.assign('MSE', per_color_loss.numpy().reshape(hd_hsv_cube.shape))
max_loss = per_color_loss.max().item()
median_loss = per_color_loss.median().item()

with displayer_mpl(
    f'large-assets/ex-{nbid}-loss-colors-no-intervention.png',
    alt_text=f"""Line chart showing loss per color, titled "{{title}}". Y-axis: mean square error, ranging from zero to {max_loss:.2g}. X-axis: hue.""",
) as show:
    show(
        lambda: plot_cube_series(
            loss_cube.permute('hsv')[:, -1:, :: (loss_cube.shape[2] // -5)],
            loss_cube.permute('svh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
            loss_cube.permute('vsh')[:, -1:, :: -(loss_cube.shape[0] // -3)],
            title='Reconstruction error · no intervention',
            var='MSE',
            figsize=(12, 3),
        )
    )
print(f'Max loss: {max_loss:.2g}')
print(f'Median MSE: {median_loss:.2g}')

Max loss: 0.0015
Median MSE: 4.3e-05


In [18]:
from IPython.display import clear_output

from ex_color.vis import plot_latent_grid_3d

y_rgb, h_rgb = await infer_with_latent_capture(model_no_unit, x_rgb, 'bottleneck')
clear_output()

with displayer_mpl(
    f'large-assets/ex-{nbid}-latents-no-intervention.png',
    alt_text="""Three spherical plots, titled "{title}". Each plot shows a vibrant collection of colored circles or balls scattered over the surface of a sphere.""",
) as show:
    show(
        lambda theme: plot_latent_grid_3d(
            h_rgb,
            y_rgb,
            x_rgb,
            title='Latents · no intervention',
            dims=[(1, 0, 2), (1, 2, 0), (3, 2, 0)],
            dot_radius=10,
            theme=theme,
        )
    )

## Conclusion
