# Imports

In [1]:
#%pip install git+https://github.com/neelnanda-io/TransformerLens.git
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor

import numpy as np
import pandas as pd
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
# import plotly.express as px
from torch.utils.data import DataLoader
from typing import Union, List, Optional, Callable, Tuple, Dict, Literal, Set
from jaxtyping import Float, Int
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.utils import to_numpy
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache, patching

import plotly.express as px
#%pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
import circuitsvis as cv
import os, sys

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-xkqftywk
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-xkqftywk
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit 1ec4a8f8fa4368e95500fe3a188b367500af5f98
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Note: you may need to restart the kernel to use updated packages.


In [2]:
if not os.path.exists("path_patching.py"):
        !wget https://github.com/callummcdougall/path_patching/archive/refs/heads/main.zip
        !unzip /content/main.zip 'path_patching-main/ioi_dataset.py'
        !unzip /content/main.zip 'path_patching-main/path_patching.py'
        sys.path.append("/content/path_patching-main")
        os.remove("/content/main.zip")
        os.rename("/content/path_patching-main/ioi_dataset.py", "ioi_dataset.py")
        os.rename("/content/path_patching-main/path_patching.py", "path_patching.py")
        os.rmdir("/content/path_patching-main")

from path_patching import Node, IterNode, path_patch, act_patch

In [3]:
%pip install git+https://github.com/neelnanda-io/neel-plotly.git
from neel_plotly import imshow, line, scatter, histogram
import tqdm
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

Collecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-l1gvx_om
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-l1gvx_om
  Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7
  Preparing metadata (setup.py) ... [?25ldone
Note: you may need to restart the kernel to use updated packages.


In [4]:
device

'cuda'

In [5]:
update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat",
    "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor",
    "showlegend", "xaxis_tickmode", "yaxis_tickmode", "xaxis_tickangle", "yaxis_tickangle", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap"
}

def imshow(tensor, return_fig = False, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    facet_labels = kwargs_pre.pop("facet_labels", None)
    border = kwargs_pre.pop("border", False)
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label
    if border:
        fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
        fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    # things like `xaxis_tickmode` should be applied to all subplots. This is super janky lol but I'm under time pressure
    for setting in ["tickangle"]:
      if f"xaxis_{setting}" in kwargs_post:
          i = 2
          while f"xaxis{i}" in fig["layout"]:
            kwargs_post[f"xaxis{i}_{setting}"] = kwargs_post[f"xaxis_{setting}"]
            i += 1
    fig.update_layout(**kwargs_post)
    fig.show(renderer=renderer)
    if return_fig:
      return fig

def hist(tensor, renderer=None, **kwargs):
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    names = kwargs_pre.pop("names", None)
    if "barmode" not in kwargs_post:
        kwargs_post["barmode"] = "overlay"
    if "bargap" not in kwargs_post:
        kwargs_post["bargap"] = 0.0
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])
    fig = px.histogram(x=tensor, **kwargs_pre).update_layout(**kwargs_post)
    if names is not None:
        for i in range(len(fig.data)):
            fig.data[i]["name"] = names[i // 2]
    fig.show(renderer)

In [6]:
from plotly import graph_objects as go
from plotly.subplots import make_subplots

In [413]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [414]:
def topk_of_Nd_tensor(tensor: Float[Tensor, "rows cols"], k: int):
    '''
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    '''
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

# Generate Dataset

In [415]:
def generate_repeated_tokens(model, batch, seq_len) -> Float[Tensor, "batch seq_len*2"]:
  tokens = torch.randint(1, model.cfg.d_vocab, (batch, seq_len))
  return torch.concat((tokens, tokens), dim=-1)

In [417]:
# generate clean prompts
gen_size = 600
prompt_half_length = 10


gen_tokens = generate_repeated_tokens(model, gen_size, prompt_half_length).cuda()
gen_prompts = [''.join(model.to_str_tokens(gen_tokens[i, :])) for i in range(gen_tokens.shape[0])]

induction_tokens = []
induction_prompts = []

for i in range(gen_size):
    # for prompts that have the same tokens before and after, add to prompts
    if gen_tokens[i].equal(model.to_tokens(gen_prompts[i], prepend_bos=False)[0]):
        induction_tokens.append(gen_tokens[i])
        #induction_prompts.append(gen_prompts[i])


for i in range(len(induction_tokens)):
    #print(good_tokens[i])
    #print(model.to_tokens(model.to_string(good_tokens[i]), prepend_bos=False))
    assert induction_tokens[i].equal(model.to_tokens(model.to_string(induction_tokens[i]), prepend_bos=False)[0])
    #assert induction_prompts[i] == model.to_string(induction_tokens[i])


for i, prompt in enumerate(induction_tokens):
    # remove last token
    induction_tokens[i] = prompt[:-1]
    # generate new prompt and add to induction prompts
    induction_prompts.append(model.to_string(induction_tokens[i]))

broken_batch_size = len(induction_tokens)


In [418]:
print(len(induction_tokens[2]))

19


generate corrupted prompt by changing prompt_half_length'th token

In [419]:
# generate corrupted prompts and tokens by changing prompt_half_length'th token

possible_broken_corrupted_tokens = []
possible_broken_corrupted_prompts = []

for i in range(broken_batch_size):
    temp = induction_tokens[i].clone()
    temp[prompt_half_length] = torch.randint(1, model.cfg.d_vocab, (1,))
    possible_broken_corrupted_tokens.append(temp)
    possible_broken_corrupted_prompts.append(model.to_string(temp))
    


In [420]:
BOS_TOKEN = model.to_tokens("")[0].item()

In [421]:
BOS_TOKEN

50256

## todo: remove prompts which corrupted stays the same

In [422]:
possible_broken_corrupted_tokens[0]

tensor([11789, 28138, 39025,  3584, 43578, 26520,  7944, 18658, 18754, 40506,
         2770, 28138, 39025,  3584, 43578, 26520,  7944, 18658, 18754],
       device='cuda:0')

In [423]:
# filter possible_broken_corrupted_tokens and possible_broken_corrupted_prompts to only include only the ones to which tokenization remains same
corrupted_tokens = []
corrupted_prompts = []

# ALSO ADDS BOS!
num_removed = 0
for i in range(broken_batch_size):
    if possible_broken_corrupted_tokens[i].equal(model.to_tokens(possible_broken_corrupted_prompts[i], prepend_bos=False)[0]):
        
        new_corrupted_token = torch.cat((torch.tensor([BOS_TOKEN]).cuda(), possible_broken_corrupted_tokens[i])).cuda()
        corrupted_tokens.append(new_corrupted_token)
        corrupted_prompts.append(model.to_string(new_corrupted_token))

        new_clean_token = torch.cat((torch.tensor([BOS_TOKEN]).cuda(), induction_tokens[i - num_removed])).cuda()
        induction_tokens[i - num_removed] = new_clean_token
        induction_prompts[i - num_removed] = model.to_string(new_clean_token)
    else:
        # remove associated induction prompt
        induction_tokens.pop(i - num_removed)
        induction_prompts.pop(i - num_removed)
        num_removed += 1

assert len(corrupted_tokens) == len(corrupted_prompts) == len(induction_tokens) == len(induction_prompts)
batch_size = len(corrupted_tokens)
clean_tokens = torch.stack(induction_tokens)
corrupted_tokens = torch.stack(corrupted_tokens)
answer_tokens = torch.stack((clean_tokens[:, prompt_half_length + 1], corrupted_tokens[:, prompt_half_length + 1]), dim=1)
#answer_tokens = torch.stack((clean_tokens[:, prompt_half_length], clean_tokens[:, 3]), dim=1) # random toke in string!


# delete temp variables
del possible_broken_corrupted_tokens
del possible_broken_corrupted_prompts

In [424]:
print(batch_size)

395


In [425]:
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [426]:
answer_residual_directions: Float[Tensor, "batch 2 d_model"] = model.tokens_to_residual_directions(answer_tokens)
correct_residual_direction, incorrect_residual_direction = answer_residual_directions.unbind(dim=1)
logit_diff_directions: Float[Tensor, "batch d_model"] = correct_residual_direction - incorrect_residual_direction

# Helper Functions

In [427]:
def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False
):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''
    # SOLUTION
    # Only the final logits are relevant for the answer
    final_logits: Float[Tensor, "batch d_vocab"] = logits[:, -1, :]
    # Get the logits corresponding to the indirect object / subject tokens respectively
    answer_logits: Float[Tensor, "batch 2"] = final_logits.gather(dim=-1, index=answer_tokens.to(device))
    # Find logit difference
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [428]:
def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given
    stack of components in the residual stream.
    '''
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)



    # # some extra code for more sanity checking
    # new_logits = scaled_residual_stack @ model.W_U
    # print(new_logits.shape)
    # new_logits = einops.repeat(new_logits, "batch d_vocab -> batch 1 d_vocab")
    # print(new_logits.shape)
    # print(logits_to_ave_logit_diff(new_logits))

    return einops.einsum(
        scaled_residual_stack, logit_diff_directions,
        "... batch d_model, batch d_model -> ..."
    ) / batch_size


In [429]:
def calc_all_logit_diffs(cache):
  clean_per_head_residual, labels = cache.stack_head_results(layer = -1, return_labels = True, apply_ln = False) # per_head_residual.shape = heads batch seq_pos d_model
  # also, for the worried, no, we're not missing the application of LN here since it gets applied in the below function call
  per_head_logit_diff: Float[Tensor, "batch head"] = residual_stack_to_logit_diff(clean_per_head_residual[:, :, -1, :], cache)

  per_head_logit_diff = einops.rearrange(
      per_head_logit_diff,
      "(layer head) ... -> layer head ...",
      layer=model.cfg.n_layers
  )

  correct_direction_per_head_logit: Float[Tensor, "batch head"] = residual_stack_to_logit_diff(clean_per_head_residual[:, :, -1, :], cache, logit_diff_directions = correct_residual_direction)

  correct_direction_per_head_logit = einops.rearrange(
      correct_direction_per_head_logit,
      "(layer head) ... -> layer head ...",
      layer=model.cfg.n_layers
  )

  incorrect_direction_per_head_logit: Float[Tensor, "batch head"] = residual_stack_to_logit_diff(clean_per_head_residual[:, :, -1, :], cache, logit_diff_directions = incorrect_residual_direction)

  incorrect_direction_per_head_logit = einops.rearrange(
      incorrect_direction_per_head_logit,
      "(layer head) ... -> layer head ...",
      layer=model.cfg.n_layers
  )

  return per_head_logit_diff, correct_direction_per_head_logit, incorrect_direction_per_head_logit

per_head_logit_diff, correct_direction_per_head_logit, incorrect_direction_per_head_logit = calc_all_logit_diffs(clean_cache)

Tried to stack head results when they weren't cached. Computing head results now


In [430]:
answer_residual_directions: Float[Tensor, "batch 2 d_model"] = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions: Float[Tensor, "batch d_model"] = correct_residual_directions - incorrect_residual_directions
print(f"Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([395, 2, 768])
Logit difference directions shape: torch.Size([395, 768])


In [431]:
clean_per_prompt_diff = logits_to_ave_logit_diff(clean_logits, per_prompt = True)

clean_average_logit_diff = logits_to_ave_logit_diff(clean_logits)
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits)

print(clean_average_logit_diff)
print(corrupted_average_logit_diff)

tensor(3.2651, device='cuda:0')
tensor(-0.9653, device='cuda:0')


In [432]:
diff_from_unembedding_bias = model.b_U[answer_tokens[:, 0]] -  model.b_U[answer_tokens[:, 1]]
final_residual_stream: Float[Tensor, "batch seq d_model"] = clean_cache["resid_post", -1]
print(f"Final residual stream shape: {final_residual_stream.shape}")
final_token_residual_stream: Float[Tensor, "batch d_model"] = final_residual_stream[:, -1, :]

print(f"Calculated average logit diff: {(residual_stack_to_logit_diff(final_token_residual_stream, clean_cache, logit_diff_directions) + diff_from_unembedding_bias.mean(0)):.10f}") # <-- okay b_U exists... and matters
print(f"Original logit difference:     {clean_average_logit_diff:.10f}")

Final residual stream shape: torch.Size([395, 20, 768])
Calculated average logit diff: 3.2650709152
Original logit difference:     3.2650706768


In [433]:
def calc_all_logit_diffs(cache):
  clean_per_head_residual, labels = cache.stack_head_results(layer = -1, return_labels = True, apply_ln = False) # per_head_residual.shape = heads batch seq_pos d_model
  # also, for the worried, no, we're not missing the application of LN here since it gets applied in the below function call
  per_head_logit_diff: Float[Tensor, "batch head"] = residual_stack_to_logit_diff(clean_per_head_residual[:, :, -1, :], cache)

  per_head_logit_diff = einops.rearrange(
      per_head_logit_diff,
      "(layer head) ... -> layer head ...",
      layer=model.cfg.n_layers
  )

  correct_direction_per_head_logit: Float[Tensor, "batch head"] = residual_stack_to_logit_diff(clean_per_head_residual[:, :, -1, :], cache, logit_diff_directions = correct_residual_direction)

  correct_direction_per_head_logit = einops.rearrange(
      correct_direction_per_head_logit,
      "(layer head) ... -> layer head ...",
      layer=model.cfg.n_layers
  )

  incorrect_direction_per_head_logit: Float[Tensor, "batch head"] = residual_stack_to_logit_diff(clean_per_head_residual[:, :, -1, :], cache, logit_diff_directions = incorrect_residual_direction)

  incorrect_direction_per_head_logit = einops.rearrange(
      incorrect_direction_per_head_logit,
      "(layer head) ... -> layer head ...",
      layer=model.cfg.n_layers
  )

  return per_head_logit_diff, correct_direction_per_head_logit, incorrect_direction_per_head_logit

per_head_logit_diff, correct_direction_per_head_logit, incorrect_direction_per_head_logit = calc_all_logit_diffs(clean_cache)

In [434]:
top_heads = []
k = 5

flattened_tensor = per_head_logit_diff.flatten().cpu()
_, topk_indices = torch.topk(flattened_tensor, k)
top_layer_arr, top_index_arr = np.unravel_index(topk_indices.numpy(), per_head_logit_diff.shape)

for l, i in zip(top_layer_arr, top_index_arr):
  top_heads.append((l,i))

print(top_heads)

[(11, 3), (11, 2), (11, 9), (10, 0), (11, 6)]


In [435]:
def display_all_logits(cache, title = "Logit Contributions", comparison = False, return_fig = False, logits = None):

  a,b,c = calc_all_logit_diffs(cache)
  if logits is not None:
    ld = logits_to_ave_logit_diff(logits)
  else:
    ld = 0.00

  if not comparison:
    fig = imshow(
        torch.stack([a,b,c]),
        return_fig = True,
        facet_col = 0,
        facet_labels = [f"Logit Diff - {ld:.2f}", "Correct Direction", "Incorrect Direction"],
        title=title,
        labels={"x": "Head", "y": "Layer", "color": "Logit Contribution"},
        #coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=1500,
        margin={"r": 100, "l": 100}
    )
  else:

    ca, cb, cc = calc_all_logit_diffs(clean_cache)
    fig = imshow(
        torch.stack([a, b, c, a - ca, b - cb, c - cc]),
        return_fig = True,
        facet_col = 0,
        facet_labels = [f"Logit Diff - {ld:.2f}", "Correct Direction", "Incorrect Direction", "Logit Diff Diff", "Correction Direction Diff", "Incorrect Direction Diff"],
        title=title,
        labels={"x": "Head", "y": "Layer", "color": "Logit Contribution"},
        #coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=1500,
        margin={"r": 100, "l": 100}
    )


  if return_fig:
    return fig



fig = display_all_logits(clean_cache, title = "Logit Contributions on Clean Dataset", return_fig = True, logits = clean_logits)

In [436]:
def stare_at_attention_and_head_pat(cache, layer_to_stare_at, head_to_isolate, display_corrupted_text = False, verbose = True, specific = False, specific_index = 0):
  """
  given a cache from a run, displays the attention patterns of a layer, as well as printing out how much the model
  attends to the S1, S2, and IO token
  """

  tokenized_str_tokens = model.to_str_tokens(corrupted_tokens[0]) if display_corrupted_text else model.to_str_tokens(clean_tokens[0])
  attention_patten = cache["pattern", layer_to_stare_at]
  print(f"Layer {layer_to_stare_at} Head {head_to_isolate} Activation Patterns:")


#   if not specific:
#     S1 = attention_patten.mean(0)[head_to_isolate][-1][2].item()
#     IO = attention_patten.mean(0)[head_to_isolate][-1][4].item()
#     S2 = attention_patten.mean(0)[head_to_isolate][-1][10].item()
#   else:
#     S1 = attention_patten[specific_index, head_to_isolate][-1][2].item()
#     IO = attention_patten[specific_index, head_to_isolate][-1][4].item()
#     S2 = attention_patten[specific_index, head_to_isolate][-1][10].item()


#   print("Attention on S1: " + str(S1))
#   print("Attention on IO: " + str(IO))
#   print("Attention on S2: " + str(S2))
#   print("S1 + IO - S2 = " + str(S1 + IO - S2))
#   print("S1 + S2 - IO = " + str(S1 + S2 - IO))
#   print("S1 - IO - S2 = " + str(S1 - S2 - IO))


  if verbose:
    display(cv.attention.attention_heads(
      tokens=tokenized_str_tokens,
      attention= attention_patten.mean(0) if not specific else attention_patten[specific_index],
      #attention_head_names=[f"L{layer_to_stare_at}H{i}" for i in range(model.cfg.n_heads)],
    ))
  else:
    print(attention_patten.mean(0).shape)

    display(cv.attention.attention_patterns(
      tokens=tokenized_str_tokens,
      attention=attention_patten.mean(0)if not specific else attention_patten[specific_index],
      attention_head_names=[f"L{layer_to_stare_at} H{i}" for i in range(model.cfg.n_heads)],
    ))

In [437]:
stare_at_attention_and_head_pat(clean_cache,11,0
                                , display_corrupted_text = False, verbose = True, specific = False, specific_index = 0)

Layer 11 Head 0 Activation Patterns:


In [408]:
def noising_ioi_metric(
    logits: Float[Tensor, "batch seq d_vocab"],
    clean_logit_diff: float = clean_average_logit_diff,
    corrupted_logit_diff: float = corrupted_average_logit_diff,
) -> float:
    '''
    Given logits, returns how much the performance has been corrupted due to noising.

    We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
    and -1 when performance has been destroyed (i.e. is same as ABC dataset).
    '''
    #print(logits[-1, -1])
    patched_logit_diff = logits_to_ave_logit_diff(logits)
    return ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)).item()

print(f"IOI metric (IOI dataset): {noising_ioi_metric(clean_logits):.4f}")
print(f"IOI metric (ABC dataset): {noising_ioi_metric(corrupted_logits):.4f}")

def denoising_ioi_metric(
    logits: Float[Tensor, "batch seq d_vocab"],
    clean_logit_diff: float = clean_average_logit_diff,
    corrupted_logit_diff: float = corrupted_average_logit_diff,
) -> float:
    '''
    We calibrate this so that the value is 1 when performance got restored (i.e. same as IOI dataset),
    and 0 when performance has been destroyed (i.e. is same as ABC dataset).
    '''
    patched_logit_diff = logits_to_ave_logit_diff(logits)
    return ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff) + 1).item()


print(f"IOI metric (IOI dataset): {denoising_ioi_metric(clean_logits):.4f}")
print(f"IOI metric (ABC dataset): {denoising_ioi_metric(corrupted_logits):.4f}")

IOI metric (IOI dataset): 0.0000
IOI metric (ABC dataset): -1.0000
IOI metric (IOI dataset): 1.0000
IOI metric (ABC dataset): 0.0000


In [409]:
def path_patch_to_head(layer, head, noisily = True):
  head_results = []

  if noisily:
    for i in ["q", "k", "v", "pattern"]:
      model.reset_hooks() # callum library buggy
      head_results.append(path_patch(
          model,
          orig_input=clean_tokens,
          new_input=corrupted_tokens,
          sender_nodes=IterNode("z"),
          receiver_nodes=[Node(i, layer = layer, head = head)],
          patching_metric=noising_ioi_metric,
      )['z'])
  else:
    for i in ["q", "k", "v", "pattern"]:
      model.reset_hooks() # callum library buggy
      head_results.append(path_patch(
          model,
          orig_input=corrupted_tokens,
          new_input=clean_tokens,
          sender_nodes=IterNode("z"),
          receiver_nodes=[Node(i, layer = layer, head = head)],
          patching_metric=denoising_ioi_metric,
      )['z'])


  return head_results

In [410]:
# all_path_patch_results = {}

In [411]:
# for head in range(12):
#     layer = 10
#     results = path_patch_to_head(layer, head)
#     all_path_patch_results[f"{layer}_{head}_noisy"] = results

In [438]:
layer = 11
head = 3
imshow(
  torch.stack(path_patch_to_head(layer,head)) * 100,
  facet_col = 0,
  facet_labels = ["Query", "Key", "Value", "Pattern"],
  title=f"Effect of Noisily Path Patching in {layer}.{head}",
  labels={"x": "Head", "y": "Layer", "color": "Percent Negative Degretation"},
  coloraxis=dict(colorbar_ticksuffix = "%"),
  border=True,
  width=1500,
  margin={"r": 100, "l": 100},
)

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

## ablation experiments
i want to see if backup is actually going on lol

In [440]:
def return_item(item):
  return item


In [446]:
ablate_heads = [(8,i) for i in range(12)]#(7,i) for i in range(12)]
act_logit_results = act_patch(
    model = model,
    orig_input = clean_tokens,
    new_cache = corrupted_cache,
    patching_nodes = [Node("z", layer = int(layer), head = int(head)) for layer, head in ablate_heads],
    patching_metric = return_item,
    apply_metric_to_cache = False
    )


act_patching_results = act_patch(
    model = model,
    orig_input = clean_tokens,
    new_cache = corrupted_cache,
    patching_nodes =  [Node("z", layer = int(layer), head = int(head)) for layer, head in ablate_heads],
    patching_metric = partial(display_all_logits, title = f"Logits when Act Patching in {layer}.{head}", comparison = True, logits = act_logit_results),
    apply_metric_to_cache = True
    )

Tried to stack head results when they weren't cached. Computing head results now
