<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [1]:
IS_COLAB = False

## Causal Tracing

A demonstration of the double-intervention causal tracing method.

The strategy used by causal tracing is to understand important
states within a transfomer by doing two interventions simultaneously:

1. Corrupt a subset of the input.  In our paper, we corrupt the subject tokens
   to frustrate the ability of the transformer to accurately complete factual
   prompts about the subject.
2. Restore a subset of the internal hidden states.  In our paper, we scan
   hidden states at all layers and all tokens, searching for individual states
   that carry the necessary information for the transformer to recover its
   capability to complete the factual prompt.

The traces of decisive states can be shown on a heatmap.  This notebook
demonstrates the code for conducting causal traces and creating these heatmaps.

In [2]:
%load_ext autoreload
%autoreload 2

The `experiments.causal_trace` module contains a set of functions for running causal traces.

In this notebook, we reproduce, demonstrate and discuss the interesting functions.

We begin by importing several utility functions that deal with tokens and transformer models.

In [3]:
import os, sys, re, json
import string
import torch
import numpy as np
import copy
from collections import defaultdict, Counter
from util import nethook
from util.globals import DATA_DIR
from experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)
from dsets import KnownsDataset

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fd1ec0d1bb0>

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

## USKG

In [20]:
from transformers import (
    HfArgumentParser,
    set_seed,
    AutoTokenizer
)
from transformers import AutoModelForPreTraining

import uskg
# from uskg.models.unified.prefixtuning import Model
from uskg.models.unified import finetune, prefixtuning
from uskg.utils.configue import Configure
from uskg.utils.training_arguments import WrappedSeq2SeqTrainingArguments
from uskg.seq2seq_construction import spider as s2s_spider
from uskg.third_party.spider.preprocess.get_tables import dump_db_json_schema
from uskg.third_party.spider import evaluation as sp_eval

from uskg.utils.dataset import gpt2_construct_input

from tqdm.notebook import tqdm


# from nltk.stem.wordnet import WordNetLemmatizer
# import stanza

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

import matplotlib.pyplot as plt
import sqlite3
import ujson
import pickle

# import experiments
from experiments import causal_trace_uskg as ctu
from experiments import causal_trace_uskg_gpt2 as ctu2

# import util

In [8]:
# Try to re-import to solve the class mismatch problem, but still not working...

# import importlib

# importlib.reload(uskg.models.prompt.modeling_auto)
# importlib.reload(uskg.models.prompt.modeling_gpt2)
# importlib.reload(uskg.models.unified)
# importlib.reload(experiments)
# importlib.reload(ctu)
# importlib.reload(util)
# importlib.reload(util.uskg)
# importlib.reload(uskg.utils)

# from uskg.models.unified import finetune, prefixtuning

# from uskg.models.prompt.modeling_auto import AutoModelForPreTraining
# from uskg.models.prompt.modeling_gpt2 import GPT2LMHeadModel
# from uskg.models.prompt.modeling_t5 import T5ForConditionalGeneration

# from experiments import causal_trace_uskg as ctu

In [9]:
from util import uskg as uu
from util import uskg_gpt2 as uu2

### Loading model

In [10]:
# uskg_gpt2_model, tokenizer_uskg, tokenizer_fast, training_args, model_args, task_args = ctu.load_model_uskg('gpt2-prefix', untie_embeddings=False)
mt_uskg_gpt2 = ctu.ModelAndTokenizer_USKG_GPT2('gpt2-medium-prefix')

Using tokenizer_uskg: /home/yshao/Projects/UnifiedSKG/output/A-GPT2_medium_prefix_spider_with_cell_value-pfx=20/run-20231113/checkpoint-23500
Using tokenizer_fast: gpt2-medium
gpt2-medium
prefix-tuning sequence length is 20.


In [None]:
mt_uskg_gpt2.model

In [None]:
# [k for k, v in mt_uskg_gpt2.model.named_modules()]

### Loading dataset

In [12]:
spider_train_path = '/home/yshao/Projects/SDR-analysis/data/spider/train+ratsql_graph.json'
spider_dev_path = '/home/yshao/Projects/SDR-analysis/data/spider/dev+ratsql_graph.json'
spider_db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [13]:
raw_spider_dev = ctu.load_raw_dataset(
    data_filepath = spider_dev_path,
    db_path=spider_db_dir,
#     schema_cache=SCHEMA_CACHE
)
len(raw_spider_dev)

1034

In [14]:
raw_spider_dev[0].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph'])

In [15]:
mt_uskg_gpt2.task_args.dataset.use_cache

True

In [16]:
processed_spider_dev = s2s_spider.DevDataset(
    args=mt_uskg_gpt2.task_args,
    raw_datasets=raw_spider_dev,
    cache_root='../cache')

In [17]:
_id = 130
processed_spider_dev[_id]['text_in'], \
processed_spider_dev[_id]['struct_in'], \
processed_spider_dev[_id]['seq_out']

('What are the names of all European countries with at least 3 manufacturers?',
 '| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.countryname from countries as t1 join continents as t2 on t1.continent = t2.contid join car_makers as t3 on t1.countryid = t3.country where t2.continent = 'europe' group by t1.countryname having count(*) >= 3;")

In [18]:
_enc_sentence = f"{processed_spider_dev[_id]['text_in']}; structed knowledge: {processed_spider_dev[_id]['struct_in']}"
_toks = mt_uskg_gpt2.tokenizer.tokenize(_enc_sentence)
len(_toks)

102

### Utils

#### Evaluator

In [24]:
table_path = '/home/yshao/Projects/language/language/xsp/data/spider/tables.json'
db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [25]:
kmaps = sp_eval.build_foreign_key_map_from_json(table_path)
evaluator = sp_eval.Evaluator(db_dir=db_dir, kmaps=kmaps, etype='all')

In [26]:
ctu.evaluate_hardness.evaluator = evaluator

## Testing

### find_text_struct_in_range_gpt2()

In [80]:
_id = 130
ex = processed_spider_dev[_id]
seq_in = f"{ex['text_in']}; structed knowledge: {ex['struct_in']}"
seq_out = ex['seq_out']
_test_seq_input = gpt2_construct_input(seq_in, seq_out, mt_uskg_gpt2.tokenizer,
                                       in_maxlen=362,
                                       out_maxlen=128,
                                       padding=False)

In [81]:
_test_seq = mt_uskg_gpt2.tokenizer.decode(_test_seq_input['input_ids'])
_test_seq

"What are the names of all European countries with at least 3 manufacturers?; structed knowledge: | car_1 | continents : contid, continent ( europe ) | countries : countryid, countryname, continent | car_makers : id, maker, fullname, country | model_list : modelid, maker, model | car_names : makeid, model, make | cars_data : id, mpg, cylinders, edispl, horsepower, weight, accelerate, year ; SQL: select t1.countryname from countries as t1 join continents as t2 on t1.continent = t2.contid join car_makers as t3 on t1.countryid = t3.country where t2.continent = 'europe' group by t1.countryname having count(*) >= 3; END OF SQL"

In [82]:
_toks = uu.decode_tokens(mt_uskg_gpt2.tokenizer, _test_seq_input['input_ids'])
print(_toks)

['What', ' are', ' the', ' names', ' of', ' all', ' European', ' countries', ' with', ' at', ' least', ' 3', ' manufacturers', '?', ';', ' struct', 'ed', ' knowledge', ':', ' |', ' car', '_', '1', ' |', ' continents', ' :', ' cont', 'id', ',', ' continent', ' (', ' euro', 'pe', ' )', ' |', ' countries', ' :', ' country', 'id', ',', ' country', 'name', ',', ' continent', ' |', ' car', '_', 'makers', ' :', ' id', ',', ' maker', ',', ' full', 'name', ',', ' country', ' |', ' model', '_', 'list', ' :', ' model', 'id', ',', ' maker', ',', ' model', ' |', ' car', '_', 'names', ' :', ' make', 'id', ',', ' model', ',', ' make', ' |', ' cars', '_', 'data', ' :', ' id', ',', ' m', 'pg', ',', ' cylinders', ',', ' ed', 'is', 'pl', ',', ' horsepower', ',', ' weight', ',', ' accelerate', ',', ' year', ' ;', ' SQL', ':', ' select', ' t', '1', '.', 'country', 'name', ' from', ' countries', ' as', ' t', '1', ' join', ' continents', ' as', ' t', '2', ' on', ' t', '1', '.', 'cont', 'inent', ' =', ' t', '

In [83]:
text_range, struct_range, out_range = uu2.find_text_struct_in_range_gpt2(mt_uskg_gpt2.tokenizer, _test_seq_input['input_ids'])

In [84]:
text_range, struct_range, out_range

((0, 14), (19, 102), (105, 176))

In [85]:
text_st, text_ed = text_range
struct_st, struct_ed = struct_range
out_st, out_ed = out_range

''.join(_toks[text_st : text_ed]), \
''.join(_toks[struct_st : struct_ed]), \
''.join(_toks[out_st : out_ed])

('What are the names of all European countries with at least 3 manufacturers?',
 ' | car_1 | continents : contid, continent ( europe ) | countries : countryid, countryname, continent | car_makers : id, maker, fullname, country | model_list : modelid, maker, model | car_names : makeid, model, make | cars_data : id, mpg, cylinders, edispl, horsepower, weight, accelerate, year',
 " select t1.countryname from countries as t1 join continents as t2 on t1.continent = t2.contid join car_makers as t3 on t1.countryid = t3.country where t2.continent = 'europe' group by t1.countryname having count(*) >= 3;")

In [86]:
''.join(_toks[36 : 40])

' : countryid,'

### create_analysis_sample_dicts() (& syntax)

In [215]:
_id = 222
ex = copy.deepcopy(processed_spider_dev[_id])

In [216]:
ex['text_in'], ex['struct_in'], ex['seq_out']

('Which city has the most frequent destination airport?',
 '| flight_2 | airlines : uid , airline , abbreviation , country | airports : city , airportcode , airportname , country , countryabbrev | flights : airline , flightno , sourceairport , destairport',
 'select t1.city from airports as t1 join flights as t2 on t1.airportcode = t2.destairport group by t1.city order by count(*) desc limit 1')

In [217]:
a_ex_list = ctu2.create_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex, subject_type='column')

In [218]:
a_ex_list[0].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'tokenized_item', 'pre_sql_sequence', 'text_range', 'struct_range', 'sql_range', 'parsed_struct_in', 'alias2table', 'col2table', 'col_name_counter', 'tab_name_counter', 'struct_node_ranges_dict', 'sql_tokens', 'sql_token_ranges', 'tok_ranges2type', 'type2tok_ranges', 'sql_col_nodes', 'sql_tab_nodes', 'sql_alias_nodes', 'dec_prompt', 'expect', 'expect_type', 'remove_struct_duplicate_nodes', 'token_ranges_dict', 'node_name_ranges', 'expect_input_ranges', 'self_ranges', 'context_ranges', 'category'])

In [219]:
a_ex_list[0]['alias2table']

{'t1': 'airports', 't2': 'flights'}

In [220]:
a_ex_list[0]['type2tok_ranges']

defaultdict(list,
            {'syntax': [(0, 6),
              (15, 19),
              (29, 31),
              (35, 39),
              (48, 50),
              (54, 56),
              (72, 73),
              (89, 94),
              (95, 97),
              (106, 111),
              (112, 114),
              (115, 120),
              (120, 121),
              (121, 122),
              (122, 123),
              (124, 128),
              (129, 134)],
             'table_alias': [(7, 10),
              (32, 34),
              (51, 53),
              (57, 60),
              (74, 77),
              (98, 101)],
             'column': [(10, 14), (60, 71), (77, 88), (101, 105)],
             'table': [(20, 28), (40, 47)],
             'val': [(135, 136)]})

In [221]:
_toks = mt_uskg_gpt2.tokenizer.tokenize(a_ex_list[0]['seq_out'])
len(_toks)

43

In [222]:
_toks[7:10]

['Ġas', 'Ġt', '1']

In [223]:
a_ex_list[0]['seq_out'][7:10]

't1.'

In [29]:
a_ex_list = ctu2.create_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex, subject_type='table')

In [None]:
a_ex_list[0]

In [31]:
a_ex_list = ctu2.create_syntax_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex)

In [None]:
a_ex_list[0]

### add_clean_prediction()

In [43]:
_id = 130
ex = copy.deepcopy(processed_spider_dev[_id])
a_ex_list = ctu2.create_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex, subject_type='column')
a_ex = a_ex_list[0]

In [44]:
a_ex['dec_prompt'], a_ex['expect']

('What are the names of all European countries with at least 3 manufacturers?; structed knowledge: | car_1 | continents : contid, continent ( europe ) | countries : countryid, countryname, continent | car_makers : id, maker, fullname, country | model_list : modelid, maker, model | car_names : makeid, model, make | cars_data : id, mpg, cylinders, edispl, horsepower, weight, accelerate, year ; SQL: select t1.',
 'countryname')

In [45]:
a_ex = ctu2.add_clean_prediction_gpt2(mt_uskg_gpt2, a_ex, samples=2)

In [46]:
a_ex['answer_len'], \
a_ex['base_score'], \
a_ex['answers_t'], \
a_ex['answer'], \
a_ex['correct_prediction']

(2,
 0.9999053478240967,
 tensor([19315,  3672], device='cuda:0'),
 'countryname',
 True)

### find_struct_name_ranges()

In [154]:
_id = 130
ex = copy.deepcopy(processed_spider_dev[_id])
ctu2.add_basic_analysis_info_gpt2(mt_uskg_gpt2, ex)

In [155]:
ex.keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'tokenized_item', 'pre_sql_sequence', 'text_range', 'struct_range', 'sql_range', 'parsed_struct_in', 'alias2table', 'col2table', 'col_name_counter', 'tab_name_counter', 'struct_node_ranges_dict', 'sql_tokens', 'sql_token_ranges', 'tok_ranges2type', 'type2tok_ranges', 'sql_col_nodes', 'sql_tab_nodes', 'sql_alias_nodes'])

In [156]:
struct_node_ranges_dict = uu2.find_struct_name_ranges_gpt2(mt_uskg_gpt2.tokenizer, ex)

In [157]:
struct_node_ranges_dict

{'db_id_ranges': defaultdict(list, {'car_1': [(20, 23)]}),
 'table_name_ranges': defaultdict(list,
             {'continents': [(24, 25)],
              'countries': [(35, 36)],
              'car_makers': [(45, 48)],
              'model_list': [(58, 61)],
              'car_names': [(69, 72)],
              'cars_data': [(80, 83)]}),
 'col_name_ranges': defaultdict(list,
             {'contid': [(26, 28)],
              'continent': [(29, 34), (43, 44)],
              'countryid': [(37, 39)],
              'countryname': [(40, 42)],
              'id': [(49, 50), (84, 85)],
              'maker': [(51, 52), (65, 66)],
              'fullname': [(53, 55)],
              'country': [(56, 57)],
              'modelid': [(62, 64)],
              'model': [(67, 68), (76, 77)],
              'makeid': [(73, 75)],
              'make': [(78, 79)],
              'mpg': [(86, 88)],
              'cylinders': [(89, 90)],
              'edispl': [(91, 94)],
              'horsepower': [(95, 96)

In [158]:
_toks = mt_uskg_gpt2.tokenizer.convert_ids_to_tokens(ex['tokenized_item']['input_ids'])
print(_toks)

['What', 'Ġare', 'Ġthe', 'Ġnames', 'Ġof', 'Ġall', 'ĠEuropean', 'Ġcountries', 'Ġwith', 'Ġat', 'Ġleast', 'Ġ3', 'Ġmanufacturers', '?', ';', 'Ġstruct', 'ed', 'Ġknowledge', ':', 'Ġ|', 'Ġcar', '_', '1', 'Ġ|', 'Ġcontinents', 'Ġ:', 'Ġcont', 'id', 'Ġ,', 'Ġcontinent', 'Ġ(', 'Ġeuro', 'pe', 'Ġ)', 'Ġ|', 'Ġcountries', 'Ġ:', 'Ġcountry', 'id', 'Ġ,', 'Ġcountry', 'name', 'Ġ,', 'Ġcontinent', 'Ġ|', 'Ġcar', '_', 'makers', 'Ġ:', 'Ġid', 'Ġ,', 'Ġmaker', 'Ġ,', 'Ġfull', 'name', 'Ġ,', 'Ġcountry', 'Ġ|', 'Ġmodel', '_', 'list', 'Ġ:', 'Ġmodel', 'id', 'Ġ,', 'Ġmaker', 'Ġ,', 'Ġmodel', 'Ġ|', 'Ġcar', '_', 'names', 'Ġ:', 'Ġmake', 'id', 'Ġ,', 'Ġmodel', 'Ġ,', 'Ġmake', 'Ġ|', 'Ġcars', '_', 'data', 'Ġ:', 'Ġid', 'Ġ,', 'Ġm', 'pg', 'Ġ,', 'Ġcylinders', 'Ġ,', 'Ġed', 'is', 'pl', 'Ġ,', 'Ġhorsepower', 'Ġ,', 'Ġweight', 'Ġ,', 'Ġaccelerate', 'Ġ,', 'Ġyear', 'Ġ;', 'ĠSQL', ':', 'Ġselect', 'Ġt', '1', '.', 'country', 'name', 'Ġfrom', 'Ġcountries', 'Ġas', 'Ġt', '1', 'Ġjoin', 'Ġcontinents', 'Ġas', 'Ġt', '2', 'Ġon', 'Ġt', '1', '.', 'cont', 'inen

In [159]:
for k, v in struct_node_ranges_dict['col_name_ranges'].items():
    print(k)
    for s, e in v:
        print(_toks[s:e])
    print('-'*20)

contid
['Ġcont', 'id']
--------------------
continent
['Ġcontinent', 'Ġ(', 'Ġeuro', 'pe', 'Ġ)']
['Ġcontinent']
--------------------
countryid
['Ġcountry', 'id']
--------------------
countryname
['Ġcountry', 'name']
--------------------
id
['Ġid']
['Ġid']
--------------------
maker
['Ġmaker']
['Ġmaker']
--------------------
fullname
['Ġfull', 'name']
--------------------
country
['Ġcountry']
--------------------
modelid
['Ġmodel', 'id']
--------------------
model
['Ġmodel']
['Ġmodel']
--------------------
makeid
['Ġmake', 'id']
--------------------
make
['Ġmake']
--------------------
mpg
['Ġm', 'pg']
--------------------
cylinders
['Ġcylinders']
--------------------
edispl
['Ġed', 'is', 'pl']
--------------------
horsepower
['Ġhorsepower']
--------------------
weight
['Ġweight']
--------------------
accelerate
['Ġaccelerate']
--------------------
year
['Ġyear']
--------------------


In [160]:
ex['struct_in'], ex['parsed_struct_in']

('| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 ((1, 'car_1', 'car_1'),
  [((3, 'continents', 'continents'),
    [[(5, 'contid', 'contid'), []],
     [(7, 'continent', 'continent ( europe )'), [(9, 'europe', 'europe')]]]),
   ((12, 'countries', 'countries'),
    [[(14, 'countryid', 'countryid'), []],
     [(16, 'countryname', 'countryname'), []],
     [(18, 'continent', 'continent'), []]]),
   ((20, 'car_makers', 'car_makers'),
    [[(22, 'id', 'id'), []],
     [(24, 'maker', 'maker'), []],
     [(26, 'fullname', 'fullname'), []],
     [(28, 'country', 'country'), []]]),
   ((30, 'model_list', 'model_list'),
    [[(32, 'modelid', 'modelid'), []],
     [(34, 'maker', 'maker'), []],
     [(36, 'model', 'model'), []]]),
   ((

### test exp2

In [168]:
from experiments.exp_2_gpt2 import trace_exp2_section_corrupt_restore_gpt2

In [190]:
_id = 10
ex = copy.deepcopy(processed_spider_dev[_id])
a_ex_list = ctu2.create_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex, subject_type='column')
len(a_ex_list)

2

In [191]:
a_ex = a_ex_list[0]
a_ex = ctu2.add_clean_prediction_gpt2(mt_uskg_gpt2, a_ex, samples=2)

a_ex['answer_len'], \
a_ex['base_score'], \
a_ex['answers_t'], \
a_ex['answer'], \
a_ex['correct_prediction']

(1, 0.9999260902404785, tensor([1499], device='cuda:0'), ' country', True)

In [192]:
result = trace_exp2_section_corrupt_restore_gpt2(mt_uskg_gpt2, a_ex)

In [193]:
a_ex['text_range'], a_ex['struct_range'], a_ex['sql_range'], a_ex['dec_prompt'], a_ex['expect']

((0, 11),
 (16, 97),
 (100, 112),
 'Show all countries and the number of singers in each country.; structed knowledge: | concert_singer | stadium : stadium_id, location, name, capacity, highest, lowest, average | singer : singer_id, name, country, song_name, song_release_year, age, is_male | concert : concert_id, concert_name, theme, stadium_id, year | singer_in_concert : concert_id, singer_id ; SQL: select',
 'country')

In [198]:
a_ex['self_ranges']

[(47, 50)]

In [195]:
dec_toks = mt_uskg_gpt2.tokenizer.tokenize(a_ex['dec_prompt'])
len(dec_toks)

101

In [200]:
dec_toks[a_ex['self_ranges'][0][0] : a_ex['self_ranges'][0][1]]

[',', 'Ġcountry', ',']

In [201]:
len(mt_uskg_gpt2.tokenizer.tokenize(a_ex['expect']))

1

In [202]:
result

{'enc_sentence': 'Show all countries and the number of singers in each country.; structed knowledge: | concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id',
 'seq_out': 'select country, count(*) from singer group by country',
 'dec_prompt': 'Show all countries and the number of singers in each country.; structed knowledge: | concert_singer | stadium : stadium_id, location, name, capacity, highest, lowest, average | singer : singer_id, name, country, song_name, song_release_year, age, is_male | concert : concert_id, concert_name, theme, stadium_id, year | singer_in_concert : concert_id, singer_id ; SQL: select',
 'db_id': 'concert_singer',
 'expect': 'country',
 'expect_type': 'column',
 'expect_input_ranges': [(48, 49)],
 'self_ranges': [(47, 50)],


## Temp

In [10]:
gpt2_model = AutoModelForPreTraining.from_pretrained('gpt2')

In [9]:
gpt2_tokenizer = AutoTokenizer.from_pretrained('gpt2')

### Tokenizer

In [11]:
gpt2_tokenizer

PreTrainedTokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_len=1024, is_fast=True, padding_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'})

In [12]:
gpt2_tokenizer.eos_token_id

50256

In [13]:
_toks = gpt2_tokenizer.tokenize(' Select * from singer')
_toks

['ĠSelect', 'Ġ*', 'Ġfrom', 'Ġsinger']

In [18]:
_toks = gpt2_tokenizer.tokenize('x ; SQL: Select')
_toks

['x', 'Ġ;', 'ĠSQL', ':', 'ĠSelect']

In [19]:
gpt2_tokenizer.convert_tokens_to_string(_toks)

'x ; SQL: Select'

In [20]:
gpt2_tokenizer.convert_tokens_to_ids(_toks)

[87, 2162, 16363, 25, 9683]

In [30]:
_sql = "select role_code from project_staff where date_from > '2003-04-19 15:06:20' and date_to < '2016-03-15 00:33:18'"
_toks = gpt2_tokenizer.tokenize(_sql)

In [31]:
print(_toks)

['select', 'Ġrole', '_', 'code', 'Ġfrom', 'Ġproject', '_', 'staff', 'Ġwhere', 'Ġdate', '_', 'from', 'Ġ>', "Ġ'", '2003', '-', '04', '-', '19', 'Ġ15', ':', '06', ':', '20', "'", 'Ġand', 'Ġdate', '_', 'to', 'Ġ<', "Ġ'", '2016', '-', '03', '-', '15', 'Ġ00', ':', '33', ':', '18', "'"]


In [32]:
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

In [33]:
gpt2_tokenizer.convert_tokens_to_string(_toks)

"select role_code from project_staff where date_from > '2003-04-19 15:06:20' and date_to < '2016-03-15 00:33:18'"

In [34]:
len(_toks)

42

In [142]:
ex['parsed_struct_in']

((1, 'car_1', 'car_1'),
 [((3, 'continents', 'continents'),
   [[(5, 'contid', 'contid'), []],
    [(7, 'continent', 'continent ( europe )'), [(9, 'europe', 'europe')]]]),
  ((12, 'countries', 'countries'),
   [[(14, 'countryid', 'countryid'), []],
    [(16, 'countryname', 'countryname'), []],
    [(18, 'continent', 'continent'), []]]),
  ((20, 'car_makers', 'car_makers'),
   [[(22, 'id', 'id'), []],
    [(24, 'maker', 'maker'), []],
    [(26, 'fullname', 'fullname'), []],
    [(28, 'country', 'country'), []]]),
  ((30, 'model_list', 'model_list'),
   [[(32, 'modelid', 'modelid'), []],
    [(34, 'maker', 'maker'), []],
    [(36, 'model', 'model'), []]]),
  ((38, 'car_names', 'car_names'),
   [[(40, 'makeid', 'makeid'), []],
    [(42, 'model', 'model'), []],
    [(44, 'make', 'make'), []]]),
  ((46, 'cars_data', 'cars_data'),
   [[(48, 'id', 'id'), []],
    [(50, 'mpg', 'mpg'), []],
    [(52, 'cylinders', 'cylinders'), []],
    [(54, 'edispl', 'edispl'), []],
    [(56, 'horsepower', 'ho

In [136]:
# _sent = 'What are the names of all European countries with at least 3 manufacturers?; structed knowledge: | car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year'
_sent = '| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year'
_sent

'| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year'

In [137]:
_toks = mt_uskg_gpt2.tokenizer.tokenize(_sent)

In [138]:
tokenized_sent = mt_uskg_gpt2.tokenizer(_sent)

In [139]:
list(enumerate(_sent.split()))

[(0, '|'),
 (1, 'car_1'),
 (2, '|'),
 (3, 'continents'),
 (4, ':'),
 (5, 'contid'),
 (6, ','),
 (7, 'continent'),
 (8, '('),
 (9, 'europe'),
 (10, ')'),
 (11, '|'),
 (12, 'countries'),
 (13, ':'),
 (14, 'countryid'),
 (15, ','),
 (16, 'countryname'),
 (17, ','),
 (18, 'continent'),
 (19, '|'),
 (20, 'car_makers'),
 (21, ':'),
 (22, 'id'),
 (23, ','),
 (24, 'maker'),
 (25, ','),
 (26, 'fullname'),
 (27, ','),
 (28, 'country'),
 (29, '|'),
 (30, 'model_list'),
 (31, ':'),
 (32, 'modelid'),
 (33, ','),
 (34, 'maker'),
 (35, ','),
 (36, 'model'),
 (37, '|'),
 (38, 'car_names'),
 (39, ':'),
 (40, 'makeid'),
 (41, ','),
 (42, 'model'),
 (43, ','),
 (44, 'make'),
 (45, '|'),
 (46, 'cars_data'),
 (47, ':'),
 (48, 'id'),
 (49, ','),
 (50, 'mpg'),
 (51, ','),
 (52, 'cylinders'),
 (53, ','),
 (54, 'edispl'),
 (55, ','),
 (56, 'horsepower'),
 (57, ','),
 (58, 'weight'),
 (59, ','),
 (60, 'accelerate'),
 (61, ','),
 (62, 'year')]

In [140]:
list(enumerate(_toks))

[(0, '|'),
 (1, 'Ġcar'),
 (2, '_'),
 (3, '1'),
 (4, 'Ġ|'),
 (5, 'Ġcontinents'),
 (6, 'Ġ:'),
 (7, 'Ġcont'),
 (8, 'id'),
 (9, 'Ġ,'),
 (10, 'Ġcontinent'),
 (11, 'Ġ('),
 (12, 'Ġeuro'),
 (13, 'pe'),
 (14, 'Ġ)'),
 (15, 'Ġ|'),
 (16, 'Ġcountries'),
 (17, 'Ġ:'),
 (18, 'Ġcountry'),
 (19, 'id'),
 (20, 'Ġ,'),
 (21, 'Ġcountry'),
 (22, 'name'),
 (23, 'Ġ,'),
 (24, 'Ġcontinent'),
 (25, 'Ġ|'),
 (26, 'Ġcar'),
 (27, '_'),
 (28, 'makers'),
 (29, 'Ġ:'),
 (30, 'Ġid'),
 (31, 'Ġ,'),
 (32, 'Ġmaker'),
 (33, 'Ġ,'),
 (34, 'Ġfull'),
 (35, 'name'),
 (36, 'Ġ,'),
 (37, 'Ġcountry'),
 (38, 'Ġ|'),
 (39, 'Ġmodel'),
 (40, '_'),
 (41, 'list'),
 (42, 'Ġ:'),
 (43, 'Ġmodel'),
 (44, 'id'),
 (45, 'Ġ,'),
 (46, 'Ġmaker'),
 (47, 'Ġ,'),
 (48, 'Ġmodel'),
 (49, 'Ġ|'),
 (50, 'Ġcar'),
 (51, '_'),
 (52, 'names'),
 (53, 'Ġ:'),
 (54, 'Ġmake'),
 (55, 'id'),
 (56, 'Ġ,'),
 (57, 'Ġmodel'),
 (58, 'Ġ,'),
 (59, 'Ġmake'),
 (60, 'Ġ|'),
 (61, 'Ġcars'),
 (62, '_'),
 (63, 'data'),
 (64, 'Ġ:'),
 (65, 'Ġid'),
 (66, 'Ġ,'),
 (67, 'Ġm'),
 (68, 'pg'),
 (69

In [141]:
list([(i, mt_uskg_gpt2.tokenizer.decode(tokenized_sent['input_ids'][i])) for i in range(len(tokenized_sent['input_ids']))])

[(0, '|'),
 (1, ' car'),
 (2, '_'),
 (3, '1'),
 (4, ' |'),
 (5, ' continents'),
 (6, ' :'),
 (7, ' cont'),
 (8, 'id'),
 (9, ','),
 (10, ' continent'),
 (11, ' ('),
 (12, ' euro'),
 (13, 'pe'),
 (14, ' )'),
 (15, ' |'),
 (16, ' countries'),
 (17, ' :'),
 (18, ' country'),
 (19, 'id'),
 (20, ','),
 (21, ' country'),
 (22, 'name'),
 (23, ','),
 (24, ' continent'),
 (25, ' |'),
 (26, ' car'),
 (27, '_'),
 (28, 'makers'),
 (29, ' :'),
 (30, ' id'),
 (31, ','),
 (32, ' maker'),
 (33, ','),
 (34, ' full'),
 (35, 'name'),
 (36, ','),
 (37, ' country'),
 (38, ' |'),
 (39, ' model'),
 (40, '_'),
 (41, 'list'),
 (42, ' :'),
 (43, ' model'),
 (44, 'id'),
 (45, ','),
 (46, ' maker'),
 (47, ','),
 (48, ' model'),
 (49, ' |'),
 (50, ' car'),
 (51, '_'),
 (52, 'names'),
 (53, ' :'),
 (54, ' make'),
 (55, 'id'),
 (56, ','),
 (57, ' model'),
 (58, ','),
 (59, ' make'),
 (60, ' |'),
 (61, ' cars'),
 (62, '_'),
 (63, 'data'),
 (64, ' :'),
 (65, ' id'),
 (66, ','),
 (67, ' m'),
 (68, 'pg'),
 (69, ','),
 (7

In [129]:
tokenized_sent.word_to_tokens(12)

TokenSpan(start=12, end=13)

In [144]:
len(tokenized_sent.word_ids()), tokenized_sent.word_ids()

(83,
 [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  7,
  8,
  9,
  10,
  11,
  11,
  12,
  13,
  14,
  15,
  16,
  16,
  17,
  18,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  60,
  61,
  62,
  63,
  64,
  64,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72])

In [134]:
i = 0
_words = []
while True:
    try:
        s, e = tokenized_sent.word_to_chars(i)
        _words.append(_sent[s : e])
        i += 1
    except:
        break

In [135]:
_words

['What',
 ' are',
 ' the',
 ' names',
 ' of',
 ' all',
 ' European',
 ' countries',
 ' with',
 ' at',
 ' least',
 ' 3',
 ' manufacturers',
 '?;',
 ' structed',
 ' knowledge',
 ':',
 ' |',
 ' car',
 '_',
 '1',
 ' |',
 ' continents',
 ' :',
 ' contid',
 ' ,',
 ' continent',
 ' (',
 ' europe',
 ' )',
 ' |',
 ' countries',
 ' :',
 ' countryid',
 ' ,',
 ' countryname',
 ' ,',
 ' continent',
 ' |',
 ' car',
 '_',
 'makers',
 ' :',
 ' id',
 ' ,',
 ' maker',
 ' ,',
 ' fullname',
 ' ,',
 ' country',
 ' |',
 ' model',
 '_',
 'list',
 ' :',
 ' modelid',
 ' ,',
 ' maker',
 ' ,',
 ' model',
 ' |',
 ' car',
 '_',
 'names',
 ' :',
 ' makeid',
 ' ,',
 ' model',
 ' ,',
 ' make',
 ' |',
 ' cars',
 '_',
 'data',
 ' :',
 ' id',
 ' ,',
 ' mpg',
 ' ,',
 ' cylinders',
 ' ,',
 ' edispl',
 ' ,',
 ' horsepower',
 ' ,',
 ' weight',
 ' ,',
 ' accelerate',
 ' ,',
 ' year']

In [146]:
tokenized_sent.char_to_token(10)

5

In [147]:
_sent[8:12], _toks[5]

('| co', 'Ġcontinents')

### GPT2 Model

In [38]:
[k for k, v in gpt2_model.named_modules()]

['',
 'transformer',
 'transformer.wte',
 'transformer.wpe',
 'transformer.drop',
 'transformer.h',
 'transformer.h.0',
 'transformer.h.0.ln_1',
 'transformer.h.0.attn',
 'transformer.h.0.attn.c_attn',
 'transformer.h.0.attn.c_proj',
 'transformer.h.0.attn.attn_dropout',
 'transformer.h.0.attn.resid_dropout',
 'transformer.h.0.ln_2',
 'transformer.h.0.mlp',
 'transformer.h.0.mlp.c_fc',
 'transformer.h.0.mlp.c_proj',
 'transformer.h.0.mlp.dropout',
 'transformer.h.1',
 'transformer.h.1.ln_1',
 'transformer.h.1.attn',
 'transformer.h.1.attn.c_attn',
 'transformer.h.1.attn.c_proj',
 'transformer.h.1.attn.attn_dropout',
 'transformer.h.1.attn.resid_dropout',
 'transformer.h.1.ln_2',
 'transformer.h.1.mlp',
 'transformer.h.1.mlp.c_fc',
 'transformer.h.1.mlp.c_proj',
 'transformer.h.1.mlp.dropout',
 'transformer.h.2',
 'transformer.h.2.ln_1',
 'transformer.h.2.attn',
 'transformer.h.2.attn.c_attn',
 'transformer.h.2.attn.c_proj',
 'transformer.h.2.attn.attn_dropout',
 'transformer.h.2.attn.r

### Temp debugging

In [21]:
import uskg

from transformers.models.auto.configuration_auto import AutoConfig

In [45]:
import importlib

importlib.reload(uskg.models.prompt.modeling_auto)

from uskg.models.unified import finetune, prefixtuning

from uskg.models.prompt.modeling_auto import AutoModelForPreTraining
from uskg.models.prompt.modeling_gpt2 import GPT2LMHeadModel
from uskg.models.prompt.modeling_t5 import T5ForConditionalGeneration

In [46]:
_m = AutoModelForPreTraining.from_pretrained('gpt2')

In [47]:
type(_m)

uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

In [48]:
id(type(_m))

94854136059536

In [49]:
isinstance(_m, uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel)

True

In [50]:
isinstance(_m, GPT2LMHeadModel)

True

In [51]:
GPT2LMHeadModel is uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

True

In [37]:
id(uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel)

94854135704752

In [15]:
type(_m) is uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

True

In [67]:
type(_m) is GPT2LMHeadModel

False

In [69]:
isinstance(_m, type(_m))

True

In [64]:
type(_m)

uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

In [65]:
_m.__class__

uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

In [87]:
_cfg = AutoConfig.from_pretrained('gpt2')

In [88]:
_cfg

GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.9.2",
  "use_cache": true,
  "vocab_size": 50257
}

In [None]:
## T5
_m2 = AutoModelForPreTraining.from_pretrained('t5-base')

In [83]:
isinstance(_m2, uskg.models.prompt.modeling_t5.T5ForConditionalGeneration)

True

In [84]:
isinstance(_m2, T5ForConditionalGeneration)

True