In [1]:
# Setup (no need to read). Scroll down for introduction.

# Setup largely copied from Neel Nanda's introduction to TransformerLens:
# https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Main_Demo.ipynb
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis

    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

import circuitsvis as cv
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

device = "cuda" if torch.cuda.is_available() else "cpu"

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-qjfu6a2o
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-qjfu6a2o
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 218ebd6f491f47f5e2f64e4c4327548b60a093eb
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting datasets>=2.7.1 (from transformer-lens==0.0.0)
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0 (from transformer-lens==0.0.0)
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Collecting circuitsvis
  Downloading circuitsvis-1.40.0-py3-none-any.whl (1.8 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.8/1.8 MB[0m [31m47.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m32.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting importlib-metadata<6.0.0,>=5.1.0 (from circuitsvis)
  Downloading importlib_metadata-5.2.0-py3-none-any.whl (21 kB)
Installing collected packages: importlib-metadata, circuitsvis
Successfully installed circuitsvis-1.40.0 importlib-metadata-5.2.0
Using renderer: colab


Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


**Introduction**

*Authors: Theo Clark, Alex Roman, Hannes Thurnherr*

This paper explores how a language model, in this case GPT2-small, is able to detect whether it is located inside or outside of a bracketed sequence.

Because it is not possible to assess this directly, we use a proxy for this concept: the likelihood assigned to predicting a ")". We hypothesised that the model needs to keep track of the concept of "inside vs outside of parentheses" to perform this task accurately.

Our general approach is to identify combinations of heads that correlate strongly with this proxy and then, once isolated, to reason more concretely and reconstruct the algorithm which might be implemented.

Despite tackling this problem from a number of angles, we did not find significant evidence of a circuit responsible for this behaviour and indeed found some evidence pointing to this information being highly distributed throughout the network. However, we are also acutely aware of the limitations in tooling available to locate such networks, and highlight this as a matter of importance for future research agendas.

This project was conducted as part of the Apart Research Alignment Jam #10 (Interpretability 3.0), 2023 (see https://alignmentjam.com/jam/interpretability)

**Splitting the sequence**

We predicted that, for a head responsible for this behaviour, we would see tokens up to and including the closing bracket attend much more strongly to the opening bracket than tokens after the closing bracket.

Our initial pass identified (Head 3, Layer 10) and (Head 5, Layer 2) as the most promising candidates.

We reasoned that the arguments for and against closing a bracket mostly depend on the actual content of the text rather than just the question of whether a bracket has been opened. To see whether the same heads still show the same behaviour we let GPT-4 generate 27 sentences that contain a bracket and ran the same experiment again. We expected to see much more noise since analysing the actual content of the sentences for whether a bracket should be closed probably requires much more circuitry than just one or two heads. Suprisingly this is not what we got. The results remain pretty consistent with head 5 of layer 2 and to a lesser degree head 3 of layer 10 being the most active.

This initially seemed quite promising.

However, there are a number of reasons why a token might want to attend to "(" when it's inside rather than outside of parentheses (e.g. judging phrase length) and so this is a long way from proving a causal relationship.

Indeed, we found no further evidence of these heads being responsible for this behaviour and we remain unsure as to why these heads might be highlighted in this situation.

In [2]:
# The average per-head difference in attention to the "(" by tokens before and after ")"
opening_bracket_token = 357       # " ("
closing_bracket_token = 8         # ")"

# Generate a batch of random sequences with opening and closing brackets at the 9th and 19th index
def generate_batch(bs, len_examples = 30, max_vocab_size = 5000, open_idx = 9, close_idx = 19):
  batch = torch.randint(low = 256, high = 256 + max_vocab_size, size=(bs, len_examples)) # first 256 tokens are junk
  batch[torch.arange(bs), open_idx] = opening_bracket_token
  batch[torch.arange(bs), close_idx] = closing_bracket_token
  return batch

tokens = generate_batch(bs = 100)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=False)
attn = []
for layer_idx in range(12):
  attention_pattern = cache["pattern", layer_idx, "attn"]
  attn.append(attention_pattern)
attn = torch.stack(attn).permute(0,2,1,3,4)         # layers, heads, bs, y, x

within_bracket = attn[:,:,:,10:20,9].mean(dim=(2,3))
beyond_bracket = attn[:,:,:,20:,9].mean(dim=(2,3))
difference = within_bracket - beyond_bracket
imshow(difference, xaxis="Head", yaxis="Layer", title="Opening bracket attention before and after the closing bracket")

In [3]:
# Repeating the above experiment but with realistic sentences generated by GPT-4

closing_bracket_tokens = [828, 8, 737]
manual_samples = [
"The piano (which was old and out of tune) sat in the corner of the room.",
"Many people (including our friends and neighbors) came to enjoy the fair.",
"I quickly finished the book (despite its length and complexity), then started another.",
"The city (once bustling with life and laughter) now stood eerily silent.",
"I made pancakes (topped with berries and whipped cream) for our breakfast.",
"She gingerly picked up the letter (sealed in an old envelope) and read it.",
"The painting (hung in the most prominent spot) drew everyone's attention.",
"They daringly climbed the mountain (known for its treacherous terrain) and reached the summit.",
"He utilized his inheritance (a large sum of money) for the welfare of the poor.",
"She lovingly began knitting a blanket (warm, soft, and comforting) for her grandchild.",
"At the park, we saw the fountain (turned off for the winter), silent and still.",
"The main dish was pasta (smothered in a delicious tomato sauce), which everyone loved.",
"His thesis (focused on the impacts of climate change) won an academic award.",
"The company (just a small start-up) now competes with the tech giants.",
"She wore a beautiful dress (patterned with flowers and birds) that complemented her eyes.",
"For dessert, we indulged in cookies (chocolate chip, my personal favorite), warm and gooey.",
"The tree (old and gnarled) hid a secret hollow filled with trinkets.",
"The scientist (known for her groundbreaking work) was invited to the international conference.",
"The ship (deemed unsinkable) eventually sank due to an unfortunate accident.",
"Our tour guide (knowledgeable about local history) showed us the most interesting sites.",
"Despite my preferences, I decided to try sushi (even though I dislike raw fish) and it was great!.",
"They spent the entire day at the museum (home to many precious artifacts), exploring.",
"The house (once grand and elegant) now lay forgotten, a relic of the past.",
"His stamp collection (painstakingly gathered over years) is an impressive sight to behold.",
"The dog (cute but incredibly mischievous) left a trail of destruction in its wake.",
"Despite the minor mishaps, our vacation was fantastic (we even saw the Northern Lights), if a little cold",
"The ocean (calm and peaceful) stretched out to the horizon, a vast expanse of blue."
]

samples = []
for sample in manual_samples:
  tokenised = model.to_tokens(sample)
  opening_idx = 0
  closing_idx = 0
  for i in range(len(tokenised[0])):
    if tokenised[0][i] == opening_bracket_token:
      opening_idx = i
    elif tokenised[0][i] in closing_bracket_tokens:
      closing_idx = i
      break
  samples.append((tokenised, opening_idx, closing_idx))

differences = []
for tokens, start_idx, final_idx in samples:
  logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
  attn = []
  for layer_idx in range(12):
    attention_pattern = cache["pattern", layer_idx, "attn"]
    attn.append(attention_pattern)
  attn = torch.stack(attn)

  within_bracket = attn[:,:,start_idx+1:final_idx+1,start_idx].mean(axis=2)
  beyond_bracket = attn[:,:,final_idx+1:,start_idx].mean(axis=2)
  difference = within_bracket - beyond_bracket
  differences.append(difference)
differences = torch.stack(differences).mean(axis=0)

imshow(differences, xaxis="Head", yaxis="Layer", title="Opening bracket attention before and after the closing bracket")

**Activation Patching**

Activation patching allows for a more surgical and fine grained approach to locating target heads.

Methodology
1. Generate a prompt and confirm that ")" is the next predicted token.
2. Investigate patching attention values to -inf to prevent the final token from attending back to the "(" token.

Results
1. We found that patching the relevant value in a single head made no difference to the overall probability.
2. We even found that patching every single head for this particular value made little difference.
3. Blocking any future tokens from attending to the "(" does prevent ")" from being predicted (sanity check).
4. Gradually patching more layers and positions does not yield a sudden change in probability at any point.

Conclusion

Information about the open bracket is passed to the final position progressively through other positions rather than exclusively by directly attending to the "(".

The evidence so far indicates that this process of passing information is highly distributed across layers.

In [4]:
prompt = "John and Mary went home (with a bottle of milk"
tokens = model.to_tokens(prompt)
print(f"Prompt tokenized as: {model.to_str_tokens(prompt)}")
logits, cache = model.run_with_cache(tokens)
correct_idx = model.to_single_token(")")
print(f'")" is the most likely prediction? {correct_idx == logits[0, -1, :].argmax()}')
print(f'probability is: {F.softmax(logits, dim=2)[0, -1, correct_idx]}')

Prompt tokenized as: ['<|endoftext|>', 'John', ' and', ' Mary', ' went', ' home', ' (', 'with', ' a', ' bottle', ' of', ' milk']
")" is the most likely prediction? True
probability is: 0.30470868945121765


In [5]:
# Preventing the final position from attending to the "(" position for each head in turn
def bracket_patching_hook(
    activation: Float[torch.Tensor, "batch head target source"],
    hook: HookPoint,
    head: int,
    source: int,
    target: int
) -> Float[torch.Tensor, "batch pos d_model"]:
    activation[:, head, target, source] = -1e5
    return activation

out = torch.zeros(12,12)
for layer in range(12):
  for head in range(12):
    temp_hook_fn = partial(bracket_patching_hook, head=head, source=6, target=11)
    patched_logits = model.run_with_hooks(tokens, fwd_hooks=[
        (utils.get_act_name("attn_scores", layer=layer), temp_hook_fn)
    ])
    out[layer, head] = F.softmax(patched_logits, dim=2)[0, -1, correct_idx]
imshow(out, xaxis="Head", yaxis="Layer", title="Output probability when a head is patched to -inf")

In [6]:
# Preventing the final position from attending to a specific token for every head
def bracket_multi_patching_hook(
    activation: Float[torch.Tensor, "batch head target source"],
    hook: HookPoint,
    source: int,
    target: int
) -> Float[torch.Tensor, "batch pos d_model"]:
    activation[:, :, target, source] = -1e5
    return activation

def attn_finder(name):
  return "attn_scores" in name

seq_len = tokens.size()[-1]
out = torch.zeros((1, seq_len))
for source in range(seq_len):
  temp_hook_fn = partial(bracket_multi_patching_hook, source=source, target=-1)
  patched_logits = model.run_with_hooks(tokens, fwd_hooks=[(attn_finder, temp_hook_fn)])
  out[0,source] = F.softmax(patched_logits, dim=2)[0, -1, correct_idx]

imshow(out, xaxis="source token", title="Output probability when all heads are patched to -inf for a specific source token", x=model.to_str_tokens(tokens))

In [7]:
# Sanity check - preventing any future tokens from attending to a source token
def bracket_multi_patching_hook(
    activation: Float[torch.Tensor, "batch head target source"],
    hook: HookPoint,
    source: int,
    target: int
) -> Float[torch.Tensor, "batch pos d_model"]:
    activation[:, :, source+1:, source] = -1e5
    return activation

def attn_finder(name):
  return "attn_scores" in name

seq_len = tokens.size()[-1]
out = torch.zeros((1, seq_len))
for source in range(seq_len):
  temp_hook_fn = partial(bracket_multi_patching_hook, source=source, target=-1)
  patched_logits = model.run_with_hooks(tokens, fwd_hooks=[(attn_finder, temp_hook_fn)])
  out[0,source] = F.softmax(patched_logits, dim=2)[0, -1, correct_idx]

imshow(out, xaxis="source token", title="Output probability when no future tokens are able to attend to a source token", x=model.to_str_tokens(tokens))

In [8]:
# Gradually increasing the number of layers and the number of positions that are patched
def bracket_multi_patching_hook(
    activation: Float[torch.Tensor, "batch head target source"],
    hook: HookPoint,
    source: int,
    cutoff_token: int,
    cutoff_layer: int
) -> Float[torch.Tensor, "batch pos d_model"]:
    if int(hook.name.split(".")[1]) >= cutoff_layer:
      activation[:, :, cutoff_token:, source] = -1e5
    return activation

def attn_finder(name):
  return "attn_scores" in name

source = 6
seq_len = tokens.size()[-1]
out = torch.zeros((12, 5))
for cutoff_layer in range(12):
  for cutoff_token in range(source+1, seq_len):
    temp_hook_fn = partial(bracket_multi_patching_hook, source=source, cutoff_token=cutoff_token, cutoff_layer=cutoff_layer)
    patched_logits = model.run_with_hooks(tokens, fwd_hooks=[(attn_finder, temp_hook_fn)])
    out[cutoff_layer, cutoff_token - source - 1] = F.softmax(patched_logits, dim=2)[0, -1, correct_idx]

imshow(out, xaxis="cutoff_token", yaxis="cutoff_layer", title="Output probability when gradually increasing activation patching across positions and layers", x=model.to_str_tokens(tokens[0,source+1:]))

**Single token predictions**

In a bit to simplify the problem and gain more traction, the sequence following the opening bracket was shortened to a single character (i.e. what is the probability of predicting ")" one after "(" for example "(a)" would be a valid string).

This reduces the complexity of the problem but is less closely connected to the typical behaviour we are studying, something we address later on.

Methodology

1. A random string was used so as to decrease the chance of other relational information influencing activations (though this seemed to make little difference).
2. An inverse string, replacing " )" with "(", was used to generate activations for patching. This is a more realistic method for 'undoing' an activation than simply setting it to zero.
3. Carry out similar activation patching, but this time with fewer positions to worry about.

Results

1. Patching the attention output in the first 3 layers is enough to reduce the likelihood of predicting ")" to almost 0
2. Patching the second layer by itself was not enough to cause this behaviour: it needs to be combined with either layer 0 or layer 1
3. Ablating this result across all head combinations shows a particularly strong correlation between patching layer 1, head 5 and layer 2, head 1.

These last two results patched the value activations rather than the attention output

In [9]:
# Patching all attention outputs below and including a "cutoff layer"
inv_prompt = "sw Pres dis advantronic)bob"
inv_tokens = model.to_tokens(inv_prompt)
_, inv_cache = model.run_with_cache(inv_tokens)

def patching_hook(
    activation: Float[torch.Tensor, "batch target d_model"],
    hook: HookPoint,
    blocked_layer: list,
) -> Float[torch.Tensor, "batch pos d_model"]:
    if int(hook.name.split(".")[1]) <= blocked_layer:
      name = hook.name
      activation[:, -1, :] = inv_cache[name][:, -1, :]
    return activation

def finder(name):
  return "attn_out" in name

prompt = "sw Pres dis advantronic (bob"
tokens = model.to_tokens(prompt)
out = torch.zeros(1, 12)
for blocked_layer in range(12):
  hook_fn = partial(patching_hook, blocked_layer=blocked_layer)
  patched_logits = model.run_with_hooks(tokens, fwd_hooks=[(finder, hook_fn)])
  out[0,blocked_layer] = F.softmax(patched_logits, dim=2)[0, -1, correct_idx]

imshow(out, xaxis="cutoff layer", title="Patching the attention output in all layers up to and including the cutoff layer")

In [10]:
# Patching the values in pairs of layers for all combinations in the first 3 layers
inv_prompt = "sw Pres dis advantronic)bob"
inv_tokens = model.to_tokens(inv_prompt)
_, inv_cache = model.run_with_cache(inv_tokens)

def patching_hook(
    activation: Float[torch.Tensor, "batch head target source"],
    hook: HookPoint,
    source: int,
    blocked_layers: list,
) -> Float[torch.Tensor, "batch pos d_model"]:
    if int(hook.name.split(".")[1]) in blocked_layers:
      name = hook.name
      activation = inv_cache[name]
    return activation

def finder(name):
  return "hook_v" in name

prompt = "sw Pres dis advantronic (bob"
tokens = model.to_tokens(prompt)
source = 6

seq_len = tokens.size()[-1]
out = torch.zeros(3, 3)
for i in range(3):
  for j in range(3):
    blocked_layers = [i,j]
    hook_fn = partial(patching_hook, source=source, blocked_layers=blocked_layers)
    patched_logits = model.run_with_hooks(tokens, fwd_hooks=[(finder, hook_fn)])
    out[i,j] = F.softmax(patched_logits, dim=2)[0, -1, correct_idx]
imshow(out, xaxis="layer", yaxis="layer", title="Value patching pairs of early layers")

In [11]:
# Ablating this result further to compare pairs of heads in layers 0 and 2 and layers 1 and 2
inv_prompt = "sw Pres dis advantronic)bob"
inv_tokens = model.to_tokens(inv_prompt)
_, inv_cache = model.run_with_cache(inv_tokens)

def patching_hook(
    activation: Float[torch.Tensor, "batch head target source"],
    hook: HookPoint,
    source: int,
    blocked_layers: list,
    a: int,
    b: int
) -> Float[torch.Tensor, "batch pos d_model"]:
    layer = int(hook.name.split(".")[1])
    if layer in blocked_layers:
      name = hook.name
      if layer == blocked_layers[0]:
        activation[:,:,a,:] = inv_cache[name][:,:,a,:]
      else:
        activation[:,:,b,:] = inv_cache[name][:,:,b,:]
    return activation

def finder(name):
  return "hook_v" in name

prompt = "sw Pres dis advantronic (bob"
tokens = model.to_tokens(prompt)
source = 6

seq_len = tokens.size()[-1]
out = torch.zeros(12, 12)
for i in range(2):
  for a in range(12):
    for b in range(12):
      blocked_layers = [i,2]
      hook_fn = partial(patching_hook, source=source, blocked_layers=blocked_layers, a=a, b=b)
      patched_logits = model.run_with_hooks(tokens, fwd_hooks=[(finder, hook_fn)])
      out[a,b] = F.softmax(patched_logits, dim=2)[0, -1, correct_idx]
  imshow(out, xaxis=f"head (layer {blocked_layers[0]})", yaxis=f"head (layer {blocked_layers[1]})", title=f"Value patching heads in layers {blocked_layers[0]} and {blocked_layers[1]}")

**Hypothesis testing**

This result was initially quite exciting, indicating that a combination of two heads (layer 1, head 5 and layer 2, head 1) significantly correlate with reducing the probability of outputting a ")" when the preceding " (" is replaced with a ")". We came up with a number of hypotheses for what this "circuit" of two heads might be doing:

1. The circuit is spurious and does not generalise beyond this exact example
2. The circuit is responsible for the very precise task of predicting a close paren immediately after an open paren but does not generalise beyond that.
3. The circuit is predicting that a ")" should not follow another ")" and is showing up here because we are patching with ")".
4. The circuit is responsible for the more general behaviour of detecting whether the current position is inside or outside parentheses
5. Same as before but is also able to map that behaviour to other types of brackets e.g. [], {} or <>

We then made a number of switches to the original and inverse prompts to test these hypotheses.

Results:
1. Inverting the prompts (i.e. patching " (" over ")") causes the inverse behaviour as expected: the two heads we are interested in correlate with a much higher probability of predicting ")" next.
2. Altering earlier parts of the sequence or the final token does not cause any noticeable difference in output distribution.
3. Patching " (" with a token other than ")" (i.e. a random word) does NOT yield the same pattern. Indeed sometimes it increases the probability of outputting a ")"
4. But, if we then start with ")" in the main sequence and patch with a random token we so no correlation to that particular head.
5. Repeating the original experiment but with square brackets as opposed to round brackets breaks the pattern, indicating that this does not generalise to other forms of "openness"

Further experiments, similar to the ones above, were carried out with square braces to identify whether the same layers (0,1,2) were responsible for this behaviour and if there was any correlation. For square braces, the third layer also seemed to be important and there was no 3 way combination of any 3 heads in the first 3 layers that seemed to correlate with reducing the probability of predicting a ")" (see result below).

**Summary**

These results indicate that the circuit is more closely related to detecting a closing bracket than an opening one, but even this is not concrete. It certainly does not generalise to other brackets. Whilst this might be something worth investigating it does not seem to represent the idea of "being inside/outside parentheses".

In [12]:
# Repeating an earlier experiment but with square braces and this time comparing triplets of heads across the first 3 layers
inv_prompt = "sw Pres dis advantronic]bob"
inv_tokens = model.to_tokens(inv_prompt)
_, inv_cache = model.run_with_cache(inv_tokens)

def patching_hook(
    activation: Float[torch.Tensor, "batch head target source"],
    hook: HookPoint,
    source: int,
    a: int,
    b: int,
    c: int,
) -> Float[torch.Tensor, "batch pos d_model"]:
    layer = int(hook.name.split(".")[1])
    name = hook.name
    if layer == 0:
      activation[:,:,a,:] = inv_cache[name][:,:,a,:]
    elif layer == 1:
      activation[:,:,b,:] = inv_cache[name][:,:,b,:]
    elif layer == 2:
      activation[:,:,c,:] = inv_cache[name][:,:,c,:]
    return activation

def finder(name):
  return "hook_v" in name

prompt = "sw Pres dis advantronic [bob"
tokens = model.to_tokens(prompt)
source = 6
correct_idx = model.to_single_token("]")

seq_len = tokens.size()[-1]
out = torch.zeros(12, 144)
for a in range(12):
  for b in range(12):
    for c in range(12):
      blocked_layers = [i,2]
      hook_fn = partial(patching_hook, source=source, a=a, b=b, c=c)
      patched_logits = model.run_with_hooks(tokens, fwd_hooks=[(finder, hook_fn)])
      out[b, a*12 + c] = F.softmax(patched_logits, dim=2)[0, -1, correct_idx]
imshow(out, yaxis=f"head (layer 1)", xaxis=f"head (layer 2 repeated across each head in layer 0)", title=f"Value patching heads in layers 1,2 and 3 for square braces")

**Averaging Activations**

The methodology used above did not seem to isolate the intended behaviour. This may be because it is quite unusual to want to predict a closed bracket straight after an open bracket and does not represent a realistic scenario in which the circuit needs to judge whether it is inside or outside of parentheses. We therefore reverted back to using longer sequences but opted to average the activations across samples in the hope of isolating a more reliable signal.

All prompts are similar to the following example and are adapted from our earlier GPT4 prompts:

"The main house (once grand and very elegant"

The model assigns the highest probability of a ")" to the next token for all samples and they are all of the same token length and have the open parenthesis in the same location.

We used the following hypothesis of how the model might approach this task to guide us:

1. Does the previous token(s) indicate that this might be a new clause?
2. If yes --> "("
3. If no: am I inside a bracketed string?
4. If no --> predict random token
5. If yes: does it make sense to close the bracket given the current context?
6. If yes --> ")"
7. If no --> predict random token

It is the third step that we are looking to isolate and locate the circuit for.

**Methodology**

Obtain a set of averaged activations:

1. Run a forward pass and save out the cache for each sequence.
2. Repeat but replace " (" with " ". These will represent the inverse sequences.
3. Average all activations for the standard sequences and the inverse sequences.
4. Confirm that the averaged activations have the desired behaviour.

Isolate potential layers and heads

5. Compare the cosine similarity of incoming keys and values, as well as the output of each attention layer.
6. Repeat some of the earlier patching experiments but this time with the averaged values


**Results**
1. Most differences in cosine similarity occur for keys originating from the " (" token as you would expect but there are still clear differences in some of the other tokens which implies that information about the bracket can be passed indirectly through the network and the current position does not have to rely on attending directly to the " (", which corroborates what we have found previously.
2. Ablating the cosine similarity across heads for keys and values originating from the "(" position shows a lot of variation but there is nothing linking this to predicting ")" at the current position.
3. Plotting the cosine similarity for the "z" activations of every head shows a gradual decrease in similarity as the layers increase. It is likely just the case that, as we proceed through the network, there is more of an opportunity to attend to information from the changed token and so we see progressively more difference as we moe through the layers. Relative difference within the same layer is therefore likely more of an indicator.

Cosine similarity plots are interesting but do not link directly to the likelihood of outputting ")". It is therefore difficult to read to much into these results.

In [13]:
# Generate prompts and confirm that:
# 1. when tokenized, they are the same length and the open paren is in the same location
# 2. ")" is the highest predicted token

prompts = [
    "Many other people (including our friends and neighbors",
    "The main house (once grand and very elegant",
    "I finished it (despite its length and complexity",
    "The grand city (once bustling with loud laughter",
    "The thick blanket (warm and really very comforting",
    "The central fountain (turned off for the winter",
    "The main dish (covered in rich tomato sauce",
]

length = 10
paren = 4
correct_idx = model.to_single_token(")")
for prompt in prompts:
  tokens = model.to_tokens(prompt)
  idx = torch.where(tokens == 357)[1].item()
  assert idx == paren, f'idx error: {prompt}'
  assert tokens.size()[1] == length, f'length error: {prompt}'
  logits, cache = model.run_with_cache(tokens)
  assert correct_idx == logits[0, -1, :].argmax(), f'predicting error: {prompt}'

In [14]:
# Perform forward pass and save out cache for standard and inverse prompts
caches = []
inv_caches = []
for prompt in prompts:
  tokens = model.to_tokens(prompt)
  logits, cache = model.run_with_cache(tokens)
  caches.append(cache)

  inv_tokens = tokens
  inv_tokens[:,paren] = model.to_single_token(" ")
  inv_logits, inv_cache = model.run_with_cache(inv_tokens)
  inv_caches.append(inv_cache)

In [15]:
# Average the activations across all the prompts
avg_cache = {}
avg_inv_cache = {}
for k in caches[0]:
  values = []
  for cache in caches:
    values.append(cache[k])
  values = torch.stack(values).mean(dim=0)
  avg_cache[k] = values

  inv_values = []
  for cache in inv_caches:
    inv_values.append(cache[k])
  inv_values = torch.stack(inv_values).mean(dim=0)
  avg_inv_cache[k] = inv_values

In [16]:
# Confirm the desired behaviour by patching all activations with the average.
# This is fairly pointless because only the final residual activation is
# actually important but serves as a sanity check
def patching_hook(
    activation: Float[torch.Tensor, "batch head target source"],
    hook: HookPoint,
    invert: bool,
) -> Float[torch.Tensor, "batch pos d_model"]:
    name = hook.name
    if invert:
      activation = avg_inv_cache[name]
    else:
      activation = avg_cache[name]
    return activation

def finder(name):
  return True

correct_idx = model.to_single_token(")")
tokens = model.to_tokens(prompts[0])
for invert in [False, True]:
  hook_fn = partial(patching_hook, invert=invert)
  patched_logits = model.run_with_hooks(tokens, fwd_hooks=[(finder, hook_fn)])
  print(f'probability of predicting a ")" when invert={invert}: {F.softmax(patched_logits, dim=2)[0, -1, correct_idx].item()}')

probability of predicting a ")" when invert=False: 0.63490891456604
probability of predicting a ")" when invert=True: 0.0003868180210702121


**Cosine Similarity**

In [17]:
# Calculate the cosine similarity between key and value activations of normal and inverted sequences
key_stack = []
value_stack = []
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for k in avg_cache:
  if "hook_k" in k or "hook_v" in k:
    x = avg_cache[k].squeeze().flatten(start_dim=1)
    y = avg_inv_cache[k].squeeze().flatten(start_dim=1)
    sim = cos(x,y)
    if "hook_k" in k:
      key_stack.append(sim[4:])
    else:
      value_stack.append(sim[4:])
key_stack = torch.stack(key_stack)
value_stack = torch.stack(value_stack)
imshow(key_stack, xaxis="token", yaxis="layer", title="Cosine similarity between normal and inverted Keys")
imshow(value_stack, xaxis="token", yaxis="layer", title="Cosine similarity between normal and inverted Values")

In [18]:
# Ablate the first token position by head
key_stack = []
value_stack = []
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for k in avg_cache:
  if "hook_k" in k or "hook_v" in k:
    bracket_dim = 4
    x = avg_cache[k][0,bracket_dim,:]
    y = avg_inv_cache[k][0,bracket_dim,:]
    sim = cos(x,y)
    if "hook_k" in k:
      key_stack.append(sim)
    else:
      value_stack.append(sim)
key_stack = torch.stack(key_stack)
value_stack = torch.stack(value_stack)
imshow(key_stack, xaxis="head", yaxis="layer", title="Cosine similarity between normal and inverted source Keys by head")
imshow(value_stack, xaxis="head", yaxis="layer", title="Cosine similarity between normal and inverted source Values by head")

In [19]:
# The cosine similarity (between the standard and inverse averages) of the
# activations located after the values matrix has been multiplied by the
# attention scores (the "z" values)
attn_stack = []
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for k in avg_cache:
  if "hook_z" in k:
    x = avg_cache[k][0,-1,:].flatten(start_dim=1)
    y = avg_inv_cache[k][0,-1,:].flatten(start_dim=1)
    sim = cos(x,y)
    attn_stack.append(sim)
attn_stack = torch.stack(attn_stack)
imshow(attn_stack, xaxis="head", yaxis="layer", title="Attention similarity for each head for normal vs inverted prompt")

**Activation Patching**

Cosine similarity of activations by itself is not enough to explain the process we are trying to observe as it does not link directly to the likelihood of predicting a ")" or any notion of being inside parentheses.

It therefore makes sense to revisit activation patching but now with averaged values, in the hope that it provides a more reliable and robust signal for identifying relevant circuits.

Methodology:
1. Taking a sample sequence at random and always patch the very first residual with the standard average to negate the input.
2. Patch the key and value activations in every layer as this covers all of the activations that the final position is able to use to attend to prior information
3. Iterate over different layers - sometimes patching with the standard average and sometimes with the inverted average to identify if there are any layers/heads which correlate strongly with predicting a ")"

Results:
Sweeping over individual layers or blocks of layers yielded no clear layers or heads that might be responsible. A greedy search was performed (shown below), gradually increasing the number of layers that are patched, to identify which layers are most important, and this indicates the latter layers. However, blocking out the final few layers yields only a gradual decrease in likelihood.



In [20]:
def patching_hook(
    activation: Float[torch.Tensor, "batch head target source"],
    hook: HookPoint,
    layers: list
) -> Float[torch.Tensor, "batch pos d_model"]:
    name = hook.name
    if name == "blocks.0.hook_resid_pre":      # Patch the first residual input
      activation = avg_cache[name]
    elif "hook_k" in name or "hook_v" in name:
      if int(hook.name.split(".")[1]) in layers:
        activation = avg_inv_cache[name]
      else:
        activation = avg_cache[name]
    return activation

def finder(name):
  return True

def get_prob(layers):
  correct_idx = model.to_single_token(")")
  tokens = model.to_tokens(prompts[0])
  hook_fn = partial(patching_hook, layers=layers)
  patched_logits = model.run_with_hooks(tokens, fwd_hooks=[(finder, hook_fn)])
  return F.softmax(patched_logits, dim=2)[0, -1, correct_idx]

# Perform a greedy search to find an estimate for the most important 'n' layers as we increase n
options = [i for i in range(12)]
curr = []
for iterations in range(12):
  best_layer = None
  best_prob = 1
  for layer in options:
    if layer not in curr:
      curr.append(layer)
      prob = get_prob(curr)
      if prob < best_prob:
        best_layer = layer
        best_prob = prob
      curr.pop()
  curr.append(best_layer)
print(curr)

[10, 11, 9, 7, 5, 8, 2, 1, 3, 4, 6, 0]


**Conclusion**

We have found no substantial evidence of a specific circuit dedicated to identifying whether the current position is inside or outside of a bracketed sequence. There are a number of possible reasons for this which we will address in turn.

**1. Predicting ")" is a bad proxy**

Throughout these experiments we have used the probability of predicting a ")" as a proxy for whether the model believes it is inside or outside of parentheses, based on the idea that it is unlikely to assign high probability to ")" if it is not inside parentheses.

We have already indicated that predicting a ")" straight after " (" does not represent most situations where it is important to keep track of this across a sequence.

We have tried longer sequences with both random sequences (to isolate bracket behaviour) and realistic sequences (to better simulate real situations).

The last effort, involving averaging realistic sequences attempts to tie both of these together: both representing realistic sentences and averaging out some of the noise you might expect from them.

However, compared to Indirect Object Identification, where predicting the correct token is a very clear indicator of whether the model possesses this concept, the present task is less tangible and proxies are harder to come by.

There are many concepts which suffer from the same issue and further work is necessary to identify better methods or proxies for eliciting their behaviour.

**2. The model does not possess this concept**

This is difficult question to answer and depends slightly on definitions. On the one hand, it is clear that GPT2-small is very accurate when it comes to placing brackets in the correct place. On some level it clearly possesses the ability to know when it needs to add a ")". If the only thing this depends on is whether there is a "(" earlier in the sequence then it must be using that information when making its prediction.

Even if the model does not have the concept of "am I inside a bracketed sequence?", it must have a related concept, e.g. "is there a ')' more recently than a '(' in the preceding sequence?".

In essence this issue cancels out the first point regarding proxies. We are attempting to locate whatever 'concept' it is that allows the model to correctly predict a closing bracket.

If we accept that the concept is poorly defined but nevertheless present then the proxy ceases to be a poor one.

**3. Our methodologies are flawed/limited**

We have already discussed the potential issues with our choice of proxy for this behaviour. However, the techniques used to try to isolate this behaviour are definitely limited. It is not feasible to exhaustively search all possible combinations of heads to isolate this behaviour.



**4. The circuit responsible for this concept is highly distributed**

We have painstakingly isolated parts of the network to attempt to identify this behaviour and no highly reduced combination of layers/heads has been found that is clearly responsible.

There are broad combinations of layers that seem to be of more importance than others but nothing fine grained.

The slightly different problem of predicting the sequence "()" seems to be more easy to isolate but does not seem to correlate to the more general problem of predicting closed brackets later in a sequence.

With the tools and methods currently available to us, we are left to conclude that this concept is simply highly distributed throughout the network. This aligns with the fact that preventing the final token from attending to "(" is not enough to prevent a ")" from being predicted, indicating that this information is passed gradually through the network.

However, we are reluctant to conclude that definitively and the main takeaway from this paper should be that there is still much work to be done to improve methods and techniques for identifying circuits within models.