## How to use

Introduces a single new function `plot_info_flow(prompt)` which accepts a prompt and an optional boolean argument `plot_info_flow(prompt, by_layer=True)`. Defaults to `by_layer=False`

this function creates a heatmap that shows what previous token positions a token pays attention to.

attention can be shown in aggregate across all layers, or by layer if `by_layer=True`

In [1]:
%pip install transformer_lens
%pip install sae_lens
%pip install sae_vis
%pip install safetensors
%pip install circuitsvis


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.11 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.

[1m[

In [2]:
from transformer_lens import HookedTransformer
import torch as t
import einops

In [3]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)



Loaded pretrained model gpt2-small into HookedTransformer


In [12]:
import plotly.graph_objects as go

def heatmap(matrix, x_labels, y_labels, x_title, y_title, title):
    # Reverse the y_labels and corresponding rows in the matrix
    y_labels = list(reversed(y_labels))
    matrix = list(reversed(matrix))

    # Create the heatmap trace
    trace = go.Heatmap(
        z=matrix,
        x0=0, dx=1, y0=0, dy=1,
        xgap=1, ygap=1,
        xaxis='x', yaxis='y'
    )

    # Create the figure and set the x and y labels
    fig = go.Figure(data=trace)
    fig.update_layout(
        title=title,
        xaxis=dict(
            tickmode='array',
            tickvals=list(range(len(x_labels))),
            ticktext=x_labels,
            tickfont=dict(size=12),
            title=x_title
        ),
        yaxis=dict(
            tickmode='array',
            tickvals=list(range(len(y_labels))),
            ticktext=y_labels,
            tickfont=dict(size=12),
            title=y_title
        )
    )
    # Show the figure
    fig.show()

In [6]:
# Aggregate all attention outputs 
def get_activations(prompt):
    _, cache = model.run_with_cache(prompt, prepend_bos=True)
    batch = 0

    activations = []

    ## Get all the activations for every head in every layer
    for layer in range(model.cfg.n_layers):
        layer_activations = []
        for head in range(model.cfg.n_heads):
            head_acts = cache["pattern", layer][batch][head][1:, 1:] # getting the attention patterns for eveall tokens except the first <|endoftext|> token

            layer_activations.append(head_acts)
        layer_activations = t.stack(layer_activations)
        activations.append(layer_activations)

    activations = t.stack(activations)

    print(activations.shape)
    return activations

# This function will return the normalized attention values for a given token position
def aggregate_attention_layer_activations(tokens, activations, position=-1, by_layer:bool=False):
  if by_layer:
    position_contributions =  t.zeros((model.cfg.n_layers, len(tokens[1:])), device=model.cfg.device)
  else:
    position_contributions =  t.zeros(len(tokens[1:]), device=model.cfg.device)

  for layer in range(model.cfg.n_layers):
    layer_activations = activations[layer]
  
    token_activations = layer_activations[:,position,:]
    summed_values = einops.reduce(token_activations, 'tensors num -> num', 'sum')
    normalized_tensor = summed_values / t.sum(summed_values)

    if by_layer:
      position_contributions[layer] = normalized_tensor
    else:
      position_contributions += normalized_tensor

    if not by_layer: position_contributions = position_contributions/model.cfg.n_layers
  return position_contributions

def plot_info_flow(prompt, by_layer=False):
  tokens = model.to_str_tokens(prompt)
  activations = get_activations(prompt)

  total_position_contributions =  t.zeros((len(tokens[1:]), model.cfg.n_layers, len(tokens[1:])), device=model.cfg.device)

  for i in range(len(tokens) - 1):
    position_contributions = aggregate_attention_layer_activations(tokens, activations, i, by_layer)
    total_position_contributions[i] = t.tensor(position_contributions, device=model.cfg.device)

  ## Plot heatmamps by layer or aggregated
  if by_layer:
    for i in range(total_position_contributions.shape[1]):
      print(f"Layer {i}")
      layer_contributions = total_position_contributions[:,i,:].cpu().numpy()
      heatmap(layer_contributions, tokens[1:], tokens[1:], "Attention Patterns", "Target Token for prediction", "Information flow")
  else:
    total_position_contributions = total_position_contributions.mean(dim=1).cpu().numpy()
    heatmap(total_position_contributions, tokens[1:], tokens[1:], "Attention Patterns", "Target Token for prediction", "Information flow")


In [13]:
plot_info_flow("Mary and John went to the store. John gave the cart to")

torch.Size([12, 12, 13, 13])



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [14]:
plot_info_flow("Mary and John went to the store. John gave the cart to", by_layer=True)

torch.Size([12, 12, 13, 13])
Layer 0



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Layer 1


Layer 2


Layer 3


Layer 4


Layer 5


Layer 6


Layer 7


Layer 8


Layer 9


Layer 10


Layer 11
