# Superposition by embedding bottleneck

This is an experiment to squash higher-dimensional data into a lower-dimensional embedding space. We'll start with color: RGB values (3 dimensions) ranging from 0..1. If we compress them into a 2D embedding, we should expect to see superposition, with directions interpretable as they would be in a classic color wheel: three primary color directions (RGB) spaced 120° apart.

## Training data

While we describe colors in terms of hue (the color wheel), saturation (color intensity), and value (brightness), RGB values combine these properties in a non-intuitive way.

Primary colors are fully-saturated, and are vectors of length 1:

| Color   |     R      |     G      |     B      |  Length  |
| ------- | :--------: | :--------: | :--------: | :------: |
| red     |    $1$     |    $0$     |    $0$     |   $1$    |
| green   |    $0$     |    $1$     |    $0$     |   $1$    |
| blue    |    $1$     |    $0$     |    $1$     |   $1$    |

Secondary colors are also fully-saturated, but have length $\sqrt2$:

| Color   |     R      |     G      |     B      |  Length  |
| ------- | :--------: | :--------: | :--------: | :------: |
| yellow  |    $1$     |    $1$     |    $0$     | $\sqrt2$ |
| cyan    |    $0$     |    $1$     |    $1$     | $\sqrt2$ |
| magenta |    $0$     |    $0$     |    $1$     | $\sqrt2$ |

Whereas grayscale values are fully _desaturated_ (have no defined hue), and have various lengths:

| Color   |     R      |     G      |     B      |  Length  |
| ------- | :--------: | :--------: | :--------: | :------: |
| black   |    $0$     |    $0$     |    $0$     |   $0$    |
| gray    | $\sqrt{⅓}$ | $\sqrt{⅓}$ | $\sqrt{⅓}$ |   $1$    |
| white   |    $1$     |    $1$     |    $1$     | $\sqrt3$ |

We could pick random RGB values, but even if the vectors were normalized, we wouldn't necessarily get data that compresses well to a 2D embedding. So we will use `colorsys` — a builtin Python module that lets us work with perceptual properties. For our superposition experiment, generating colors by varying hue while keeping saturation and value fixed gives us a much more intuitive and evenly distributed set of training examples that isolate the property we're most interested in studying.

Let's do a quick test to see if we can reproduce the primary and secondary RGB values just by varying the hue.

In [1]:
import colorsys
import numpy as np

# Create an array of evenly spaced hues (every 60 degrees)
hues = np.arange(0, 360, 60, dtype=float)

# Convert each HSV value to RGB
rgb_colors = np.array([(h,) + colorsys.hsv_to_rgb(h/360, 1, 1) for h in hues])

print(rgb_colors)

[[  0.   1.   0.   0.]
 [ 60.   1.   1.   0.]
 [120.   0.   1.   0.]
 [180.   0.   1.   1.]
 [240.   0.   0.   1.]
 [300.   1.   0.   1.]]


In [None]:
from colorsys import hsv_to_rgb, hls_to_rgb
from functools import reduce
from operator import mul
from typing import overload
import torch
from torch import Tensor
from torch.utils.data import Dataset
from jaxtyping import Float


# Seq = Sequence[float] | ndarray[tuple[int], dtype[floating]]
Seq = Float[torch.Tensor, ' T']


class ColorCube(Dataset):
    @overload
    def __init__(self, *, h: Seq, s: Seq, v: Seq): ...

    @overload
    def __init__(self, *, h: Seq, l: Seq, s: Seq): ...

    @overload
    def __init__(self, *, r: Seq, g: Seq, b: Seq): ...

    def __init__(
        self,
        *,
        h: Seq | None = None,
        s: Seq | None = None,
        v: Seq | None = None,
        l: Seq | None = None,
        r: Seq | None = None,
        g: Seq | None = None,
        b: Seq | None = None,
    ):
        if h is not None and s is not None and v is not None:
            self.a = h
            self.b = s
            self.c = v
            self.axes = 'vsh'
        elif h is not None and l is not None and s is not None:
            self.a = h
            self.b = l
            self.c = s
            self.axes = 'slh'
        elif r is not None and g is not None and b is not None:
            self.a = r
            self.b = g
            self.c = b
            self.axes = 'bgr'
        else:
            raise ValueError('Invalid parameters: must provide either HSV, HLS, or RGB values.')

        self.shape = (len(self.c), len(self.b), len(self.a), 3)

    def axis(self, axis: str) -> torch.Tensor:
        # Get the index of the specified axis
        idx = self.axes.index(axis.lower())
        if idx == 0:
            return self.c
        elif idx == 1:
            return self.b
        elif idx == 2:
            return self.a
        else:
            raise ValueError(f'Invalid axis: {axis}. Valid axes are: {self.axes}.')

    def as_cube(self, axes: str | None) -> torch.Tensor:
        if axes:
            axes = axes.lower()
            if not set(axes) == set(self.axes):
                raise ValueError(f'Invalid axes: must be a permutation of {self.axes}.')
        else:
            axes = self.axes

        # Get all RGB values
        all_rgb = self.__getitems__(list(range(0, len(self))))
        # Reshape to 4D tensor with dimensions [c, b, a, 3]
        cube = all_rgb.reshape(*self.shape)

        # Create a permutation mapping based on the provided axes
        permutation = [self.axes.index(axis) for axis in axes] + [3]
        return torch.permute(cube, permutation)

    def get_components(self, axes: str) -> tuple[Tensor, ...]:
        axes = axes.lower()
        return tuple(self.axis(ax) for ax in axes)

    def as_image(self, x: str, y: str) -> torch.Tensor:
        cube, _dims = self.as_cube(y + x)

        # Reshape to [height, width, 3]
        height = reduce(mul, cube.shape[: len(y)], 1)
        width = reduce(mul, cube.shape[len(y) : -1], 1)
        return cube.reshape(height, width, 3)


    def plot(self, components: str):
        import matplotlib.pyplot as plt
        # Create a figure with subplots
        color_cube = self.as_cube(components)
        dims = self.get_components(components)

        n_values = color_cube.size(0)
        fig, axes = plt.subplots(1, n_values, figsize=(15, 3))

        # Plot each slice of the cube (one for each value)
        for i in range(n_values):
            if n_values > 1:
                ax = axes[i]
                # Make all subplots share the same Y axis
                ax.sharey(axes[0])
            else:
                ax = axes

            ax.imshow(color_cube[i])
            ax.set_title(f'{components[0]} = {dims[0][i]:.2g}')


            # Add axes labels without cluttering the display
            if i == 0:
                ax.xaxis.set_ticks([0, color_cube.shape[2]-1])
                ax.xaxis.set_ticklabels([f'{dims[2][0]:.2g}', f'{dims[2][-1]:.2g}'])
                ax.set_xlabel(components[2])

                ax.yaxis.set_ticks([0, color_cube.shape[1]-1])
                ax.yaxis.set_ticklabels([f'{dims[1][0]:.2g}', f'{dims[1][-1]:.2g}'])
                ax.set_ylabel(components[1])
            else:
                # Hide the Y axis labels for other subplots
                ax.xaxis.set_visible(False)
                ax.yaxis.set_visible(False)

            # ax.axis('off')

        plt.suptitle(components.upper(), fontsize=16)
        # plt.axis('off')
        plt.show()

    def __len__(self) -> int:
        return len(self.a) * len(self.b) * len(self.c)

    def __getitem__(self, idx: int):
        stride_a = len(self.a)
        stride_b = len(self.b) * stride_a
        if self.axes == 'vsh':
            return torch.tensor(
                hsv_to_rgb(
                    self.a[idx % len(self.a)].item(),
                    self.b[(idx // stride_a) % len(self.b)].item(),
                    self.c[idx // stride_b % len(self.c)].item(),
                )
            )
        elif self.axes == 'slh':
            return torch.tensor(
                hls_to_rgb(
                    self.a[idx % len(self.a)].item(),
                    self.b[(idx // stride_a) % len(self.b)].item(),
                    self.c[idx // stride_b % len(self.c)].item(),
                )
            )
        elif self.axes == 'bgr':
            return torch.tensor(
                [
                    self.a[idx % len(self.a)].item(),
                    self.b[(idx // stride_a) % len(self.b)].item(),
                    self.c[idx // stride_b % len(self.c)].item(),
                ]
            )
        else:
            raise ValueError('Invalid axes configuration.')

    def __getitems__(self, indices: list[int]):
        return torch.stack(tuple(self[i] for i in indices))

cols = ColorCube(
    r=torch.linspace(0, 1, 5),
    g=torch.linspace(0, 1, 5),
    b=torch.linspace(0, 1, 5),
)
cols.plot('RGB')

cols = ColorCube(
    h=torch.arange(0, 1, 1 / 12),
    l=torch.linspace(0, 1, 5),
    s=torch.linspace(0, 1, 5),
)
cols.plot('SLH')

cols = ColorCube(
    h=torch.arange(0, 1, 1 / 12),
    s=torch.linspace(0, 1, 5),
    v=torch.linspace(0, 1, 5),
)
cols.plot('SVH')
