# Experiment 2.1: Intervention lobes

In this series of experiments, we shall explore the effects of intervening on latent activations. Having structured the latent space (see Ex 1.x), it should just be a matter of transforming latent embeddings that are closely aligned to the anchored concepts.

We draw inspiration from shaders in computer graphics: BSDFs compute the output energy given: 1. an input light direction, and 2. the viewing direction, relative to the surface normal. Our interventions are similar: we have 1. a subject concept vector, and 2. activation vectors. If we treat the subject vector as analogous to a light source and acivation vectors as analogous to viewing directions, we may build on a wealth of established techniques.

Here we define our intervention as a BSDF-like function:

$$e' = f(s,e)$$

Where $e$ is an embedding vector, $s$ is the subject, and $e'$ is the modified embedding. In fact $s$ need not be a (directional) vector; it could be other geometric features of our embedding space, such as a subspace defined by multiple basis vectors.

In [21]:
nb_id = '2.1'

In [None]:
from numpy.typing import NDArray
from torch._tensor import Tensor


from matplotlib.axes import Axes
import numpy as np
import torch

# from matplotlib.axes import Axes
from matplotlib.projections.polar import PolarAxes
from torch import Tensor

from ex_color.intervention.intervention import Intervention


def sample_idf(idf: Intervention, n=360, phase=0.0) -> tuple[NDArray, NDArray, NDArray, NDArray]:  # noqa: F821
    # Input angles θ_in: 0..2π (exclude endpoint to avoid duplicate vertex)
    thetas_in: Tensor = torch.linspace(phase, 2 * torch.pi, steps=n + 1, dtype=torch.float32)[:-1]

    # Unit circle directions, shape [n, 2]
    unit: Tensor = torch.stack((torch.cos(thetas_in), torch.sin(thetas_in)), dim=-1)

    # Apply intervention (idf expects Tensors); disable grad for safety
    with torch.no_grad():
        out: Tensor = idf(unit)  # [n, 2]
        # Assume the IDF computes effect strength as falloff(dist(x))
        r_mag = idf.falloff(idf.dist(unit))  # [n]
        # Compute difference of distances between before and after (however the IDF measures distance)
        # r_mag = torch.abs(idf.dist(unit) - idf.dist(out)) # [n]

    # Convert outputs to polar coordinates
    y, x = out[..., 1], out[..., 0]
    theta_out = torch.atan2(y, x)  # [-π, π]
    r_out = torch.linalg.norm(out, dim=-1)

    return (
        theta_out.detach().cpu().numpy(),
        r_out.detach().cpu().numpy(),
        thetas_in.detach().cpu().numpy(),
        r_mag.detach().cpu().numpy(),
    )


def filled_series(ax: Axes, xs: NDArray, ys: NDArray, color: str | None, alpha=0.3, **kwargs):
    ax.fill(
        np.concatenate([[0], xs, [0]]),
        np.concatenate([[0], ys, [0]]),
        color=color,
        alpha=alpha,
        zorder=0,
    )
    ax.plot(xs, ys, color=color, **kwargs)


def diff_series(ax: Axes, xs1: NDArray, xs2: NDArray, ys1: NDArray, ys2: NDArray, label: str | None = None, **kwargs):
    # Draw line segments between series 1 and 2
    for x1, x2, y1, y2 in zip(xs1, xs2, ys1, ys2, strict=True):
        ax.plot([x1, x2], [y1, y2], **kwargs)

    # Only add the label once
    if label:
        ax.plot([], [], label=label, **kwargs)


def draw_intervention_slice(ax: Axes | PolarAxes, idf: Intervention):
    """
    Plot a 2D slice of an intervention function on a polar axes.

    The angular coordinate corresponds to the direction of a unit input vector.
    Two curves are drawn:
      - Transformed: the output vector converted to polar (θ_out, r_out)
      - Falloff: the magnitude of the intervention plotted against input θ

    Args:
        ax: A PolarAxes instance to draw into.
        idf: The intervention function to plot. Will be called with a tensor of [B,E] where E=2.
    """
    theta_out, r_out, thetas_in, r_mag = sample_idf(idf, 360, phase=1e-7)

    # Plot on polar axes

    # Pre-intervention activations
    ax.plot(
        thetas_in,
        [1.0] * len(thetas_in),
        color='hotpink',
        linewidth=1.0,
        label='Input',
    )

    # Post-intervention activations
    filled_series(
        ax,
        theta_out,
        r_out,
        color='#1f77b4',
        linewidth=2.0,
        label='Transformed',
        alpha=0.3 if isinstance(ax, PolarAxes) else 0.0,
    )

    # Magnitude of intervention
    filled_series(
        ax,
        thetas_in,
        r_mag,
        color='#ff7f0e',
        linewidth=2.0,
        label='Magnitude',
        alpha=0.3 if isinstance(ax, PolarAxes) else 0.0,
    )

    # Differences
    theta_out, r_out, thetas_in, _ = sample_idf(idf, 360 // 10, phase=1e-7)
    diff_series(
        ax,
        thetas_in,
        theta_out,
        np.ones_like(thetas_in),
        r_out,
        color='white',
        alpha=0.6,
        linewidth=0.5,
        marker='o',
        markersize=2.0,
        markeredgecolor='none',
        markerfacecolor='white',
        label='Offset',
    )

    # Customize polar plot
    if isinstance(ax, PolarAxes):
        ax.set_theta_zero_location('N')  # 0° at top (perfect alignment)
        ax.set_theta_direction(-1)  # Clockwise
        ax.set_thetalim(0, np.pi)  # Only show 0 to π (hemisphere)

        # Add angle labels
        ax.set_thetagrids([0, 30, 60, 90, 120, 150, 180], ['', '', '', '', '', '', ''], ha='left')

        # Set radial limit to comfortably contain all data and the unit radius
        max_r = max(r_out.max(), r_mag.max(), 1.0)
        ax.set_rmax(max(1.0, max_r) * 1.1)

    else:
        ax.set_xlim(0, np.pi)
        ax.set_ylim(0, max(r_out.max(), r_mag.max(), 1.0) * 1.1)

## Suppression

In [23]:
from math import sin, pi
from typing import cast, Annotated, override

from annotated_types import Ge, Le
import matplotlib.pyplot as plt
from matplotlib.projections.polar import PolarAxes
from pydantic import validate_call

from ex_color.intervention.falloff import Falloff, Linear, Power
from ex_color.intervention.suppression import Suppression
from utils.nb import displayer_img
from utils.plt import configure_matplotlib

configure_matplotlib()


class Bounded(Falloff):
    @validate_call(config={'arbitrary_types_allowed': True})
    def __init__(
        self,
        inner_term: Falloff,
        lower: Annotated[float, [Ge(0), Le(1)]],
        eps=1e-8,
    ):
        self.inner_term = inner_term
        self.lower = lower
        self.eps = eps

    @override
    def __call__(self, alignment: Tensor) -> Tensor:
        if self.lower < self.eps:
            return self.inner_term(alignment)
        if self.lower > 1 - self.eps:
            return alignment

        scale = 1 - self.lower
        shifted = (alignment - self.lower) / scale
        shifted = self.inner_term(shifted)
        return torch.where(alignment > self.lower, shifted, torch.zeros_like(alignment))

    def __repr__(self):
        return f'{type(self).__name__}({self.inner_term}, {self.lower:.2g})'

    def __str__(self):
        return f'{self.inner_term} | d>{self.lower:.2g}'


class Angular(Falloff):
    @validate_call(config={'arbitrary_types_allowed': True})
    def __init__(
        self,
        inner_term: Falloff,
    ):
        self.inner_term = inner_term

    @override
    def __call__(self, cos_sim: Tensor) -> Tensor:  # cos_sim is a batch, shape [B]
        # Convert cosine distance to angular alignment for linear transformations
        # 1 -> 1  (0 degrees away from subject)
        # 0.5 -> 1/3  (60 degrees away from subject)
        # 0 -> 0  (90 degrees away from subject)
        thetas = 1 - (torch.acos(cos_sim) / (pi / 2))
        return torch.cos((1 - self.inner_term(thetas)) * (pi / 2))


falloffs = [
    Linear(0.5),
    Bounded(Linear(0.5), sin(pi * 1 / 6)),
    Power(2),
    Bounded(Power(2), sin(pi * 1 / 6)),
]


with displayer_img(
    f'large-assets/ex-{nb_id}-suppression.png',
    alt_text="Grid of semicircular polar plots showing the effects of suppression on activations. Each plot shows two lobes: an orange one indicating the magnitude of the intervention, and a blue one showing the transformed activation space. The direction being intervened on (the 'subject') is always 'up', so the orange 'magnitude' lobes are also oriented upwards. The blue 'transformed' lobes are more circular but have a depression in the top, showing that the directions more aligned with the subject are squashed/attenuated by the intervention.",
) as show:
    fig = plt.figure(figsize=(1 + 4.5 * len(falloffs), 6), constrained_layout=True)

    axes = []
    for i, falloff in enumerate(falloffs):
        ax = cast(PolarAxes, fig.add_subplot(1, len(falloffs), i + 1, axes_class=PolarAxes))
        ax.spines['polar'].set_color(c='gray')
        ax.grid(True, color='#444', linewidth=0.5)

        idf = Suppression(
            subject=torch.tensor([1, 0], dtype=torch.float32),  # North
            falloff=falloff,
            amount=1,
            renormalize=False,
            bidirectional=False,
        )
        draw_intervention_slice(ax, idf)
        ax.set_title(str(idf.falloff), pad=15)
        axes.append(ax)

    # Add a single legend for all axes, horizontal, just above the bottom margin, no border/background
    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc='lower center',
        ncol=len(labels),
        frameon=False,
        bbox_to_anchor=(0.5, -0.07),
        bbox_transform=fig.transFigure,
        fontsize='medium',
    )

    fig.suptitle('Intervention lobes: Suppression')

    show(fig)
    plt.close(fig)

## Repulsion

In [24]:
from math import sin, pi
from typing import cast

from annotated_types import Gt, Lt
import matplotlib.pyplot as plt
from matplotlib.projections.polar import PolarAxes

from ex_color.intervention.repulsion import Repulsion
from utils.plt import configure_matplotlib

configure_matplotlib()


class Mapped(Falloff):
    @validate_call(config={'arbitrary_types_allowed': True})
    def __init__(
        self,
        inner_term: Falloff,
        a: Annotated[float, [Ge(0), Lt(1)]],
        b: Annotated[float, [Gt(0), Le(1)]],
        eps=1e-8,
    ):
        assert a < b
        self.inner_term = inner_term
        self.a = a
        self.b = b
        self.eps = eps

    @override
    def __call__(self, alignment: Tensor) -> Tensor:  # alignment is a batch, shape [B]
        shifted = (alignment - self.a) / (1 - self.a)
        shifted = self.inner_term(shifted)
        shifted = shifted * (self.b - self.a) + self.a
        return torch.where(alignment > self.a, shifted, alignment)

    def __repr__(self):
        return f'{type(self).__name__}({self.inner_term}, {self.a:.2g}, {self.b:.2g})'

    def __str__(self):
        return f'{self.inner_term} | d∈[{self.a:.2g},{self.b:.2g}]'


falloffs = [
    Linear(0.5),
    # Power(2),
    Mapped(
        Linear(0.5),
        sin(pi * 1 / 6),
        sin(pi * 2 / 6),
    ),
    Mapped(
        Power(2.0),
        sin(pi * 1 / 6),
        sin(pi * 2 / 6),
    ),
]


with displayer_img(
    f'large-assets/ex-{nb_id}-repulsion.png',
    alt_text="Grid of semicircular polar plots showing the effects of repulsion on activations. Each plot shows two lobes: an orange one indicating the magnitude of the intervention, and a blue one showing the transformed activation space. The direction being intervened on (the 'subject') is always 'up', so the orange 'magnitude' lobes are also oriented upwards. The blue 'transformed' lobes are more circular but have a chunk taken out of the top, showing that the directions more aligned with the subject are rotated/pushed away by the intervention.",
) as show:
    fig = plt.figure(figsize=(1 + 4.5 * len(falloffs), 6), constrained_layout=True)

    axes = []
    for i, falloff in enumerate(falloffs):
        ax = cast(PolarAxes, fig.add_subplot(1, len(falloffs), i + 1, axes_class=PolarAxes))
        ax.spines['polar'].set_color(c='gray')
        ax.grid(True, color='#444', linewidth=0.5)

        idf = Repulsion(
            subject=torch.tensor([1, 0], dtype=torch.float32),  # North
            falloff=falloff,
        )
        draw_intervention_slice(ax, idf)
        ax.set_title(str(idf.falloff), pad=15)
        axes.append(ax)

    # Add a single legend for all axes, horizontal, just above the bottom margin, no border/background
    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc='lower center',
        ncol=len(labels),
        frameon=False,
        bbox_to_anchor=(0.5, -0.07),
        bbox_transform=fig.transFigure,
        fontsize='medium',
    )

    fig.suptitle('Intervention lobes: Repulsion')

    show(fig)
    plt.close(fig)