# Superposition by embedding bottleneck

This is an experiment to squash higher-dimensional data into a lower-dimensional embedding space. We'll start with color: RGB values (3 dimensions) ranging from 0..1. If we compress them into a 2D embedding, we should expect to see superposition, with directions interpretable as they would be in a classic color wheel: three primary color directions (RGB) spaced 120° apart.


In [1]:
from __future__ import annotations

In [2]:
import logging
from utils.logging import SimpleLoggingConfig

logging_config = SimpleLoggingConfig().info('notebook', 'utils', 'mini', 'ex_color')
logging_config.apply()

# This is the logger for this notebook
log = logging.getLogger('notebook')

## Simple MLP with bottleneck

We'll train a simple 2-layer MLP with low-dimensional bottleneck to map RGB values (inputs like [1.0, 0.0, 0.0] for red) to RGB values. This would force colors into an embedding space where we expect to see the superposition effect.

In [3]:
import torch
import torch.nn as nn

E = 4

class ColorMLP(nn.Module):
    def __init__(self, normalize_bottleneck=False):
        super().__init__()
        # RGB input (3D) → hidden layer → bottleneck → hidden layer → RGB output
        self.encoder = nn.Sequential(
            nn.Linear(3, 16),
            nn.GELU(),
            # nn.Linear(16, 16),
            # nn.GELU(),
            nn.Linear(16, E),  # Our critical bottleneck!
        )

        self.decoder = nn.Sequential(
            nn.Linear(E, 16),
            nn.GELU(),
            # nn.Linear(16, 16),
            # nn.GELU(),
            nn.Linear(16, 3),
            nn.Sigmoid(),  # Keep RGB values in [0,1]
        )

        self.normalize = normalize_bottleneck

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # Get our bottleneck representation
        bottleneck = self.encoder(x)

        # Optionally normalize to unit vectors (like nGPT)
        if self.normalize:
            norm = torch.norm(bottleneck, dim=1, keepdim=True)
            bottleneck = bottleneck / (norm + 1e-8)  # Avoid division by zero

        # Decode back to RGB
        output = self.decoder(bottleneck)
        return output, bottleneck

## Training machinery

In [4]:
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from torch import Tensor
import torch.optim as optim

from mini.temporal.timeline import State


@dataclass
class InferenceResult:
    outputs: Tensor
    latents: Tensor

    def detach(self):
        return InferenceResult(self.outputs.detach(), self.latents.detach())

    def clone(self):
        return InferenceResult(self.outputs.clone(), self.latents.clone())

    def cpu(self):
        return InferenceResult(self.outputs.cpu(), self.latents.cpu())


@runtime_checkable
class LossCriterion(Protocol):
    def __call__(self, data: Tensor, res: InferenceResult) -> Tensor: ...


@runtime_checkable
class SpecialLossCriterion(LossCriterion, Protocol):
    def forward(self, model: ColorMLP, data: Tensor) -> InferenceResult | None: ...


@dataclass(eq=False, frozen=True)
class Event:
    name: str
    step: int
    model: ColorMLP
    timeline_state: State
    optimizer: optim.Optimizer


@dataclass(eq=False, frozen=True)
class PhaseEndEvent(Event):
    validation_data: Tensor
    inference_result: InferenceResult


class EventHandler[T](Protocol):
    def __call__(self, event: T) -> None: ...


class EventBinding[T]:
    """A class to bind events to handlers."""

    def __init__(self, event_name: str):
        self.event_name = event_name
        self.handlers: list[tuple[str, EventHandler[T]]] = []

    def add_handler(self, event_name: str, handler: EventHandler[T]) -> None:
        self.handlers.append((event_name, handler))

    def emit(self, event_name: str, event: T) -> None:
        for name, handler in self.handlers:
            if name == event_name:
                handler(event)


class EventHandlers:
    """A simple event system to allow for custom callbacks."""

    phase_start: EventBinding[Event]
    pre_step: EventBinding[Event]
    action: EventBinding[Event]
    phase_end: EventBinding[PhaseEndEvent]

    def __init__(self):
        self.phase_start = EventBinding[Event]('phase-start')
        self.pre_step = EventBinding[Event]('pre-step')
        self.action = EventBinding[Event]('action')
        self.phase_end = EventBinding[PhaseEndEvent]('phase-end')

In [5]:
from typing import Iterable, Iterator
from torch.utils.data import DataLoader
import torch.optim as optim

from mini.temporal.dopesheet import Dopesheet
from mini.temporal.timeline import Timeline
from utils.progress import RichProgress


def reiterate[T](it: Iterable[T]) -> Iterator[T]:
    """
    Iterates over an iterable indefinitely.

    When the iterable is exhausted, it starts over from the beginning. Unlike
    `itertools.cycle`, yielded values are not cached — so each iteration may be
    different.
    """
    while True:
        yield from it


def train_color_model(  # noqa: C901
    model: ColorMLP,
    datasets: dict[str, tuple[DataLoader, Tensor]],
    dopesheet: Dopesheet,
    loss_criteria: dict[str, LossCriterion | SpecialLossCriterion],
    event_handlers: EventHandlers | None = None,
):
    if event_handlers is None:
        event_handlers = EventHandlers()

    # --- Validate inputs ---
    # Check if all phases in dopesheet have corresponding data
    dopesheet_phases = dopesheet.phases
    missing_data = dopesheet_phases - set(datasets.keys())
    if missing_data:
        raise ValueError(f'Missing data for dopesheet phases: {missing_data}')

    # Check if 'lr' is defined in the dopesheet properties
    if 'lr' not in dopesheet.props:
        raise ValueError("Dopesheet must define the 'lr' property column.")
    # --- End Validation ---

    timeline = Timeline(dopesheet)
    optimizer = optim.Adam(model.parameters(), lr=0)
    device = next(model.parameters()).device

    data_iterators = {
        phase_name: iter(reiterate(dataloader))  #
        for phase_name, (dataloader, _) in datasets.items()
    }

    total_steps = len(timeline)

    with RichProgress(total=total_steps, description='Training Steps') as pbar:
        for step in range(total_steps):
            # Get state *before* advancing timeline for this step's processing
            current_state = timeline.state
            current_phase_name = current_state.phase

            # Assuming TensorDataset yields a tuple with one element
            (batch,) = next(data_iterators[current_phase_name])

            # --- Event Handling ---
            event_template = {
                'step': step,
                'model': model,
                'timeline_state': current_state,
                'optimizer': optimizer,
            }

            if current_state.is_phase_start:
                event = Event(name=f'phase-start:{current_phase_name}', **event_template)
                event_handlers.phase_start.emit(event.name, event)
                event_handlers.phase_start.emit('phase-start', event)

            for action in current_state.actions:
                event = Event(name=f'action:{action}', **event_template)
                event_handlers.action.emit(event.name, event)
                event_handlers.action.emit('action', event)

            event = Event(name='pre-step', **event_template)
            event_handlers.pre_step.emit('pre-step', event)

            # --- Training Step ---
            # ... (get data, update LR, zero grad, forward pass, calculate loss, backward, step) ...

            current_lr = current_state.props['lr']
            for param_group in optimizer.param_groups:
                param_group['lr'] = current_lr

            optimizer.zero_grad()

            outputs, latents = model(batch.to(device))
            current_results = InferenceResult(outputs, latents)

            total_loss = torch.tensor(0.0, device=device)
            losses_dict: dict[str, float] = {}
            for name, criterion in loss_criteria.items():
                weight = current_state.props.get(name, 0.0)
                if weight == 0:
                    continue

                if isinstance(criterion, SpecialLossCriterion):
                    # Special criteria might run on their own data (like Anchor)
                    # or potentially use the current batch (depends on implementation).
                    # The forward method gets the model and the *current batch*
                    special_results = criterion.forward(model, batch)
                    if special_results is None:
                        continue
                    term_loss = criterion(batch, special_results)
                else:
                    term_loss = criterion(batch, current_results)

                total_loss += term_loss * weight
                losses_dict[name] = term_loss.item()

            if total_loss > 0:
                total_loss.backward()
                optimizer.step()
            # --- End Training Step ---

            # --- Post-Step Event Handling ---
            if current_state.is_phase_end:
                # Trigger phase-end for the *current* phase
                _, validation_data = datasets[current_phase_name]
                with torch.no_grad():
                    val_outputs, val_latents = model(validation_data.to(device))
                event = PhaseEndEvent(
                    name=f'phase-end:{current_phase_name}',
                    **event_template,
                    validation_data=validation_data,
                    inference_result=InferenceResult(val_outputs, val_latents),
                )
                event_handlers.phase_end.emit(event.name, event)
                event_handlers.phase_end.emit('phase-end', event)
            # --- End Event Handling ---

            # Update progress bar
            pbar.update(
                metrics={
                    'PHASE': current_phase_name,
                    'lr': f'{current_lr:.6f}',
                    'loss': f'{total_loss.item():.4f}',
                    **{name: f'{lt:.4f}' for name, lt in losses_dict.items()},
                },
            )

            # Advance timeline *after* processing the current step
            if step < total_steps:  # Avoid stepping past the end
                timeline.step()

    log.info('Training finished!')

## Visualization

In [6]:
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from torch import Tensor

from utils.nb import save_fig


class PhasePlotter:
    """Event handler to plot latent space at the end of each phase."""

    def __init__(self, dim_pairs: list[tuple[int, int]] | None = None):
        from utils.nb import displayer

        # Store (phase_name, end_step, data, result) - data comes from event now
        self.history: list[tuple[str, int, Tensor, InferenceResult]] = []
        self.display = displayer()
        self.dim_pairs = dim_pairs or [(0, 1), (0, 2)]

    # Expect PhaseEndEvent specifically
    def __call__(self, event: PhaseEndEvent):
        """Handle phase-end events."""
        if not isinstance(event, PhaseEndEvent):
            raise TypeError(f'Expected PhaseEndEvent, got {type(event)}')

        # TODO: Don't assume device = CPU
        # TODO: Split this class so that the event handler is separate from the plotting, and so the plotting can happen locally with @run.hither
        phase_name = event.timeline_state.phase
        end_step = event.step
        phase_dataset = event.validation_data
        inference_result = event.inference_result

        log.info(f'Plotting end of phase: {phase_name} at step {end_step} using provided results.')

        # Append to history
        self.history.append((phase_name, end_step, phase_dataset.cpu(), inference_result.cpu()))

        # Plotting logic remains the same as it already expected CPU tensors
        fig = self._plot_phase_history()
        self.display(
            HTML(
                save_fig(
                    fig,
                    'assets/ex-1.5-color-phase-history.png',
                    alt_text='Visualizations of latent space at the end of each curriculum phase.',
                )
            )
        )

    def _plot_phase_history(self):
        num_phases = len(self.history)
        plt.style.use('dark_background')
        if num_phases == 0:
            fig, ax = plt.subplots()
            fig.set_facecolor('#333')
            ax.set_facecolor('#222')
            ax.text(0.5, 0.5, 'Waiting...', ha='center', va='center')
            return fig

        fig, axes = plt.subplots(
            num_phases, len(self.dim_pairs), figsize=(5 * len(self.dim_pairs), 5 * num_phases), squeeze=False
        )
        fig.set_facecolor('#333')

        for row_idx, (phase_name, end_step, data, res) in enumerate(self.history):
            _latents = res.latents.numpy()
            _colors = data.numpy()

            for col_idx, (dim1, dim2) in enumerate(self.dim_pairs):
                ax = axes[row_idx, col_idx]
                ax.set_facecolor('#222')
                ax.scatter(_latents[:, dim1], _latents[:, dim2], c=_colors, s=50, alpha=0.7)
                if col_idx == 0:
                    ax.set_ylabel(
                        f'Phase: {phase_name}\n(End Step: {end_step})',
                        fontsize='medium',
                        rotation=0,
                        labelpad=40,
                        verticalalignment='center',
                    )
                if row_idx == 0:
                    ax.set_title(f'Dims {dim1} vs {dim2}')
                ax.set_xlabel(f'Dim {dim1}')
                if col_idx != 0:
                    ax.set_ylabel(f'Dim {dim2}')
                else:
                    ax.yaxis.set_label_coords(-0.2, 0.5)
                    ax.set_yticks([])
                ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
                ax.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
                ax.add_patch(Circle((0, 0), 1, fill=False, linestyle='--', color='gray', alpha=0.3))
                ax.set_aspect('equal')

        fig.tight_layout()
        return fig

## Curriculum

In order to get a somewhat predictable orientation in latent space, we'll use curriculum learning:

| Phase | Training Data              | Regularization                       |
| ----- | -------------------------- | ------------------------------------ |
| 1     | Primary & secondary colors | Separation, normalization, planarity |
| 2     | All pure hues              | Anchor, normalization                |
| 3     | Full color space           | Anchor, normalization                |

With this curriculum, we expect to see a strong, well-formed color wheel with just the hues. As we add the darker tones, it should more-or-less retain its shape, but the darker tones should appear as smaller rings further from the _hue_ plane. Viewed from the side, it should start to resemble a dome.

The constraints are applied as regularization terms. These ultimately form part of the loss function, but are calculated from the latent HSV embeddings rather than the reconstructed RGB outputs.

- Separation: Penalizes close points, to encourage regular spacing of the primary and secondary colors.
- Normalization: This uses the L2 norm in two ways:
  - First calculates the L2 norm of each latent vector
  - Then applies mean squared error between these norms and 1.0
  - This encourages all latent vectors to lie on a unit sphere.
- Planar constriant: This is also L2-based:
  - Takes the squared values (i.e. L2 norm, not absolute values) of the remaining dimensions
  - This pushes the later dimensions toward zero, encouraging a planar representation in the first two dimensions.
- Anchor: Penalizes _specific_ points if they move after a certain step in training. This is used to keep the primary and secondary colors in place while the curriculum continues.

To reduce disruption to the embeddings learnt in the earlier phases, we'll use smoothly-varying parameters for the learning rate and the weights of each regularization term. The full curriculum is defined as keyframes in [a dopesheet](ex-1.5-dopesheet.csv).


In [7]:
from IPython.display import display, HTML, Markdown
from matplotlib import pyplot as plt

from mini.temporal.vis import group_properties_by_scale, plot_timeline, realize_timeline
from mini.temporal.dopesheet import Dopesheet
from mini.temporal.timeline import Timeline

dopesheet = Dopesheet.from_csv('ex-1.5-dopesheet.csv')
display(
    Markdown(f"""
## Parameter schedule
{dopesheet.to_markdown()}
""")
)

timeline = Timeline(dopesheet)
history_df = realize_timeline(timeline)
keyframes_df = dopesheet.as_df()

groups = group_properties_by_scale(keyframes_df[dopesheet.props])
fig, ax = plot_timeline(history_df, keyframes_df, groups)
display(
    HTML(
        save_fig(
            fig,
            'assets/ex-1.5-color-timeline.png',
            alt_text='Line chart showing the hyperparameter schedule over time.',
        )
    )
)


## Parameter schedule
|   STEP | PHASE               | ACTION   |      lr |   loss-recon |   reg-separate |   reg-planar |   reg-norm |   reg-anchor |   data-fraction |
|-------:|:--------------------|:---------|--------:|-------------:|---------------:|-------------:|-----------:|-------------:|----------------:|
|      0 | Primary & secondary |          |         |          1   |            0   |          0.2 |            |              |                 |
|   1200 |                     |          |         |          0.8 |            0.3 |              |            |              |                 |
|   1800 |                     |          |         |              |                |          0.4 |       0.1  |              |                 |
|   3000 | All hues            | anchor   |         |              |            0   |              |       0.25 |         0    |            0    |
|   3350 |                     |          |   0.01  |              |                |              |            |              |                 |
|   6500 |                     |          |         |          0.8 |                |          0   |            |              |                 |
|   8600 |                     |          |         |              |                |              |            |         0.3  |                 |
|  10000 | Full color space    |          |         |              |                |              |            |              |            0.25 |
|  10500 |                     |          |         |              |                |              |            |              |                 |
|  13000 |                     |          |         |          1   |                |              |            |         0.1  |            1    |
|  20000 |                     |          |   0.001 |              |                |              |            |         0.75 |                 |


I 4.8 ut.nb:   Figure saved: 'assets/ex-1.5-color-timeline.png'


In [8]:
from torch import linalg as LA

from ex_color.data.color_cube import ColorCube
from ex_color.data.cyclic import arange_cyclic


def objective(fn):
    """Adapt loss function to look like a regularizer"""

    def wrapper(data: Tensor, res: InferenceResult) -> Tensor:
        return fn(data, res.outputs)

    return wrapper


def unitary(data: Tensor, res: InferenceResult) -> Tensor:
    """Regularize latents to have unit norm (vectors of length 1)"""
    norms = LA.vector_norm(res.latents, dim=-1)
    return torch.mean((norms - 1.0) ** 2)


def planarity(data: Tensor, res: InferenceResult) -> Tensor:
    """Regularize latents to be planar in the first two channels (so zero in other channels)"""
    return torch.mean(res.latents[:, 2:] ** 2)


class Separate(LossCriterion):
    def __init__(self, channels: tuple[int, ...] = (0, 1)):
        self.channels = channels

    def __call__(self, data: Tensor, res: InferenceResult) -> Tensor:
        """Regularize latents to be separated from each other in first two channels"""
        # Get pairwise differences in the first two dimensions
        points = res.latents[:, self.channels]  # [B, C]
        diffs = points.unsqueeze(1) - points.unsqueeze(0)  # [B, B, C]

        # Calculate squared distances
        sq_dists = torch.sum(diffs**2, dim=-1)  # [B, B]

        # Remove self-distances (diagonal)
        mask = 1.0 - torch.eye(sq_dists.shape[0], device=sq_dists.device)
        masked_sq_dists = sq_dists * mask

        # Encourage separation by minimizing inverse distances (stronger repulsion between close points)
        epsilon = 1e-6  # Prevent division by zero
        return torch.mean(1.0 / (masked_sq_dists + epsilon))


class Anchor(SpecialLossCriterion):
    """Regularize latents to be close to their position in the reference phase"""

    ref_data: Tensor
    _ref_latents: Tensor | None = None

    def __init__(self, ref_data: Tensor):
        self.ref_data = ref_data
        self._ref_latents = None
        log.info(f'Anchor initialized with reference data shape: {ref_data.shape}')

    def forward(self, model: ColorMLP, data: Tensor) -> InferenceResult | None:
        """Run the *stored reference data* through the *current* model."""
        # Note: The 'data' argument passed by the training loop for SpecialLossCriterion
        # is the *current training batch*, which we IGNORE here.
        # We only care about running our stored _ref_data through the model.
        device = next(model.parameters()).device
        ref_data = self.ref_data.to(device)

        outputs, latents = model(ref_data)
        return InferenceResult(outputs, latents)

    def __call__(self, data: Tensor, special: InferenceResult) -> Tensor:
        """Calculates loss between current model's latents (for ref_data) and the stored reference latents."""
        if self._ref_latents is None:
            # This means on_anchor hasn't been called yet, so the anchor loss is zero.
            # This prevents errors during the very first phase before the anchor point is set.
            log.debug('Anchor.__call__ invoked before reference latents captured. Returning zero loss.')
            return torch.tensor(0.0, device=special.latents.device)
        ref_latents = self._ref_latents.to(special.latents.device)
        return torch.mean((special.latents - ref_latents) ** 2)

    def on_anchor(self, event: Event):
        # Called when the 'anchor' event is triggered
        log.info(f'Capturing anchor latents via Anchor.on_anchor at step {event.step}')

        device = next(event.model.parameters()).device
        ref_data = self.ref_data.to(device)

        with torch.no_grad():
            _, latents = event.model(ref_data)
        self._ref_latents = latents.detach().cpu()
        log.info(f'Anchor state captured internally. Ref data: {ref_data.shape}, Ref latents: {latents.shape}')

In [9]:
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from ex_color.data.cube_sampler import DynamicWeightedRandomBatchSampler, vibrancy
from ex_color.data.filters import levels


primary_cube = ColorCube.from_hsv(h=arange_cyclic(step_size=1 / 6), s=np.ones(1), v=np.ones(1))
primary_tensor = torch.tensor(primary_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
primary_dataset = TensorDataset(primary_tensor)
primary_loader = DataLoader(primary_dataset, batch_size=len(primary_tensor))

full_cube = ColorCube.from_hsv(
    h=arange_cyclic(step_size=10 / 360),
    s=np.linspace(0, 1, 10),
    v=np.linspace(0, 1, 10),
)
full_tensor = torch.tensor(full_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)
full_dataset = TensorDataset(full_tensor)
full_sampler = DynamicWeightedRandomBatchSampler(
    bias=full_cube.bias.flatten(),
    batch_size=256,
    steps_per_epoch=100,
)
vibrancy_weights = vibrancy(full_cube).flatten()
full_loader = DataLoader(full_dataset, batch_sampler=full_sampler)

rgb_cube = ColorCube.from_rgb(
    r=np.linspace(0, 1, 10),
    g=np.linspace(0, 1, 10),
    b=np.linspace(0, 1, 10),
)
rgb_tensor = torch.tensor(rgb_cube.rgb_grid.reshape(-1, 3), dtype=torch.float32)


def update_sampler_weights(event: Event):
    frac = event.timeline_state.props['data-fraction']
    # When the fraction is near zero, in_low is almost 1 — which means "scale everything down to 0 except for 1"
    # When the fraction is 0.5, in_low and out_low are both 0, so the weights are unchanged
    # When the fraction is 1, in_low is 0 and out_low is 1, so the weights are all scaled to 1
    in_low = np.interp(frac, [0, 0.5], [0.99, 0])
    out_low = np.interp(frac, [0.5, 1], [0, 1])
    full_sampler.weights = levels(vibrancy_weights, in_low=in_low, out_low=out_low)


class ModelRecorder(EventHandler):
    """Event handler to record model parameters."""

    history: list[tuple[int, dict[str, Tensor]]]
    """A list of tuples (step, state_dict) where state_dict is a copy of the model's state dict."""

    def __init__(self):
        self.history = []

    def __call__(self, event: Event):
        # It's crucial to get a *copy* of the state dict and move it to the CPU
        # so we don't hold onto GPU memory or track gradients unnecessarily.
        model_state = {k: v.cpu().clone() for k, v in event.model.state_dict().items()}
        self.history.append((event.step, model_state))
        log.debug(f'Recorded model state at step {event.step}')


recorder = ModelRecorder()

# Phase -> (train loader, validation tensor)
datasets: dict[str, tuple[DataLoader, Tensor]] = {
    'Primary & secondary': (primary_loader, primary_tensor),
    'All hues': (full_loader, rgb_tensor),
    'Full color space': (full_loader, rgb_tensor),
}

model = ColorMLP(normalize_bottleneck=False)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log.info(f'Model initialized with {total_params:,} trainable parameters.')

event_handlers = EventHandlers()
event_handlers.pre_step.add_handler('pre-step', recorder)
event_handlers.pre_step.add_handler('pre-step', update_sampler_weights)

plotter = PhasePlotter(dim_pairs=[(0, 1), (0, 2), (0, 3)])
event_handlers.phase_end.add_handler('phase-end', plotter)

reg_anchor = Anchor(ref_data=primary_tensor)
event_handlers.action.add_handler('action:anchor', reg_anchor.on_anchor)

history = train_color_model(
    model,
    datasets,
    dopesheet,
    loss_criteria={
        'loss-recon': objective(nn.MSELoss()),
        'reg-separate': Separate((0, 1)),
        'reg-planar': planarity,
        'reg-norm': unitary,
        'reg-anchor': reg_anchor,
    },
    event_handlers=event_handlers,
)

I 5.2 no:      Model initialized with 263 trainable parameters.
I 5.2 no:      Anchor initialized with reference data shape: torch.Size([6, 3])


I 10.3 no:     Plotting end of phase: Primary & secondary at step 2999 using provided results.
I 10.5 ut.nb:  Figure saved: 'assets/ex-1.5-color-phase-history.png'


I 10.5 no:     Capturing anchor latents via Anchor.on_anchor at step 3000
I 10.5 no:     Anchor state captured internally. Ref data: torch.Size([6, 3]), Ref latents: torch.Size([6, 4])
I 30.9 no:     Plotting end of phase: All hues at step 9999 using provided results.
I 31.4 ut.nb:  Figure saved: 'assets/ex-1.5-color-phase-history.png'
I 60.4 no:     Plotting end of phase: Full color space at step 20000 using provided results.
I 61.3 ut.nb:  Figure saved: 'assets/ex-1.5-color-phase-history.png'
I 61.3 no:     Training finished!


## Latent space evolution

In [10]:
import numpy as np


def eval_latent_history(
    recorder: ModelRecorder,
    rgb_tensor: Tensor,
):
    """Evaluate the latent space for each step in the recorder's history."""
    # Create a new model instance
    from utils.progress import RichProgress
    model = ColorMLP(normalize_bottleneck=False)

    latent_history: list[tuple[int, np.ndarray]] = []
    # Iterate over the recorded history
    for step, state_dict in RichProgress(recorder.history, description='Evaluating latents'):
        # Load the model state dict
        model.load_state_dict(state_dict)
        model.eval()
        with torch.no_grad():
            # Get the latents for the RGB tensor
            _, latents = model(rgb_tensor.to(next(model.parameters()).device))
            latents = latents.cpu().numpy()
            latent_history.append((step, latents))
    return latent_history


latent_history = eval_latent_history(recorder, rgb_tensor)

In [11]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import imageio_ffmpeg
from matplotlib import rcParams

from utils.progress import RichProgress

rcParams['animation.ffmpeg_path'] = imageio_ffmpeg.get_ffmpeg_exe()


def animate_latent_evolution(
    history: list[tuple[int, np.ndarray]],
    colors: np.ndarray,
    dim_pair: tuple[int, int] = (0, 1),
    interval=1 / 30,
):
    """Create an animation of the latent space evolution."""
    plt.style.use('dark_background')
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.set_xlim(-1.5, 1.5)
    ax.set_ylim(-1.5, 1.5)
    ax.set_aspect('equal')

    fig.patch.set_facecolor('#333')
    ax.patch.set_facecolor('#222')

    model = ColorMLP(normalize_bottleneck=False)
    data = rgb_tensor.to(next(model.parameters()).device)
    colors = data.cpu().numpy()

    step, latents = history[0]
    scatter = ax.scatter(latents[:, dim_pair[0]], latents[:, dim_pair[1]], c=colors, s=50, alpha=0.7)
    title = ax.set_title(f'Step {step}')

    def update(frame: int):
        frame += 1  # first frame already drawn
        step, latents = history[frame]
        scatter.set_offsets(latents[:, dim_pair])
        title.set_text(f'Step {step}')
        return (scatter, title)

    ani = animation.FuncAnimation(fig, update, frames=len(history) - 1, interval=interval * 1000, blit=True)
    return fig, ani


stride = 10
fig, ani = animate_latent_evolution(latent_history[::stride], rgb_tensor.cpu().numpy(), dim_pair=(0, 1))

video_file = 'assets/ex-1.5-latent-evolution.mp4'
with RichProgress(total=len(latent_history) // stride, description='Rendering video') as pbar:
    ani.save(
        video_file,
        fps=30,
        extra_args=['-vcodec', 'libx264'],
        progress_callback=lambda i, n: pbar.update(1),
    )
plt.close(fig)

from random import randint
from IPython.display import HTML

cache_buster = randint(1, 1_000_000)

HTML(
    f"""
    <video width="600" height="600" controls loop>
        <source src="{video_file}?v={cache_buster:d}" type="video/mp4">
        Your browser does not support the video tag.
    </video>
    """
)