<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [None]:
# If running in colab

In [2]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/memit
git clone https://github.com/kmeng01/memit memit > install.log 2>&1
pip install -r /content/memit/scripts/colab_reqs/rome.txt >> install.log 2>&1
pip install --upgrade google-cloud-storage >> install.log 2>&1

In [3]:


IS_COLAB = True
try:
    import google.colab
    import torch
    import os

    IS_COLAB = True
    os.chdir("/content/memit")
    if not torch.cuda.is_available():
        raise Exception("Change runtime type to include a GPU.")
except ModuleNotFoundError as _:
    pass

## Causal Tracing

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
!pip install transformers
!pip install datasets
!pip install accelerate

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Using cached dill-0.3.7-py3-none-any.whl (115 kB)
Collecting multiprocess (from datasets)
  Using cached multiprocess-0.70.15-py310-none-any.whl (134 kB)
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6
Collecting accelerate
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.24.1


In [6]:
import os, re, json
import matplotlib.pyplot as plt
import torch, numpy
from collections import defaultdict
from util import nethook
from util.globals import DATA_DIR
from experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)
from dsets import KnownsDataset

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7b208821b910>

In [7]:
model_name = "gpt2-xl"  
mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=IS_COLAB,
    torch_dtype=(torch.float16 if "20b" in model_name else None),
)

config.json:   0%|          | 0.00/689 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/6.43G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [9]:
predict_token(
      mt,
      ['The capital of Egypt is'],
      return_p=True,
    )

([' Cairo'], tensor([0.3710], device='cuda:0'))

In [10]:
with open('/content/correct_capitals_predicted.txt', 'r') as file:
    sentences = [line.strip() for line in file]

c = 0
unique_sentence = []
for s in sentences:
  s1 = ' '.join(s.split()[:-1])
  ans = predict_token(
      mt,
      [s1],
      return_p=True,
    )[0][0]
  print(s1 + ans)

Kabul is the capital of Afghanistan
Tirana is the capital of Albania
Algiers is the capital of Algeria
Pago Pago is the capital of American Samoa
Luanda is the capital of Angola
St. John's is the capital of Antigua And Barb
Buenos Aires is the capital of Argentina
Yerevan is the capital of Armenia
Canberra is the capital of Australia
Vienna is the capital of Austria
Baku is the capital of Azerbaijan
Dhaka is the capital of Bangladesh
Brussels is the capital of Belgium
Sarajevo is the capital of Bosnia and Herz
Brasilia is the capital of Brazil
Diego Garcia is the capital of British Indian Ocean Territory
Sofia is the capital of Bulgaria
Ouagadougou is the capital of Burkina Fas
Phnom Penh is the capital of Cambodia
Ottawa is the capital of Canada
Praia is the capital of Cape Ver
George Town is the capital of Cayman Islands
Bangui is the capital of Central African Republic
Beijing is the capital of China
Flying Fish Cove is the capital of Christmas Island
West Island is the capital of C

In [12]:
knowns = KnownsDataset(DATA_DIR)  # Dataset of known facts
noise_level = 3 * collect_embedding_std(mt, [k["subject"] for k in knowns])
print(f"Using noise level {noise_level}")


data/known_1000.json does not exist. Downloading from https://memit.baulab.info/data/dsets/known_1000.json


100%|██████████| 335k/335k [00:00<00:00, 776kB/s]


Loaded dataset with 1209 elements
Using noise level 0.13462981581687927


In [13]:
knowns.data

[{'known_id': 0,
  'subject': 'Vinson Massif',
  'attribute': 'Antarctica',
  'template': '{} is located in the continent',
  'prediction': ' of Antarctica. It is the largest of the three',
  'prompt': 'Vinson Massif is located in the continent of',
  'relation_id': 'P30'},
 {'known_id': 1,
  'subject': 'Beats Music',
  'attribute': 'Apple',
  'template': '{} is owned by',
  'prediction': ' Apple, which is also the owner of Beats Electronics',
  'prompt': 'Beats Music is owned by',
  'relation_id': 'P127'},
 {'known_id': 2,
  'subject': 'Audible.com',
  'attribute': 'Amazon',
  'template': '{} is owned by',
  'prediction': ' Amazon.com, Inc. or its affiliates.',
  'prompt': 'Audible.com is owned by',
  'relation_id': 'P127'},
 {'known_id': 3,
  'subject': 'The Big Bang Theory',
  'attribute': 'CBS',
  'template': '{} premieres on',
  'prediction': ' CBS on September 22.<|endoftext|>',
  'prompt': 'The Big Bang Theory premieres on',
  'relation_id': 'P449'},
 {'known_id': 4,
  'subject'

In [14]:
def extract_prompts_by_relation(knowns, relation_id):
    prompts = [known['prompt'] for known in knowns if known['relation_id'] == relation_id]
    return prompts


relation_id = 'P112'  # Change this to the relation_id you want to extract prompts for
prompts = extract_prompts_by_relation(knowns.data, relation_id)
print(len(prompts))

0


In [15]:
def trace_with_patch(
    model,  # The model
    inp,  # A set of inputs
    states_to_patch,  # A list of (token index, layername) triples to restore
    answers_t,  # Answer probabilities to collect
    tokens_to_mix,  # Range of tokens to corrupt (begin, end)
    noise=0.1,  # Level of noise to add
    trace_layers=None,  # List of traced outputs to return
):
    prng = numpy.random.RandomState(1)  # For reproducibility, use pseudorandom noise
    patch_spec = defaultdict(list)
    for t, l in states_to_patch:
        patch_spec[l].append(t)
    embed_layername = layername(model, 0, "embed")

    def untuple(x):
        return x[0] if isinstance(x, tuple) else x

    # Define the model-patching rule.
    def patch_rep(x, layer):
        if layer == embed_layername:
            # If requested, we corrupt a range of token embeddings on batch items x[1:]
            if tokens_to_mix is not None:
                b, e = tokens_to_mix
                x[1:, b:e] += noise * torch.from_numpy(
                    prng.randn(x.shape[0] - 1, e - b, x.shape[2])
                ).to(x.device)
            return x
        if layer not in patch_spec:
            return x
        # If this layer is in the patch_spec, restore the uncorrupted hidden state
        # for selected tokens.
        h = untuple(x)
        for t in patch_spec[layer]:
            h[1:, t] = h[0, t]
        return x

    # With the patching rules defined, run the patched model in inference.
    additional_layers = [] if trace_layers is None else trace_layers
    with torch.no_grad(), nethook.TraceDict(
        model,
        [embed_layername] + list(patch_spec.keys()) + additional_layers,
        edit_output=patch_rep,
    ) as td:
        outputs_exp = model(**inp)

    # We report softmax probabilities for the answers_t token predictions of interest.
    probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]

    # If tracing all layers, collect all activations together to return.
    if trace_layers is not None:
        all_traced = torch.stack(
            [untuple(td[layer].output).detach().cpu() for layer in trace_layers], dim=2
        )
        return probs, all_traced

    return probs

In [18]:
def calculate_hidden_flow(
    mt, prompt, subject, samples=10, noise=0.1, window=10, kind=None
):
    """
    Runs causal tracing over every token/layer combination in the network
    and returns a dictionary numerically summarizing the results.
    """
    inp = make_inputs(mt.tokenizer, [prompt] * (samples + 1))
    with torch.no_grad():
        answer_t, base_score = [d[0] for d in predict_from_input(mt.model, inp)]
    [answer] = decode_tokens(mt.tokenizer, [answer_t])
    e_range = find_token_range(mt.tokenizer, inp["input_ids"][0], subject)
    low_score = trace_with_patch(
        mt.model, inp, [], answer_t, e_range, noise=noise
    ).item()
    subj_prob = 0

    differences, subj_prob = trace_important_window(
        mt.model,
        mt.num_layers,
        inp,
        e_range,
        answer_t,
        noise=noise,
        window=window,
        kind=kind,
    )
    differences = differences.detach().cpu()
   
    return (dict(
    scores=differences,
    low_score=low_score,
    high_score=base_score,
    input_ids=inp["input_ids"][0],
    input_tokens=decode_tokens(mt.tokenizer, inp["input_ids"][0]),
    subject_range=e_range,
    answer=answer,
    window=window,
    kind=kind or "",

), subj_prob)



def trace_important_states(model, num_layers, inp, e_range, answer_t, noise=0.1):
    ntoks = inp["input_ids"].shape[1]
    table = []
    for tnum in range(ntoks):
        row = []
        for layer in range(0, num_layers):
            r = trace_with_patch(
                model,
                inp,
                [(tnum, layername(model, layer))],
                answer_t,
                tokens_to_mix=e_range,
                noise=noise,
            )


            row.append(r)

        table.append(torch.stack(row))

    return torch.stack(table)


def trace_important_window(
    model, num_layers, inp, e_range, answer_t, kind, window=10, noise=0.1
):
    ntoks = inp["input_ids"].shape[1]
    table = []
    for tnum in range(ntoks):
        row = []
        for layer in range(0, num_layers):
            layerlist = [
                (tnum, layername(model, L, kind))
                for L in range(
                    max(0, layer - window // 2), min(num_layers, layer - (-window // 2))
                )
            ]
            r = trace_with_patch(
                model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise
            )

            row.append(r)
        table.append(torch.stack(row))
    
    return torch.stack(table), torch.stack(table)[e_range[1] - 1]

## Code to plot the results

In [19]:
def plot_hidden_flow(
    mt,
    prompt,
    subject=None,
    samples=10,
    noise=0.1,
    window=10,
    kind=None,
    modelname=None,
    savepdf=None,
):

    if subject is None:
        subject = guess_subject(prompt)

    result, subj_prob = calculate_hidden_flow(
      mt, prompt, subject, samples=samples, noise=noise, window=window, kind=kind
    )

    return subj_prob




def plot_all_flow(mt, prompt, subject=None, noise=0.1, modelname=None):
    subj_prob = 0
    for kind in ["mlp"]: #[None, "mlp", "attn"]:
        subj_prob = plot_hidden_flow(
            mt, prompt, subject, modelname=modelname, noise=noise, kind=kind
        )
    return subj_prob



In [None]:
with open('/content/correct_capitals_predicted.txt', 'r') as file:
    sentences = [line.strip() for line in file]
import pandas as pd
c = 0
unique_sentence = []
#subj_prob_list = []
df = pd.DataFrame({})
for s in sentences:

  prompt = s.split('is the capital of')[0] + 'is the capital of'
  print(prompt)
  val = plot_all_flow(mt, prompt, noise=noise_level)
  #print(val)
  temp_df = pd.DataFrame(val.cpu().numpy()).transpose()
  df = pd.concat([df, temp_df])
  df.to_csv('df_causal_trace_capitals.csv')



Kabul is the capital of
Tirana is the capital of
Algiers is the capital of


In [None]:
df.index = range(len(df))
df

In [None]:
# df.to_csv('df_causal_trace_capitals.csv')