In [1]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd

from pattern_lens.figure_util import matplotlib_figure_saver, save_matrix_wrapper
from pattern_lens.attn_figure_funcs import register_attn_figure_func
from pattern_lens.figures import figures_main

In [2]:
# define and register your own functions
# don't take these too seriously, they're just examples


# using matplotlib_figure_saver -- define a function that takes matrix and `plt.Axes`, modify the axes
@register_attn_figure_func
@matplotlib_figure_saver(fmt="svgz")
def svd_spectra(attn_matrix: np.ndarray, ax: plt.Axes) -> None:
    # Perform SVD
    U, s, Vh = svd(attn_matrix)

    # Plot singular values
    ax.plot(s, "o-")
    ax.set_yscale("log")
    ax.set_xlabel("Singular Value Index")
    ax.set_ylabel("Singular Value")
    ax.set_title("Singular Value Spectrum of Attention Matrix")


# manually creating and saving a figure
@register_attn_figure_func
def attention_flow(attn_matrix: np.ndarray, path: Path) -> None:
    """Visualize attention as flows between tokens.

    Creates a simplified Sankey-style diagram where line thickness and color
    intensity represent attention strength.
    """

    fig, ax = plt.subplots(figsize=(6, 6))
    n_tokens: int = attn_matrix.shape[0]

    # Create positions for tokens on left and right
    left_pos: np.ndarray = np.arange(n_tokens)
    right_pos: np.ndarray = np.arange(n_tokens)

    # Plot flows
    for i in range(n_tokens):
        for j in range(n_tokens):
            weight = attn_matrix[i, j]
            if weight > 0.05:  # Only plot stronger connections
                ax.plot(
                    [0, 1],
                    [left_pos[i], right_pos[j]],
                    alpha=weight,
                    linewidth=weight * 5,
                    color="blue",
                )

    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(-1, n_tokens)
    ax.axis("off")
    ax.set_title("Attention Flow Between Positions")

    # be sure to save the figure as `function_name.format` in the given location
    fig.savefig(path / "attention_flow.svgz", format="svgz")


@register_attn_figure_func
@save_matrix_wrapper(fmt="svgz")
def gram_matrix(attn_matrix: np.ndarray) -> np.ndarray:
    return attn_matrix @ attn_matrix.T

In [4]:
# run the pipeline
figures_main(
    model_name="pythia-14m",
    save_path=Path("docs/demo/"),
    n_samples=5,
    force=False,
)

✔️  (0.02s) setting up paths                                                   
✔️  (0.02s) loading prompts                                                    
5 prompts loaded
4 figure functions loaded
	raw, svd_spectra, attention_flow, gram_matrix


Making figures: 100%|██████████| 5/5 [00:01<00:00,  4.23prompt/s]

| (0.00s) updating jsonl metadata for models and functions                     




✔️  (0.04s) updating jsonl metadata for models and functions                   
