# Setup

In [38]:
!pip install transformer_lens
!pip install circuitsvis
!pip install auto-circuit



In [39]:
import torch as t

import auto_circuit as ac
from auto_circuit.data import load_datasets_from_json
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.types import AblationType
from auto_circuit.utils.ablation_activations import src_ablations
from auto_circuit.utils.graph_utils import patch_mode, patchable_model
from auto_circuit.utils.misc import repo_path_to_abs_path
from auto_circuit.visualize import draw_seq_graph, net_viz
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.types import PruneScores
import numpy as np
import itertools
import numpy as np
import torch as t
import torch.nn as nn
from transformer_lens import HookedTransformer
import einops
import circuitsvis as cv
import pandas as pd
import io
import json

In [40]:
pd.set_option('display.max_colwidth', 300)
device = t.device("cuda" if t.cuda.is_available() else "cpu")
tl_model = HookedTransformer.from_pretrained("gpt2-small")
ac_model = load_tl_model("gpt2-small", device)


`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884



Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer


# Bat prompt creation

In [41]:
animal_bat_context = ["In the cave,",
                      "In the forest,",
                      "At the zoo,"]
object_bat_context = ["In the stadium,",
                      "In the arena,",
                      "At home plate,"]

In [42]:
bat_actor = ["I",
         "Sarah",
         "John",
         "you"]

In [43]:
bat_verb = ["saw a",
        "found a",
        "detected a",
        "sensed a"]

In [44]:
bat_word = ["bat."]

In [45]:
bat_eliciter = ["Was the bat an animal or an object? It was an",
            "Was the bat an object or an animal? It was an"]

In [46]:
prompt_tails = [' '.join(x) for x in list(itertools.product(bat_actor, bat_verb, bat_word, bat_eliciter))]
prompt_dicts = []

for x in animal_bat_context:
  for y in object_bat_context:
    for z in prompt_tails:
      prompt_dict = {
      "clean": x+" "+z,
      "corrupt": y+" "+z,
      "answers": [" animal"],
      "wrong_answers": [" object"],
      }
      prompt_dicts.append(prompt_dict)

def correct_prompt(clean_prompt, correct_ans, wrong_ans):
  print(clean_prompt)
  with t.inference_mode():
    logits = ac_model(clean_prompt)[:, -1, :]
  probs = t.softmax(logits, dim=-1)
  topk_next_tokens= t.topk(probs[0], 10) # Get top 10 tokens regardless of what they are
  correct_idx, wrong_idx = ac_model.tokenizer(correct_ans)['input_ids'][0], ac_model.tokenizer(wrong_ans)['input_ids'][0]
  print("Correct|", correct_ans, ":", probs[0,correct_idx], "\nIncorrect |", wrong_ans, ":", probs[0, wrong_idx])
  print(*[(ac_model.tokenizer.decode(idx), prob) for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values)], sep="\n")
  if probs[0,correct_idx] > probs[0, wrong_idx]:
    return True
  else: False


print(type(prompt_dicts))
print(np.array(prompt_dicts).shape)
print(type(prompt_dicts[0]))

correct_prompt_dicts = []
for x in prompt_dicts:
  if correct_prompt(x["clean"], " animal", " object") and len(tl_model.to_str_tokens(tl_model.to_tokens(x["clean"]))) == 22 and len(tl_model.to_str_tokens(tl_model.to_tokens(x["corrupt"]))) == 22 :
    print(x)
    correct_prompt_dicts.append(x)

#clean_prompt_dicts = [x for x["clean"] in prompt_dicts if correct_prompt(x["clean"], " animal") == True]
print(*correct_prompt_dicts, sep="\n")

<class 'list'>
(288,)
<class 'dict'>
In the cave, I saw a bat. Was the bat an animal or an object? It was an
Correct|  animal : tensor(0.2908) 
Incorrect |  object : tensor(0.1106)
(' animal', tensor(0.2908))
(' object', tensor(0.1106))
(' insect', tensor(0.0444))
(' old', tensor(0.0390))
(' owl', tensor(0.0257))
(' enormous', tensor(0.0194))
(' elephant', tensor(0.0136))
(' ancient', tensor(0.0108))
(' egg', tensor(0.0089))
(' amphib', tensor(0.0085))
{'clean': 'In the cave, I saw a bat. Was the bat an animal or an object? It was an', 'corrupt': 'In the stadium, I saw a bat. Was the bat an animal or an object? It was an', 'answers': [' animal'], 'wrong_answers': [' object']}
In the cave, I saw a bat. Was the bat an object or an animal? It was an
Correct|  animal : tensor(0.2507) 
Incorrect |  object : tensor(0.1795)
(' animal', tensor(0.2507))
(' object', tensor(0.1795))
(' insect', tensor(0.0391))
(' old', tensor(0.0357))
(' owl', tensor(0.0203))
(' enormous', tensor(0.0161))
(' elep

In [47]:
prompt_dicts_json = {
    "prompts": correct_prompt_dicts
}
with open('disambiguation_bat_dataset.json', 'w') as outfile:
    json.dump(prompt_dicts_json, outfile)

# Baseline

In [48]:
print(tl_model.to_str_tokens(tl_model.to_tokens(" animal")))
print(tl_model.to_str_tokens(tl_model.to_tokens(" bat")))
# Verify that each string is only 1 token long
for x in animal_bat_context:
  print(tl_model.to_str_tokens(tl_model.to_tokens(x)))
for x in object_bat_context:
  print(tl_model.to_str_tokens(tl_model.to_tokens(x)))
# Verify that each string is only 4 tokens long

['<|endoftext|>', ' animal']
['<|endoftext|>', ' bat']
['<|endoftext|>', 'In', ' the', ' cave', ',']
['<|endoftext|>', 'In', ' the', ' forest', ',']
['<|endoftext|>', 'At', ' the', ' zoo', ',']
['<|endoftext|>', 'In', ' the', ' stadium', ',']
['<|endoftext|>', 'In', ' the', ' arena', ',']
['<|endoftext|>', 'At', ' home', ' plate', ',']


In [49]:
print(tl_model.to_str_tokens(tl_model.to_tokens("In the forest, I detected a bat. Was the bat an object or an animal? It was an")))

['<|endoftext|>', 'In', ' the', ' forest', ',', ' I', ' detected', ' a', ' bat', '.', ' Was', ' the', ' bat', ' an', ' object', ' or', ' an', ' animal', '?', ' It', ' was', ' an']


In [50]:
# View baseline assumption of model
for x in bat_eliciter:
  baseline_completion = tl_model.generate(x, 50, temperature=0)
  print(baseline_completion)
# We see that the model leans toward thinking of " bat" as an animal.
# Therefore, the animal contexts are used for the clean prompts,

VBox(children=(  0%|          | 0/50 [00:00<?, ?it/s],))

Was the bat an animal or an object? It was an animal, but it was not an object.

Was the bat an animal or an object? It was an animal, but it was not an object.

Was the bat an animal or an object? It was an animal, but it


VBox(children=(  0%|          | 0/50 [00:00<?, ?it/s],))

Was the bat an object or an animal? It was an object or an animal.

Was the bat an object or an animal? It was an object or an animal.

Was the bat an object or an animal? It was an object or an animal.

Was the bat an object


# Finding the disambiguating circuit

In [51]:
train_loader, test_loader = load_datasets_from_json(
    model=ac_model,
    path=repo_path_to_abs_path("/content/disambiguation_bat_dataset.json"),
    device=device,
    prepend_bos=True,
    batch_size=16,
    train_test_size=(128, 128),
)

In [52]:
ac_detector_model = patchable_model(
    ac_model,
    factorized=True,
    slice_output="last_seq",
    separate_qkv=True,
    device=device,
)


In [53]:
attribution_scores: PruneScores = mask_gradient_prune_scores(
    model=ac_detector_model,
    dataloader=train_loader,
    official_edges=None,
    grad_function="logit",
    answer_function="avg_diff",
    mask_val=0.0,
)


VBox(children=(          | 0/1 [00:00<?, ?it/s],))

In [54]:
fig = draw_seq_graph(
    ac_detector_model, attribution_scores, 1.1, layer_spacing=True, orientation="v", file_path="/content/bat_circuit.png"
)

# Bow prompt

In [55]:
bow_agents = ["Tom",
              "Clara",
              "Emily",
              "Robert"]

bow_action_context_1 = ["the actor",
                        "the performer",
                        "the dancer"]
bow_action_context_2 = ["with gratefulness.",
                        "to the crowd.",
                        "on the stage."]

bow_object_context_1 = ["the artisan",
                        "the smith",
                        "the tailor"]
bow_object_context_2 = ["in the workshop.",
                        "on the bench.",
                        "carefully and precisely."]

bow_verb = ["made a bow"]

bow_eliciter = ["Was the bow an object or an action? It was an",
                "Was the bow an action or an object? It was an"]

action_bow_contexts = list(itertools.product(bow_action_context_1, bow_action_context_2))
object_bow_contexts = list(itertools.product(bow_object_context_1, bow_object_context_2))
bow_agent_and_eliciter = list(itertools.product(bow_agents, bow_eliciter))
print(*bow_agent_and_eliciter, sep="\n")
print(bow_agent_and_eliciter[0][0])

bow_prompts_dicts = []

for x in bow_agent_and_eliciter:
  for a in action_bow_contexts:
    for o in object_bow_contexts:
      bow_prompt = {
      "clean": x[0]+" "+o[0]+" "+bow_verb[0]+" "+o[1]+" "+x[1],
      "corrupt": x[0]+" "+a[0]+" "+bow_verb[0]+" "+a[1]+" "+x[1],
      "answers": [" object"],
      "wrong_answers": [" action"],
      }
      bow_prompts_dicts.append(bow_prompt)


('Tom', 'Was the bow an object or an action? It was an')
('Tom', 'Was the bow an action or an object? It was an')
('Clara', 'Was the bow an object or an action? It was an')
('Clara', 'Was the bow an action or an object? It was an')
('Emily', 'Was the bow an object or an action? It was an')
('Emily', 'Was the bow an action or an object? It was an')
('Robert', 'Was the bow an object or an action? It was an')
('Robert', 'Was the bow an action or an object? It was an')
Tom


In [56]:

correct_bow_prompt_dicts = [x for x in bow_prompts_dicts if correct_prompt(x["clean"], " object", " action") and len(tl_model.to_str_tokens(tl_model.to_tokens(x["clean"]))) == 23 and len(tl_model.to_str_tokens(tl_model.to_tokens(x["corrupt"]))) == 23]
print(correct_bow_prompt_dicts)

#clean_prompt_dicts = [x for x["clean"] in prompt_dicts if correct_prompt(x["clean"], " animal") == True]

bow_prompt_dicts_json = {
    "prompts": correct_bow_prompt_dicts
}
with open('disambiguation_bow_dataset.json', 'w') as outfile:
    json.dump(bow_prompt_dicts_json, outfile)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
(' arrow', tensor(0.0253))
(' act', tensor(0.0214))
(' important', tensor(0.0204))
(' item', tensor(0.0188))
(' ordinary', tensor(0.0105))
(' art', tensor(0.0094))
(' extremely', tensor(0.0093))
Clara the smith made a bow in the workshop. Was the bow an action or an object? It was an
Correct|  object : tensor(0.2645) 
Incorrect |  action : tensor(0.0965)
(' object', tensor(0.2645))
(' action', tensor(0.0965))
(' arrow', tensor(0.0506))
(' item', tensor(0.0439))
(' important', tensor(0.0144))
(' act', tensor(0.0126))
(' old', tensor(0.0122))
(' interesting', tensor(0.0102))
(' ordinary', tensor(0.0101))
(' instrument', tensor(0.0100))
Clara the smith made a bow on the bench. Was the bow an action or an object? It was an
Correct|  object : tensor(0.2806) 
Incorrect |  action : tensor(0.2461)
(' object', tensor(0.2806))
(' action', tensor(0.2461))
(' arrow', tensor(0.0586))
(' act', tensor(0.0273))
(' item', tensor(0.0171))


In [57]:
len(correct_bow_prompt_dicts)

252

In [58]:
bow_train_loader, bow_test_loader = load_datasets_from_json(
    model=ac_model,
    path=repo_path_to_abs_path("/content/disambiguation_bow_dataset.json"),
    device=device,
    prepend_bos=True,
    batch_size=16,
    train_test_size=(128, 128),
)

In [59]:
bow_attribution_scores: PruneScores = mask_gradient_prune_scores(
    model=ac_detector_model,
    dataloader=bow_train_loader,
    official_edges=None,
    grad_function="logit",
    answer_function="avg_diff",
    mask_val=0.0,
)

VBox(children=(          | 0/1 [00:00<?, ?it/s],))

In [60]:
fig = draw_seq_graph(
    ac_detector_model, bow_attribution_scores, 1.1, layer_spacing=True, orientation="v", file_path="/content/bow_circuit.png"
)

# Board prompt

In [61]:
people_board_havers = ["The charity",
                      "The company",
                      "The organization",
                      "The business",
                     "The group",
                       "The nonprofit",
                       "The foundation",
                       "The agency"]

wood_board_havers = ["The ceiling",
                     "The table",
                     "The cabinet",
                     "The kitchen",
                     "The floor",
                     "The drawer",
                     "The bed",
                     "The bench"]

board_word = ["did have a board.",
              "was under a board.",]

board_eliciter = ["Was the board composed of wood or people? It was composed of",
                "Was the board composed of people or wood? It was composed of"]

board_prompts_dicts = []

for p in people_board_havers:
  for w in wood_board_havers:
    for e in board_eliciter:
      for b in board_word:
        board_prompt = {
        "clean": p+" "+b+" "+e,
        "corrupt": w+" "+b+" "+e,
        "answers": [" people"],
        "wrong_answers": [" wood"],
        }
        board_prompts_dicts.append(board_prompt)
print(len(board_prompts_dicts))

256


In [62]:
correct_board_prompt_dicts = [x for x in board_prompts_dicts if correct_prompt(x["clean"], " people", " wood") and len(tl_model.to_str_tokens(tl_model.to_tokens(x["clean"]))) == 21 and len(tl_model.to_str_tokens(tl_model.to_tokens(x["corrupt"]))) == 21]
print(correct_board_prompt_dicts)

#clean_prompt_dicts = [x for x["clean"] in prompt_dicts if correct_prompt(x["clean"], " animal") == True]

board_prompt_dicts_json = {
    "prompts": correct_board_prompt_dicts
}
with open('disambiguation_board_dataset.json', 'w') as outfile:
    json.dump(board_prompt_dicts_json, outfile)

The charity did have a board. Was the board composed of wood or people? It was composed of
Correct|  people : tensor(0.3044) 
Incorrect |  wood : tensor(0.0587)
(' people', tensor(0.3044))
(' wood', tensor(0.0587))
(' a', tensor(0.0472))
(' the', tensor(0.0248))
(' two', tensor(0.0172))
(' members', tensor(0.0148))
(' some', tensor(0.0133))
(' men', tensor(0.0125))
(' volunteers', tensor(0.0117))
(' three', tensor(0.0113))
The charity was under a board. Was the board composed of wood or people? It was composed of
Correct|  people : tensor(0.3086) 
Incorrect |  wood : tensor(0.0677)
(' people', tensor(0.3086))
(' wood', tensor(0.0677))
(' a', tensor(0.0427))
(' the', tensor(0.0249))
(' men', tensor(0.0185))
(' two', tensor(0.0176))
(' members', tensor(0.0154))
(' three', tensor(0.0112))
(' some', tensor(0.0109))
(' many', tensor(0.0108))
The charity did have a board. Was the board composed of people or wood? It was composed of
Correct|  people : tensor(0.4726) 
Incorrect |  wood : tenso

In [63]:
board_train_loader, board_test_loader = load_datasets_from_json(
    model=ac_model,
    path=repo_path_to_abs_path("/content/disambiguation_board_dataset.json"),
    device=device,
    prepend_bos=True,
    batch_size=16,
    train_test_size=(128, 128),
)
board_attribution_scores: PruneScores = mask_gradient_prune_scores(
    model=ac_detector_model,
    dataloader=board_train_loader,
    official_edges=None,
    grad_function="logit",
    answer_function="avg_diff",
    mask_val=0.0,
)

VBox(children=(          | 0/1 [00:00<?, ?it/s],))

In [64]:
fig = draw_seq_graph(
    ac_detector_model, board_attribution_scores, 3.0, layer_spacing=True, orientation="v", file_path="/content/board_circuit.png"
)

# Limitations

In [65]:
print(tl_model.generate("On the boat, I saw a seal in the sea. Was the seal an object or an animal? It was an", 5, temperature=0))
print(tl_model.generate("On the boat, I saw a seal in the sea. Was the seal an animal or an object? It was an", 5, temperature=0))
print(tl_model.generate("In the market, I saw a seal on a box. Was the seal an object or an animal? It was an", 5, temperature=0))
print(tl_model.generate("In the market, I saw a seal on a box. Was the seal an animal or an object? It was an", 5, temperature=0))

VBox(children=(  0%|          | 0/5 [00:00<?, ?it/s],))

On the boat, I saw a seal in the sea. Was the seal an object or an animal? It was an animal. I asked the


VBox(children=(  0%|          | 0/5 [00:00<?, ?it/s],))

On the boat, I saw a seal in the sea. Was the seal an animal or an object? It was an animal, but it was


VBox(children=(  0%|          | 0/5 [00:00<?, ?it/s],))

In the market, I saw a seal on a box. Was the seal an object or an animal? It was an animal. I was curious


VBox(children=(  0%|          | 0/5 [00:00<?, ?it/s],))

In the market, I saw a seal on a box. Was the seal an animal or an object? It was an animal. I was curious
