# Experiment 2.1: Intervention lobes

In this series of experiments, we shall explore the effects of intervening on latent activations. Having structured the latent space (see Ex 1.x), it should just be a matter of transforming latent embeddings that are closely aligned to the anchored concepts.

We draw inspiration from shaders in computer graphics: BSDFs compute the output energy given: 1. an input light direction, and 2. the viewing direction, relative to the surface normal. Our interventions are similar: we have 1. a subject concept vector, and 2. activation vectors. If we treat the subject vector as analogous to a light source and acivation vectors as analogous to viewing directions, we may build on a wealth of established techniques.

Here we define our intervention as a BSDF-like function:

$$e' = f(s,e)$$

Where $e$ is an embedding vector, $s$ is the subject, and $e'$ is the modified embedding. In fact $s$ need not be a (directional) vector; it could be other geometric features of our embedding space, such as a subspace defined by multiple basis vectors.

In [1]:
nb_id = '2.1'

In [2]:
from matplotlib.axes import Axes
import numpy as np
import torch

# from matplotlib.axes import Axes
from matplotlib.projections.polar import PolarAxes
from torch import Tensor

from ex_color.intervention.intervention import Intervention


def filled_series(ax: Axes, xs, ys, color, alpha=0.3, **kwargs):
    ax.fill(xs, ys, color=color, alpha=alpha, zorder=0)
    ax.plot(xs, ys, color=color, **kwargs)


def draw_intervention_slice(ax: PolarAxes, idf: Intervention):
    """
    Plot a 2D slice of an intervention function on a polar axes.

    The angular coordinate corresponds to the direction of a unit input vector.
    Two curves are drawn:
      - Transformed: the output vector converted to polar (θ_out, r_out)
      - Falloff: the magnitude of the intervention plotted against input θ

    Args:
        ax: A PolarAxes instance to draw into.
        idf: The intervention function to plot. Will be called with a tensor of [B,E] where E=2.
    """
    # Resolution (angles around the unit circle)
    n = 360

    # Input angles θ_in: 0..2π (exclude endpoint to avoid duplicate vertex)
    thetas_in: Tensor = torch.linspace(0.0, 2 * torch.pi, steps=n + 1, dtype=torch.float32)[:-1]

    # Unit circle directions, shape [n, 2]
    unit: Tensor = torch.stack((torch.cos(thetas_in), torch.sin(thetas_in)), dim=-1)

    # Apply intervention (idf expects Tensors); disable grad for safety
    with torch.no_grad():
        out: Tensor = idf(unit)  # [n, 2]
        # idf calls falloff(dist(activations)) internally by convention, but doesn't return the magnitudes
        r_mag: Tensor = idf.falloff(idf.dist(unit))  # [n]

    # Convert outputs to polar coordinates
    y, x = out[..., 1], out[..., 0]
    theta_out: Tensor = torch.atan2(y, x)  # [-π, π]
    r_out: Tensor = torch.linalg.norm(out, dim=-1)

    # Magnitude polygon encodes |out - in| along each input direction
    # delta: Tensor = out - unit
    # r_mag: Tensor = torch.linalg.norm(delta, dim=-1)  # [n]

    # Close all curves by repeating the first point at the end
    def close_torch(vals: Tensor) -> Tensor:
        return torch.cat([vals, vals[:1]], dim=0)

    theta_out_c = close_torch(theta_out)
    r_out_c = close_torch(r_out)
    thetas_in_c = close_torch(thetas_in)
    r_mag_c = close_torch(r_mag)

    # Plot on polar axes

    # Pre-intervention activations
    ax.plot(
        thetas_in_c.detach().cpu().numpy(),
        [1.0] * len(thetas_in_c),
        color='hotpink',
        linewidth=1.0,
        label='Input',
    )

    # Post-intervention activations
    filled_series(
        ax,
        theta_out_c.detach().cpu().numpy(),
        r_out_c.detach().cpu().numpy(),
        color='#1f77b4',
        linewidth=2.0,
        label='Transformed',
    )

    # Magnitude of intervention
    filled_series(
        ax,
        thetas_in_c.detach().cpu().numpy(),
        r_mag_c.detach().cpu().numpy(),
        color='#ff7f0e',
        linewidth=2.0,
        label='Magnitude',
    )

    # Customize polar plot
    ax.set_theta_zero_location('N')  # 0° at top (perfect alignment)
    ax.set_theta_direction(-1)  # Clockwise
    ax.set_thetalim(0, np.pi)  # Only show 0 to π (hemisphere)

    # Add angle labels
    ax.set_thetagrids(
        [0, 30, 60, 90, 120, 150, 180], ['1 (aligned)', '', '', '0 (orthogonal)', '', '', '-1 (opposing)'], ha='left'
    )

    # Set radial limit to comfortably contain all data and the unit radius
    max_r = float(torch.stack([r_out.max(), r_mag.max(), torch.tensor(1.0)]).max())
    ax.set_rmax(max(1.0, max_r) * 1.1)

    # ax.legend(loc='upper right', frameon=False, fontsize='small')

## Suppression

In [3]:
from typing import cast

import matplotlib.pyplot as plt
from matplotlib.projections.polar import PolarAxes

from ex_color.intervention.falloff import Linear, Power, Sinus, Polynomial, Angular, Exponential, Sigmoid
from ex_color.intervention.suppression import Suppression
from utils.nb import displayer_img
from utils.plt import configure_matplotlib

configure_matplotlib()


falloffs = [Linear(), Power(), Sinus(), Polynomial(), Angular(), Exponential(10), Sigmoid()]


with displayer_img(
    f'large-assets/ex-{nb_id}-suppression.png',
    alt_text="Grid of semicircular polar plots showing the effects of suppression on activations. Each plot shows two lobes: an orange one indicating the magnitude of the intervention, and a blue one showing the transformed activation space. The direction being intervened on (the 'subject') is always 'up', so the orange 'magnitude' lobes are also oriented upwards. The blue 'transformed' lobes are more circular but have a depression in the top, showing that the directions more aligned with the subject are squashed/attenuated by the intervention.",
) as show:
    fig = plt.figure(figsize=(1 + 4.5 * len(falloffs), 6), constrained_layout=True)

    axes = []
    for i, falloff in enumerate(falloffs):
        ax = cast(PolarAxes, fig.add_subplot(1, len(falloffs), i + 1, axes_class=PolarAxes))
        ax.set_aspect('equal')
        ax.spines['polar'].set_color(c='gray')
        ax.grid(True, color='#444', linewidth=0.5)

        idf = Suppression(
            subject=torch.tensor([1, 0], dtype=torch.float32),  # North
            falloff=falloff,
            amount=1,
            renormalize=False,
            bidirectional=False,
        )
        draw_intervention_slice(ax, idf)
        ax.set_title(repr(idf.falloff), pad=15)
        axes.append(ax)

    # Add a single legend for all axes, horizontal, just above the bottom margin, no border/background
    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc='lower center',
        ncol=len(labels),
        frameon=False,
        bbox_to_anchor=(0.5, -0.07),
        bbox_transform=fig.transFigure,
        fontsize='medium',
    )

    fig.suptitle('Intervention lobes: Suppression')

    show(fig)
    plt.close(fig)

## Repulsion

In [4]:
from typing import cast

import matplotlib.pyplot as plt
from matplotlib.projections.polar import PolarAxes

from ex_color.intervention.falloff import Linear, Power, Sinus, Polynomial, Angular, Exponential, Sigmoid
from ex_color.intervention.repulsion import Repulsion
from utils.plt import configure_matplotlib

configure_matplotlib()


falloffs = [Linear(), Power(), Sinus(), Polynomial(), Angular(), Exponential(10), Sigmoid()]


with displayer_img(
    f'large-assets/ex-{nb_id}-repulsion.png',
    alt_text="Grid of semicircular polar plots showing the effects of repulsion on activations. Each plot shows two lobes: an orange one indicating the magnitude of the intervention, and a blue one showing the transformed activation space. The direction being intervened on (the 'subject') is always 'up', so the orange 'magnitude' lobes are also oriented upwards. The blue 'transformed' lobes are more circular but have a chunk taken out of the top, showing that the directions more aligned with the subject are rotated/pushed away by the intervention.",
) as show:
    fig = plt.figure(figsize=(1 + 4.5 * len(falloffs), 6), constrained_layout=True)

    axes = []
    for i, falloff in enumerate(falloffs):
        ax = cast(PolarAxes, fig.add_subplot(1, len(falloffs), i + 1, axes_class=PolarAxes))
        ax.set_aspect('equal')
        ax.spines['polar'].set_color(c='gray')
        ax.grid(True, color='#444', linewidth=0.5)

        idf = Repulsion(
            subject=torch.tensor([1, 0], dtype=torch.float32),  # North
            falloff=falloff,
        )
        draw_intervention_slice(ax, idf)
        ax.set_title(repr(idf.falloff), pad=15)
        axes.append(ax)

    # Add a single legend for all axes, horizontal, just above the bottom margin, no border/background
    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc='lower center',
        ncol=len(labels),
        frameon=False,
        bbox_to_anchor=(0.5, -0.07),
        bbox_transform=fig.transFigure,
        fontsize='medium',
    )

    fig.suptitle('Intervention lobes: Repulsion')

    show(fig)
    plt.close(fig)