<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [1]:
IS_COLAB = False

## Causal Tracing

A demonstration of the double-intervention causal tracing method.

The strategy used by causal tracing is to understand important
states within a transfomer by doing two interventions simultaneously:

1. Corrupt a subset of the input.  In our paper, we corrupt the subject tokens
   to frustrate the ability of the transformer to accurately complete factual
   prompts about the subject.
2. Restore a subset of the internal hidden states.  In our paper, we scan
   hidden states at all layers and all tokens, searching for individual states
   that carry the necessary information for the transformer to recover its
   capability to complete the factual prompt.

The traces of decisive states can be shown on a heatmap.  This notebook
demonstrates the code for conducting causal traces and creating these heatmaps.

In [2]:
%load_ext autoreload
%autoreload 2

The `experiments.causal_trace` module contains a set of functions for running causal traces.

In this notebook, we reproduce, demonstrate and discuss the interesting functions.

We begin by importing several utility functions that deal with tokens and transformer models.

In [3]:
import os, sys, re, json
import string
import torch
import numpy as np
import copy
from collections import defaultdict, Counter
from util import nethook
from util.globals import DATA_DIR
from experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)
from dsets import KnownsDataset

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7ffa884fcbe0>

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## USKG

In [7]:
from transformers import (
    HfArgumentParser,
    set_seed,
    AutoTokenizer
)
from transformers import AutoModelForPreTraining

# from uskg.models.unified.prefixtuning import Model
from uskg.models.unified import finetune, prefixtuning
from uskg.utils.configue import Configure
from uskg.utils.training_arguments import WrappedSeq2SeqTrainingArguments
from uskg.seq2seq_construction import spider as s2s_spider
from uskg.third_party.spider.preprocess.get_tables import dump_db_json_schema
from uskg.third_party.spider import evaluation as sp_eval
from tqdm.notebook import tqdm

# from nltk.stem.wordnet import WordNetLemmatizer
# import stanza

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

import matplotlib.pyplot as plt
import sqlite3
import ujson
import pickle

from experiments import causal_trace_uskg as ctu

## Temp

In [8]:
gpt2_model = AutoModelForPreTraining.from_pretrained('gpt2')

In [10]:
gpt2_tokenizer = AutoTokenizer.from_pretrained('gpt2')

In [11]:
gpt2_tokenizer

PreTrainedTokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_len=1024, is_fast=True, padding_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'})

In [12]:
gpt2_tokenizer.eos_token_id

50256

In [18]:
_toks = gpt2_tokenizer.tokenize('x ; SQL: Select')
_toks

['x', 'Ġ;', 'ĠSQL', ':', 'ĠSelect']

In [19]:
gpt2_tokenizer.convert_tokens_to_string(_toks)

'x ; SQL: Select'

In [20]:
gpt2_tokenizer.convert_tokens_to_ids(_toks)

[87, 2162, 16363, 25, 9683]

In [30]:
_sql = "select role_code from project_staff where date_from > '2003-04-19 15:06:20' and date_to < '2016-03-15 00:33:18'"
_toks = gpt2_tokenizer.tokenize(_sql)

In [31]:
print(_toks)

['select', 'Ġrole', '_', 'code', 'Ġfrom', 'Ġproject', '_', 'staff', 'Ġwhere', 'Ġdate', '_', 'from', 'Ġ>', "Ġ'", '2003', '-', '04', '-', '19', 'Ġ15', ':', '06', ':', '20', "'", 'Ġand', 'Ġdate', '_', 'to', 'Ġ<', "Ġ'", '2016', '-', '03', '-', '15', 'Ġ00', ':', '33', ':', '18', "'"]


In [32]:
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

In [33]:
gpt2_tokenizer.convert_tokens_to_string(_toks)

"select role_code from project_staff where date_from > '2003-04-19 15:06:20' and date_to < '2016-03-15 00:33:18'"

In [34]:
len(_toks)

42