In [1]:
nbid = '1.1.1'

In [7]:
from typing import cast
from matplotlib.axes import Axes
import numpy as np
import matplotlib.pyplot as plt

from ex_color.data.color_cube import ColorCube
from utils.nb import displayer_mpl


def set_alpha(colors: np.ndarray, alpha: float) -> np.ndarray:
    """Set alpha channel of colors."""
    assert colors.ndim == 2 and colors.shape[1] in (3, 4), 'colors must be [N, 3] or [N, 4]'
    if colors.shape[1] == 3:
        colors = np.concatenate([colors, np.ones((colors.shape[0], 1), dtype=colors.dtype)], axis=1)
    colors[:, 3] = alpha
    return colors


def plot_rgb_cube_orthographic(rgb_grid: np.ndarray, *, point_size: int = 30):
    """Plot RGB cube with diagonal (black→white) vertical (white on top)."""
    assert rgb_grid.ndim == 4 and rgb_grid.shape[-1] == 3, 'rgb_grid must be [R,G,B,3]'
    R, G, B, _ = rgb_grid.shape
    # Normalized coordinate for each lattice point
    r = np.linspace(0, 1, R)
    g = np.linspace(0, 1, G)
    b = np.linspace(0, 1, B)
    rr, gg, bb = np.meshgrid(r, g, b, indexing='ij')
    coords = np.stack([rr, gg, bb], axis=-1).reshape(-1, 3)
    colors = rgb_grid.reshape(-1, 3)

    # sides = ('front',)
    sides = ('front', 'back')
    fig, axs = plt.subplots(1, len(sides), figsize=(4 * len(sides), 4), sharey=True, squeeze=False)
    axs = axs.flatten()
    for i, (side, ax) in enumerate(zip(sides, axs, strict=True)):
        ax = cast(Axes, ax)

        # Build an orthonormal basis with diag as vertical axis
        diag = np.array([1.0, 1.0, 1.0])
        e3 = diag / np.linalg.norm(diag)  # vertical (black→white)
        e1 = np.array([0.0, -1.0, 1.0] if side == 'front' else [0.0, 1.0, -1.0])
        e1 -= e1 @ e3 * e3
        e1 /= np.linalg.norm(e1)
        e2 = np.cross(e3, e1)

        projected = coords @ np.stack([e1, e2, e3], axis=1)
        x: np.ndarray
        y: np.ndarray
        z: np.ndarray  # noqa: E702
        x, y, z = projected.T  # z is vertical but we will use y= z for 2D plot

        # Sort so that lower (darker) points do not occlude brighter ones
        order = np.argsort(y)
        ax.scatter(x[order], z[order], c=colors[order], s=point_size)
        ax.set_xlabel('Hue (⊥ to value)')
        if i == 0:
            ax.set_ylabel('Value')
        ax.margins(0.1)
        ax.xaxis.set_ticks([])
        ax.yaxis.set_ticks([])
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.patch.set_alpha(1)
        ax.set_aspect('equal')

        # ax.set_title(side.capitalize())

    # fig.suptitle('RGB cube (true colors)')
    return fig


with displayer_mpl(
    f'large-assets/ex-{nbid}-rgb-cube.png',
    alt_text="""Two colorful, orthographic views of the RGB cube, rotated such that black is at the bottom and white is at the top. The other corners of the cube are arranged around the middle in two bands, one higher and one lower. The left plot, titled 'font', has in its top band cyan, yellow, and magenta, and in its bottom band green, red. The right plot, titled 'back', has in its top band magenta and cyan, and in its bottom band red, blue, and green.""",
) as show:
    show(
        lambda: plot_rgb_cube_orthographic(
            ColorCube.from_rgb(
                np.linspace(0, 1, 8),
                np.linspace(0, 1, 8),
                np.linspace(0, 1, 8),
            ).rgb_grid,
            point_size=250,
        )
    )