# Experiment 2.1: Intervention lobes

In this series of experiments, we shall explore the effects of intervening on latent activations. Having structured the latent space (see Ex 1.x), it should just be a matter of transforming latent embeddings that are closely aligned to the anchored concepts.

We draw inspiration from shaders in computer graphics: BSDFs compute the output energy given: 1. an input light direction, and 2. the viewing direction, relative to the surface normal. Our interventions are similar: we have 1. a subject concept vector, and 2. activation vectors. If we treat the subject vector as analogous to a light source and acivation vectors as analogous to viewing directions, we may build on a wealth of established techniques.

Here we define our intervention as a BSDF-like function:

$$e' = f(s,e)$$

Where $e$ is an embedding vector, $s$ is the subject, and $e'$ is the modified embedding. In fact $s$ need not be a (directional) vector; it could be other geometric features of our embedding space, such as a subspace defined by multiple basis vectors.

In [1]:
nb_id = '2.1'

### Intervention function sampling

In [None]:
from math import radians

import numpy as np
import torch
from matplotlib.axes import Axes
from matplotlib.projections.polar import PolarAxes
from numpy.typing import NDArray
from torch import Tensor

from ex_color.intervention.intervention import Intervention


def sample_idf(idf: Intervention, n=360, *, phase=0.0) -> tuple[NDArray, NDArray, NDArray, NDArray]:
    # Input angles θ_in: [0, 2π)
    thetas_in: Tensor = torch.linspace(phase, 2 * torch.pi + phase, steps=n + 1, dtype=torch.float32)[:-1]

    # Unit circle directions, shape [n, 2]
    unit: Tensor = torch.stack((torch.cos(thetas_in), torch.sin(thetas_in)), dim=-1)

    # Apply intervention (idf expects Tensors); disable grad for safety
    with torch.no_grad():
        out: Tensor = idf(unit)  # [n, 2]
        # Assume the IDF computes effect strength as falloff(dist(x))
        r_mag = idf.falloff(idf.dist(unit))  # [n]
        # Compute difference of distances between before and after (however the IDF measures distance)
        # r_mag = torch.abs(idf.dist(unit) - idf.dist(out)) # [n]

    # Convert outputs to polar coordinates
    y, x = out[..., 1], out[..., 0]
    theta_out = torch.atan2(y, x)  # [-π, π]
    # Wrap angles to be positive
    theta_out = (theta_out + 2 * torch.pi) % (2 * torch.pi)  # [0, 2π]
    r_out = torch.linalg.norm(out, dim=-1)

    return (
        theta_out.detach().cpu().numpy(),
        r_out.detach().cpu().numpy(),
        thetas_in.detach().cpu().numpy(),
        r_mag.detach().cpu().numpy(),
    )

### Special chart series

In [None]:
from typing import Literal


def wrapped_angular_diff(a: float, b: float) -> float:
    """Compute the angular difference between two angles a and b, wrapping around at 2π."""
    # Ensure 0 is considered close to 2pi
    diff = (b - a) % (2 * np.pi)
    return min(diff, 2 * np.pi - diff)


def filled_series(ax: Axes, xs: NDArray, ys: NDArray, *, color: str | None, alpha=0.3, **kwargs):
    span_x = wrapped_angular_diff(xs[0], xs[-1])
    if isinstance(ax, PolarAxes) and span_x < radians(2):
        # Close curve
        xs = np.concatenate([xs, [xs[0]]])
        ys = np.concatenate([ys, [ys[0]]])

    ax.fill(
        np.concatenate([[0], xs, [0]]),
        np.concatenate([[0], ys, [0]]),
        color=color,
        alpha=alpha,
        zorder=0,
    )
    ax.plot(xs, ys, color=color, **kwargs)


Shape = Literal['line', 'chord']


def diff_series(
    ax: Axes,
    xs1: NDArray,
    xs2: NDArray,
    ys1: NDArray,
    ys2: NDArray,
    *,
    shape: Shape,
    label: str | None = None,
    **kwargs,
):
    """Draw line segments between two series of points (xs1, ys1) and (xs2, ys2)."""
    # Split kwargs
    marker_kwargs = {k: v for k, v in kwargs.items() if k.startswith('marker')}
    other_kwargs = {k: v for k, v in kwargs.items() if not k.startswith('marker')}

    # Draw line segments between series 1 and 2
    for x1, x2, y1, y2 in zip(xs1, xs2, ys1, ys2, strict=True):
        if np.abs(x2 - x1) > np.pi:
            # Take the shortest path around the circle
            x1 += 2 * np.pi
        if shape == 'chord':
            # Draw a curve, like a chord diagram, to make it easier to see where the points go
            # Without this, rotations are hard to interpret because the lines have similar angles
            curve_length_x = wrapped_angular_diff(x1, x2)
            curve_power = 2.2
            curve_strength = 0.97 * (curve_length_x / np.pi) + 0.03
            xs = np.linspace(x1, x2, 100)
            ys = np.linspace(y1, y2, 100)
            # pull ys down in the middle
            yfrac = np.concatenate([np.linspace(1, 0, 50), np.linspace(0, 1, 50)])
            yfrac **= curve_power
            ys *= yfrac * curve_strength + 1 - curve_strength
        else:
            xs = [x1, x2]
            ys = [y1, y2]
        ax.plot(xs, ys, **other_kwargs)

    # Draw markers
    # # Starts
    # ax.plot(xs1, ys1, linestyle='', **marker_kwargs)
    # Ends
    ax.plot(xs2, ys2, linestyle='', **marker_kwargs)

    # Only add the label once
    if label:
        ax.plot([], [], label=label, **kwargs)

### Polar lobe charts

These charts show a 2D slice of functions with a polar projection. This helps to visualize the _shape_ of the intervention, although it's a bit hard to interpret the magnitude.

In [15]:
def draw_intervention_slice(ax: Axes | PolarAxes, idf: Intervention):
    """
    Plot a 2D slice of an intervention function on a polar axes.

    The angular coordinate corresponds to the direction of a unit input vector.
    Two curves are drawn:
      - Transformed: the output vector converted to polar (θ_out, r_out)
      - Falloff: the magnitude of the intervention plotted against input θ

    Args:
        ax: A PolarAxes instance to draw into.
        idf: The intervention function to plot. Will be called with a tensor of [B,E] where E=2.
    """
    theta_out, r_out, thetas_in, r_mag = sample_idf(idf, 360, phase=1e-7)

    # Plot on polar axes

    # Pre-intervention activations
    # ax.plot(
    #     thetas_in,
    #     np.ones_like(thetas_in),
    #     color='hotpink',
    #     linewidth=1.0,
    #     label='Input',
    # )

    # Post-intervention activations
    filled_series(
        ax,
        theta_out,
        r_out,
        color='#1f77b4',
        linewidth=2.0,
        label='Transformed',
        alpha=0.3 if isinstance(ax, PolarAxes) else 0.0,
    )

    # Magnitude of intervention
    filled_series(
        ax,
        thetas_in,
        r_mag,
        color='#ff7f0e',
        linewidth=2.0,
        label='Magnitude',
        alpha=0.3 if isinstance(ax, PolarAxes) else 0.0,
    )

    # Differences
    theta_out, r_out, thetas_in, _ = sample_idf(idf, 360 // 10)
    diff_series(
        ax,
        thetas_in,
        theta_out,
        np.ones_like(thetas_in),
        r_out,
        shape='line' if idf.type == 'linear' else 'chord',
        color='white',
        alpha=0.6,
        linewidth=0.5,
        marker='o',
        markersize=2.0,
        markeredgecolor='none',
        markerfacecolor='white',
        label='Offset',
    )

    # Customize polar plot
    if isinstance(ax, PolarAxes):
        ax.set_theta_zero_location('N')  # 0° at top (perfect alignment)
        # ax.set_theta_direction(-1)  # Clockwise
        # ax.set_thetalim(0, np.pi)  # Only show 0 to π (hemisphere)

        # Add angle labels
        ax.set_thetagrids([0, 30, 60, 90, 120, 150, 180], ['', '', '', '', '', '', ''], ha='left')

        # Set radial limit to comfortably contain all data and the unit radius
        max_r = max(r_out.max(), r_mag.max(), 1.0)
        ax.set_rmax(max(1.0, max_r) * 1.1)

    else:
        ax.set_xlim(0, np.pi)
        ax.set_ylim(0, max(r_out.max(), r_mag.max(), 1.0) * 1.1)

### Linear charts

These charts show the effects of the intervention as well. The input to the intervention is the alignment with the concept vector — so we'll use that as the x-axis. The choice of y-axis depends on the type of the intervention:

- For suppression, it's useful to see the magnitude of the intervention
- For repulsion, it's more useful to see the output of the mapping (i.e. the post-intervention alignment).


In [16]:
# Helpers for linear charts reused across figures
from math import cos, pi, sqrt

import numpy as np
import torch
from matplotlib.axes import Axes


def draw_mapping_linear(
    ax: Axes,
    mapping,  # Callable[[Tensor], Tensor]
    *,
    title: str | None = None,
    color: str = 'white',
    show_identity: bool = True,
    annotate_params: bool = True,
) -> None:
    """
    Draw a linear mapping y = f(x) for cosine similarity inputs.

    x is cosine similarity in [-1, 1]. y is mapping(x) in [-1, 1].
    """
    x = np.linspace(-1, 1, 400, dtype=np.float32)
    xt = torch.from_numpy(x)
    with torch.no_grad():
        y = mapping(xt).detach().cpu().numpy()

    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    setup_cosine_axes(ax, axis='both')

    if show_identity:
        ax.axline((0, 0), slope=1, color='gray', alpha=0.2, linewidth=1, linestyle='--')

    ax.plot(x, y, label='activation alignment', color=color, linewidth=2)

    # Optional parameter annotations (a vertical, b horizontal) when present
    if annotate_params:
        a = getattr(mapping, 'a', None)
        b = getattr(mapping, 'b', None)
        if isinstance(a, (float, int)):
            ax.axvline(float(a), color='hotpink', alpha=1.0, linewidth=1, linestyle='--', label='a')
            ax.text(
                float(a) + 0.02,
                ax.viewLim.ymin + 0.2,
                f'a = {float(a):.2g}',
                color='hotpink',
                fontsize='x-small',
                rotation=90,
            )
        if isinstance(b, (float, int)):
            ax.axhline(float(b), color='orange', alpha=1.0, linewidth=1, linestyle='--', label='b')
            ax.text(
                ax.viewLim.xmin + 0.2,
                float(b) + 0.02,
                f'b = {float(b):.2g}',
                color='orange',
                fontsize='x-small',
            )

    if title:
        ax.set_title(title)


def draw_suppression_strength(
    ax: Axes,
    falloff,  # Callable[[Tensor], Tensor]
    *,
    title: str | None = None,
    color: str = '#ff7f0e',
    annotate_threshold: bool = True,
) -> None:
    """
    Draw suppression amount vs cosine similarity.

    x: cosine similarity in [-1, 1]
    y: suppression amount in [0, 1] computed as falloff(alignment),
       where alignment = max(0, x) for unidirectional suppression.
    """
    x = np.linspace(-1, 1, 400, dtype=np.float32)
    alignment = np.clip(x, 0.0, 1.0).astype(np.float32)  # Only positive alignment contributes
    xt = torch.from_numpy(alignment)
    with torch.no_grad():
        y = falloff(xt).detach().cpu().numpy()

    ax.set_xlim(-1, 1)
    ax.set_ylim(0, max(1.0, float(np.max(y)) * 1.05))
    setup_cosine_axes(ax, axis='x')
    ax.set_ylabel('Suppression amount', fontsize='small', labelpad=10)

    ax.plot(x, y, label='amount', color=color, linewidth=2)

    # Threshold annotation for bounded falloffs defined over alignment
    if annotate_threshold:
        lower = getattr(falloff, 'lower', None)
        if isinstance(lower, (float, int)) and 0 <= float(lower) <= 1:
            cx = float(lower)
            ax.axvline(cx, color='hotpink', alpha=1.0, linewidth=1, linestyle='--', label='lower')
            ax.text(
                cx + 0.02,
                ax.viewLim.ymin + 0.2,
                f'lower = {cx:.2g}',
                color='hotpink',
                fontsize='x-small',
                rotation=90,
            )

    if title:
        ax.set_title(title)


def setup_cosine_axes(ax: Axes, axis: str = 'both') -> None:
    """
    Set cosine ticks/labels on axes for readability.

    axis: 'x' | 'y' | 'both'
    """
    # Major ticks at +-1, +-cos(30), +-cos(60), 0
    cos_values = np.array([-1, -cos(pi / 6), -cos(pi / 3), 0, cos(pi / 3), cos(pi / 6), 1.0])
    xlabels = np.array(['-1\nopposing', '', '', '0\northogonal', '', '', '1\naligned'])
    ylabels = np.array(['-1', '', '', '0', '', '', '1'])

    if axis in ('x', 'both'):
        ax.set_xticks(cos_values)
        ax.set_xticklabels(xlabels, fontsize='x-small')
        ax.set_xlabel('Cosine similarity (angle from subject)', fontsize='small', labelpad=10)

    if axis in ('y', 'both'):
        ax.set_yticks(cos_values)
        ax.set_yticklabels(ylabels, fontsize='x-small')
        ax.set_ylabel('Output cosine similarity', fontsize='small', labelpad=10)

    # Minor ticks at every 10 degrees
    cos_minor = np.cos(np.arange(0, 91, 10) * np.pi / 180)
    cos_minor = np.concatenate([-cos_minor[:-1], cos_minor])
    if axis in ('x', 'both'):
        ax.set_xticks(cos_minor, minor=True)
    if axis in ('y', 'both'):
        ax.set_yticks(cos_minor, minor=True)

    ax.grid(True, which='major', alpha=0.1)

## Suppression

This type of intervention is used to reduce the **magnitude** of embeddings that are aligned with a concept vector.

In [None]:
from math import sin, pi
from typing import cast, Annotated, override

from annotated_types import Ge, Le
import matplotlib.pyplot as plt
from matplotlib.projections.polar import PolarAxes
from pydantic import validate_call

from ex_color.intervention.falloff import Falloff, Linear, Power
from ex_color.intervention.suppression import Suppression
from utils.nb import displayer_img
from utils.plt import configure_matplotlib

configure_matplotlib()


class Bounded(Falloff):
    @validate_call(config={'arbitrary_types_allowed': True})
    def __init__(
        self,
        inner_term: Falloff,
        lower: Annotated[float, [Ge(0), Le(1)]],
        eps=1e-8,
    ):
        self.inner_term = inner_term
        self.lower = lower
        self.eps = eps

    @override
    def __call__(self, alignment: Tensor) -> Tensor:
        if self.lower < self.eps:
            return self.inner_term(alignment)
        if self.lower > 1 - self.eps:
            return alignment

        scale = 1 - self.lower
        shifted = (alignment - self.lower) / scale
        shifted = self.inner_term(shifted)
        return torch.where(alignment > self.lower, shifted, torch.zeros_like(alignment))

    def __repr__(self):
        return f'{type(self).__name__}({self.inner_term}, {self.lower:.2g})'

    def __str__(self):
        return f'{self.inner_term} | d>{self.lower:.2g}'


falloffs = [
    Linear(0.5),
    Bounded(Linear(0.5), sin(pi * 1 / 6)),
    Power(2),
    Bounded(Power(2), sin(pi * 1 / 6)),
]


with displayer_img(
    f'large-assets/ex-{nb_id}-suppression.png',
    alt_text="Grid of semicircular polar plots showing the effects of suppression on activations. Each plot shows two lobes: an orange one indicating the magnitude of the intervention, and a blue one showing the transformed activation space. The direction being intervened on (the 'subject') is always 'up', so the orange 'magnitude' lobes are also oriented upwards. The blue 'transformed' lobes are more circular but have a depression in the top, showing that the directions more aligned with the subject are squashed/attenuated by the intervention.",
) as show:
    # Two rows: polar slices (top) and linear suppression amount (bottom)
    n = len(falloffs)
    fig = plt.figure(figsize=(1 + 4.5 * n, 10), constrained_layout=True)

    axes = []
    linear_axes = []
    for i, falloff in enumerate(falloffs):
        ax = cast(PolarAxes, fig.add_subplot(2, n, i + 1, axes_class=PolarAxes))
        ax.spines['polar'].set_color(c='gray')
        ax.grid(True, color='#444', linewidth=0.5)

        idf = Suppression(
            subject=torch.tensor([1, 0], dtype=torch.float32),  # North
            falloff=falloff,
            amount=1,
            renormalize=False,
            bidirectional=False,
        )
        draw_intervention_slice(ax, idf)
        ax.set_title(str(idf.falloff), pad=15)
        axes.append(ax)

        # Linear suppression-strength chart
        lax = fig.add_subplot(2, n, n + i + 1)
        draw_suppression_strength(lax, falloff, title='amount vs cos(θ)')
        linear_axes.append(lax)

    # Single legend for all polar axes
    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc='lower center',
        ncol=len(labels),
        frameon=False,
        bbox_to_anchor=(0.5, -0.05),
        bbox_transform=fig.transFigure,
        fontsize='medium',
    )

    fig.suptitle('Intervention lobes: Suppression')

    show(fig)
    plt.close(fig)

## Repulsion

In [22]:
from math import sin, pi
from typing import cast, Annotated

from annotated_types import Ge, Gt, Le, Lt
import matplotlib.pyplot as plt
from matplotlib.projections.polar import PolarAxes

from ex_color.intervention.repulsion import Repulsion
from utils.plt import configure_matplotlib

configure_matplotlib()


class Mapped(Falloff):
    @validate_call(config={'arbitrary_types_allowed': True})
    def __init__(
        self,
        inner_term: Falloff,
        a: Annotated[float, [Ge(0), Lt(1)]],
        b: Annotated[float, [Gt(0), Le(1)]],
        eps=1e-8,
    ):
        assert a < b
        self.inner_term = inner_term
        self.a = a
        self.b = b
        self.eps = eps

    @override
    def __call__(self, alignment: Tensor) -> Tensor:  # alignment is a batch, shape [B]
        shifted = (alignment - self.a) / (1 - self.a)
        shifted = self.inner_term(shifted)
        shifted = shifted * (self.b - self.a) + self.a
        return torch.where(alignment > self.a, shifted, alignment)

    def __repr__(self):
        return f'{type(self).__name__}({self.inner_term}, {self.a:.2g}, {self.b:.2g})'

    def __str__(self):
        return f'{self.inner_term} | d∈[{self.a:.2g},{self.b:.2g}]'


falloffs = [
    Linear(0.5),
    # Power(2),
    Mapped(
        Linear(1),
        sin(pi * 1 / 6),
        sin(pi * 2 / 6),
    ),
    Mapped(
        Power(2.0),
        sin(pi * 1 / 6),
        sin(pi * 2 / 6),
    ),
]


with displayer_img(
    f'large-assets/ex-{nb_id}-repulsion.png',
    alt_text="Grid of semicircular polar plots showing the effects of repulsion on activations. Each plot shows two lobes: an orange one indicating the magnitude of the intervention, and a blue one showing the transformed activation space. The direction being intervened on (the 'subject') is always 'up', so the orange 'magnitude' lobes are also oriented upwards. The blue 'transformed' lobes are more circular but have a chunk taken out of the top, showing that the directions more aligned with the subject are rotated/pushed away by the intervention.",
) as show:
    # Two rows: polar slices (top) and linear mapping (bottom)
    n = len(falloffs)
    fig = plt.figure(figsize=(1 + 4.5 * n, 10), constrained_layout=True)

    axes = []
    linear_axes = []
    for i, falloff in enumerate(falloffs):
        ax = cast(PolarAxes, fig.add_subplot(2, n, i + 1, axes_class=PolarAxes))
        ax.spines['polar'].set_color(c='gray')
        ax.grid(True, color='#444', linewidth=0.5)

        idf = Repulsion(
            subject=torch.tensor([1, 0], dtype=torch.float32),  # North
            falloff=falloff,
        )
        draw_intervention_slice(ax, idf)
        ax.set_title(str(idf.falloff), pad=15)
        axes.append(ax)

        # Linear mapping chart using the same mapping function ("falloff" here)
        lax = fig.add_subplot(2, n, n + i + 1)
        draw_mapping_linear(lax, falloff, title='mapping')
        linear_axes.append(lax)

    # Single legend for all polar axes
    handles, labels = axes[0].get_legend_handles_labels()
    legend = fig.legend(
        handles,
        labels,
        loc='lower center',
        ncol=len(labels),
        frameon=False,
        bbox_to_anchor=(0.5, -0.05),
        bbox_transform=fig.transFigure,
        fontsize='medium',
    )

    fig.suptitle('Intervention lobes: Repulsion')

    show(fig)
    plt.close(fig)

In [19]:
from math import cos, pi, sqrt

import torch
from torch import Tensor


class BezierMapper(Falloff):
    @validate_call
    def __init__(
        self,
        a: Annotated[float, [Ge(0), Lt(1)]],
        b: Annotated[float, [Gt(0), Le(1)]],
        start_slope: float = 1.0,  # Aligned with unmapped leadup
        end_slope: float = 0.0,  # Flat
        control_distance: float = 1 / sqrt(2),  # Relative to intersection point
    ):
        assert a < b <= 1
        self.a = a
        self.b = b

        # Find intersection of the two tangent lines
        # Line 1: y - a = start_slope * (x - a)  =>  y = start_slope * (x - a) + a
        # Line 2: y - b = end_slope * (x - 1)    =>  y = end_slope * (x - 1) + b
        # At intersection: start_slope * (x - a) + a = end_slope * (x - 1) + b

        if abs(start_slope - end_slope) < 1e-8:
            # Parallel lines - use midpoint as fallback
            intersection_x = (a + 1) / 2
            intersection_y = (a + b) / 2
        else:
            intersection_x = (a * (start_slope - 1) + b - end_slope) / (start_slope - end_slope)
            intersection_y = start_slope * (intersection_x - a) + a

        intersection = torch.tensor([intersection_x, intersection_y], dtype=torch.float32)

        # Define the 4 control points for cubic Bézier
        self.P0 = torch.tensor([a, a], dtype=torch.float32)
        self.P3 = torch.tensor([1.0, b], dtype=torch.float32)

        # P1: distance from P0 towards intersection, scaled by control_distance
        direction_to_intersection = intersection - self.P0
        self.P1 = self.P0 + control_distance * direction_to_intersection

        # P2: distance from P3 towards intersection, scaled by control_distance
        direction_to_intersection = intersection - self.P3
        self.P2 = self.P3 + control_distance * direction_to_intersection

    def bezier_point(self, t: Tensor) -> Tensor:
        """Evaluate cubic Bézier curve at parameter t"""
        one_minus_t = 1 - t

        term0 = (one_minus_t**3)[:, None] * self.P0[None, :]
        term1 = (3 * one_minus_t**2 * t)[:, None] * self.P1[None, :]
        term2 = (3 * one_minus_t * t**2)[:, None] * self.P2[None, :]
        term3 = (t**3)[:, None] * self.P3[None, :]

        return term0 + term1 + term2 + term3

    def bezier_x(self, t: Tensor) -> Tensor:
        """Get x-coordinate of Bézier curve at parameter t"""
        return self.bezier_point(t)[:, 0]  # Changed from [..., 0]

    def bezier_y(self, t: Tensor) -> Tensor:
        """Get y-coordinate of Bézier curve at parameter t"""
        return self.bezier_point(t)[:, 1]  # Changed from [..., 1]

    def solve_for_t(self, x: Tensor, max_iters: int = 10) -> Tensor:
        """Solve for parameter t such that bezier_x(t) = x using Newton's method"""
        # Initial guess: linear interpolation
        t = (x - self.a) / (1 - self.a)
        t = torch.clamp(t, 0.01, 0.99)  # Avoid endpoints

        for _ in range(max_iters):
            # Newton step: t_new = t - f(t)/f'(t)
            # where f(t) = bezier_x(t) - target_x

            # Enable gradients for automatic differentiation
            t_var = t.clone().requires_grad_(True)
            x_pred = self.bezier_x(t_var)
            error = x_pred - x

            # Compute derivative dx/dt
            dx_dt = torch.autograd.grad(x_pred.sum(), t_var, create_graph=False)[0]

            # Newton update (be careful with division by zero)
            dt = error / (dx_dt + 1e-8)
            t = t - dt
            t = torch.clamp(t, 0.0, 1.0)

            # Check convergence
            if torch.max(torch.abs(error)) < 1e-6:
                break

        return t

    @override
    def __call__(self, alignment: Tensor) -> Tensor:
        result = alignment.clone()

        # Only apply Bézier mapping for alignment > a
        mask = alignment > self.a
        if mask.any():
            x_vals = alignment[mask]

            # Solve for t parameters
            t_vals = self.solve_for_t(x_vals)

            # Get corresponding y values
            y_vals = self.bezier_y(t_vals)

            result[mask] = y_vals

        return result

    def __repr__(self):
        return f'BezierMapper(a={self.a:.2g}, b={self.b:.2g})'

    def __str__(self):
        return f'Bézier[{self.a:.2g}→{self.b:.2g}]'


class FastBezierMapper(BezierMapper):
    def __init__(self, *args, lookup_resolution: int = 1000, **kwargs):
        super().__init__(*args, **kwargs)

        # Precompute lookup table
        t_vals = torch.linspace(0, 1, lookup_resolution, dtype=torch.float32)
        bezier_points = self.bezier_point(t_vals)

        # Ensure contiguous storage to avoid searchsorted warning
        self.x_lookup = bezier_points[:, 0].contiguous()  # x coordinates
        self.y_lookup = bezier_points[:, 1].contiguous()  # y coordinates

    def interpolate_1d(self, x_query: Tensor) -> Tensor:
        """1D linear interpolation using lookup table"""
        # Find insertion points for x_query in x_lookup
        indices = torch.searchsorted(self.x_lookup, x_query, right=False)

        # Clamp indices to valid range
        indices = torch.clamp(indices, 1, len(self.x_lookup) - 1)

        # Get surrounding points
        x0 = self.x_lookup[indices - 1]
        x1 = self.x_lookup[indices]
        y0 = self.y_lookup[indices - 1]
        y1 = self.y_lookup[indices]

        # Linear interpolation: y = y0 + (y1 - y0) * (x - x0) / (x1 - x0)
        t = (x_query - x0) / (x1 - x0 + 1e-8)  # Add small epsilon to avoid division by zero
        y_interp = y0 + t * (y1 - y0)

        return y_interp

    @override
    def __call__(self, alignment: Tensor) -> Tensor:
        result = alignment.clone()
        mask = alignment > self.a

        if mask.any():
            x_vals = alignment[mask]

            # Use interpolation on lookup table instead of Newton's method
            y_vals = self.interpolate_1d(x_vals)

            result[mask] = y_vals

        return result


def draw_bezier_handle(cp1: Tensor, cp2: Tensor, *, color: str, **kwargs):
    cx, cy = zip(cp1, cp2, strict=True)
    ax.plot(
        cx,
        cy,
        color=color,
        linewidth=1.5,
        zorder=0,
        **kwargs,
    )
    ax.plot(
        cx,
        cy,
        color=color,
        linestyle=' ',
        marker='o',
        markersize=4,
        markerfacecolor='black',
        markeredgewidth=1.2,
        **kwargs,
    )


# Create a smooth curve from a=0.5 to b=0.866 with slope=1 at start
bezier_falloff = FastBezierMapper(
    a=cos(pi / 3),
    b=cos(pi / 6),
    start_slope=1.0,  # Continue identity slope
    end_slope=0.0,  # Gentle approach to endpoint
)

title = 'Mapping function (Bézier)'
with displayer_img(
    f'large-assets/ex-{nb_id}-bezier-mapping.png',
    alt_text=f"""A graph titled '{title}' showing how input cosine similarity values are transformed by the intervention. The x-axis shows 'Cosine similarity (angle from subject)' ranging from 180° (opposing) on the left to 0° (aligned) on the right, passing through 90° (orthogonal) in the middle. The y-axis shows 'Output cosine similarity' with the same angular scale from 180° at bottom to 0° at top. A diagonal gray 'Identity' line runs from bottom-left to top-right, representing no transformation. Two dashed reference lines mark key parameters: a vertical magenta line 'a' at approximately 60° and an orange horizontal line 'b' at approximately 30°. The main white curve labeled 'activation alignment' follows the identity line until reaching the 'a' threshold, then smoothly curves to the right, demonstrating how the Bézier function maps highly aligned inputs (near 0°) to the target output value 'b' (30°). This illustrates the intervention's effect: activations below the 'a' threshold remain unchanged, while those above are smoothly redirected according to the Bézier curve.""",
) as show:
    fig = plt.figure(1, (6, 6))
    ax = fig.add_subplot()

    # Use generic linear mapping drawer
    draw_mapping_linear(ax, bezier_falloff, title=title)

    # Control points overlay
    draw_bezier_handle(bezier_falloff.P0, bezier_falloff.P1, color='hotpink')
    draw_bezier_handle(bezier_falloff.P2, bezier_falloff.P3, color='orange')

    ax.legend(frameon=False, fontsize='small')

    show(fig)
    plt.close(fig)