<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [1]:
IS_COLAB = False

## Causal Tracing

A demonstration of the double-intervention causal tracing method.

The strategy used by causal tracing is to understand important
states within a transfomer by doing two interventions simultaneously:

1. Corrupt a subset of the input.  In our paper, we corrupt the subject tokens
   to frustrate the ability of the transformer to accurately complete factual
   prompts about the subject.
2. Restore a subset of the internal hidden states.  In our paper, we scan
   hidden states at all layers and all tokens, searching for individual states
   that carry the necessary information for the transformer to recover its
   capability to complete the factual prompt.

The traces of decisive states can be shown on a heatmap.  This notebook
demonstrates the code for conducting causal traces and creating these heatmaps.

In [2]:
%load_ext autoreload
%autoreload 2

The `experiments.causal_trace` module contains a set of functions for running causal traces.

In this notebook, we reproduce, demonstrate and discuss the interesting functions.

We begin by importing several utility functions that deal with tokens and transformer models.

In [3]:
import os, sys, re, json
import string
import torch
import numpy as np
import copy
from collections import defaultdict, Counter
from util import nethook
from util.globals import DATA_DIR
from experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)
from dsets import KnownsDataset

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fd1ec0d1bb0>

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

## USKG

In [20]:
from transformers import (
    HfArgumentParser,
    set_seed,
    AutoTokenizer
)
from transformers import AutoModelForPreTraining

import uskg
# from uskg.models.unified.prefixtuning import Model
from uskg.models.unified import finetune, prefixtuning
from uskg.utils.configue import Configure
from uskg.utils.training_arguments import WrappedSeq2SeqTrainingArguments
from uskg.seq2seq_construction import spider as s2s_spider
from uskg.third_party.spider.preprocess.get_tables import dump_db_json_schema
from uskg.third_party.spider import evaluation as sp_eval

from uskg.utils.dataset import gpt2_construct_input

from tqdm.notebook import tqdm


# from nltk.stem.wordnet import WordNetLemmatizer
# import stanza

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

import matplotlib.pyplot as plt
import sqlite3
import ujson
import pickle

# import experiments
from experiments import causal_trace_uskg as ctu
from experiments import causal_trace_uskg_gpt2 as ctu2

# import util

In [8]:
# Try to re-import to solve the class mismatch problem, but still not working...

# import importlib

# importlib.reload(uskg.models.prompt.modeling_auto)
# importlib.reload(uskg.models.prompt.modeling_gpt2)
# importlib.reload(uskg.models.unified)
# importlib.reload(experiments)
# importlib.reload(ctu)
# importlib.reload(util)
# importlib.reload(util.uskg)
# importlib.reload(uskg.utils)

# from uskg.models.unified import finetune, prefixtuning

# from uskg.models.prompt.modeling_auto import AutoModelForPreTraining
# from uskg.models.prompt.modeling_gpt2 import GPT2LMHeadModel
# from uskg.models.prompt.modeling_t5 import T5ForConditionalGeneration

# from experiments import causal_trace_uskg as ctu

In [9]:
from util import uskg as uu
from util import uskg_gpt2 as uu2

### Loading model

In [10]:
# uskg_gpt2_model, tokenizer_uskg, tokenizer_fast, training_args, model_args, task_args = ctu.load_model_uskg('gpt2-prefix', untie_embeddings=False)
mt_uskg_gpt2 = ctu.ModelAndTokenizer_USKG_GPT2('gpt2-medium-prefix')

Using tokenizer_uskg: /home/yshao/Projects/UnifiedSKG/output/A-GPT2_medium_prefix_spider_with_cell_value-pfx=20/run-20231113/checkpoint-23500
Using tokenizer_fast: gpt2-medium
gpt2-medium
prefix-tuning sequence length is 20.


In [None]:
mt_uskg_gpt2.model

In [None]:
# [k for k, v in mt_uskg_gpt2.model.named_modules()]

### Loading dataset

In [12]:
spider_train_path = '/home/yshao/Projects/SDR-analysis/data/spider/train+ratsql_graph.json'
spider_dev_path = '/home/yshao/Projects/SDR-analysis/data/spider/dev+ratsql_graph.json'
spider_db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [13]:
raw_spider_dev = ctu.load_raw_dataset(
    data_filepath = spider_dev_path,
    db_path=spider_db_dir,
#     schema_cache=SCHEMA_CACHE
)
len(raw_spider_dev)

1034

In [14]:
raw_spider_dev[0].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph'])

In [15]:
mt_uskg_gpt2.task_args.dataset.use_cache

True

In [16]:
processed_spider_dev = s2s_spider.DevDataset(
    args=mt_uskg_gpt2.task_args,
    raw_datasets=raw_spider_dev,
    cache_root='../cache')

In [17]:
_id = 130
processed_spider_dev[_id]['text_in'], \
processed_spider_dev[_id]['struct_in'], \
processed_spider_dev[_id]['seq_out']

('What are the names of all European countries with at least 3 manufacturers?',
 '| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.countryname from countries as t1 join continents as t2 on t1.continent = t2.contid join car_makers as t3 on t1.countryid = t3.country where t2.continent = 'europe' group by t1.countryname having count(*) >= 3;")

In [18]:
_enc_sentence = f"{processed_spider_dev[_id]['text_in']}; structed knowledge: {processed_spider_dev[_id]['struct_in']}"
_toks = mt_uskg_gpt2.tokenizer.tokenize(_enc_sentence)
len(_toks)

102

### Utils

#### Misc

In [227]:
def reverse_2D_dict(d):
    out_d = defaultdict(lambda: defaultdict(np.nan))
    for k1, d1 in d.items():
        for k2, v in d1.items():
            out_d[k2][k1] = v
    return out_d

def format_print_1D_dict(d, sort_by=None, reverse=False, head_col_w=10, col_w=6):
    # sort: None, 'key' or 'value'
    
    item_l = list(d.items())
    if sort_by == 'key':
        item_l.sort(reverse=reverse)
    elif sort_by == 'value':
        item_l.sort(key=lambda x: (x[1], x[0]), reverse=reverse)
    
    decm_w = col_w - 2
    
    for k, v in item_l:
        print(f'{k:<{head_col_w}s}{v:.{decm_w}f}')

def format_print_2D_dict(d, 
                         all_k1=None, 
                         all_k2=None, 
                         sort_k1_kwargs=None, 
                         sort_k2_kwargs=None, 
                         head_col_w=12, 
                         col_w=6,
                         decm_w=4):
    if all_k1 is None:
        all_k1 = list(d.keys())
        if sort_k1_kwargs is not None:
            all_k1.sort(**sort_k1_kwargs)
    
    if all_k2 is None:
        for k1, d1 in d.items():
            d1_keys = list(d1.keys())
            if all_k2 is None:
                all_k2 = d1_keys
            else:
                if set(d1_keys) != set(all_k2):
                    print('Warning:\n', d1_keys, '\n', all_k2)
            # all_k2.update(list(d1.keys()))
        if sort_k2_kwargs is not None:
            all_k2.sort(**sort_k2_kwargs)
    
    print_str = '\t'.join(['X' * head_col_w] + [f'{k2:<{col_w}s}' for k2 in all_k2]) + '\n'
    
    for k1 in all_k1:
        d1 = d[k1]
        print_str += f'{k1:<{head_col_w}s}'
        for k2 in all_k2:
            v = d1[k2]
            print_str += f'\t{v:<{col_w}.{decm_w}f}'
        print_str += '\n'
    
    print(print_str)

#### Evaluator

In [24]:
table_path = '/home/yshao/Projects/language/language/xsp/data/spider/tables.json'
db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [25]:
kmaps = sp_eval.build_foreign_key_map_from_json(table_path)
evaluator = sp_eval.Evaluator(db_dir=db_dir, kmaps=kmaps, etype='all')

In [26]:
ctu.evaluate_hardness.evaluator = evaluator

## Testing

### find_text_struct_in_range_gpt2()

In [80]:
_id = 130
ex = processed_spider_dev[_id]
seq_in = f"{ex['text_in']}; structed knowledge: {ex['struct_in']}"
seq_out = ex['seq_out']
_test_seq_input = gpt2_construct_input(seq_in, seq_out, mt_uskg_gpt2.tokenizer,
                                       in_maxlen=362,
                                       out_maxlen=128,
                                       padding=False)

In [81]:
_test_seq = mt_uskg_gpt2.tokenizer.decode(_test_seq_input['input_ids'])
_test_seq

"What are the names of all European countries with at least 3 manufacturers?; structed knowledge: | car_1 | continents : contid, continent ( europe ) | countries : countryid, countryname, continent | car_makers : id, maker, fullname, country | model_list : modelid, maker, model | car_names : makeid, model, make | cars_data : id, mpg, cylinders, edispl, horsepower, weight, accelerate, year ; SQL: select t1.countryname from countries as t1 join continents as t2 on t1.continent = t2.contid join car_makers as t3 on t1.countryid = t3.country where t2.continent = 'europe' group by t1.countryname having count(*) >= 3; END OF SQL"

In [82]:
_toks = uu.decode_tokens(mt_uskg_gpt2.tokenizer, _test_seq_input['input_ids'])
print(_toks)

['What', ' are', ' the', ' names', ' of', ' all', ' European', ' countries', ' with', ' at', ' least', ' 3', ' manufacturers', '?', ';', ' struct', 'ed', ' knowledge', ':', ' |', ' car', '_', '1', ' |', ' continents', ' :', ' cont', 'id', ',', ' continent', ' (', ' euro', 'pe', ' )', ' |', ' countries', ' :', ' country', 'id', ',', ' country', 'name', ',', ' continent', ' |', ' car', '_', 'makers', ' :', ' id', ',', ' maker', ',', ' full', 'name', ',', ' country', ' |', ' model', '_', 'list', ' :', ' model', 'id', ',', ' maker', ',', ' model', ' |', ' car', '_', 'names', ' :', ' make', 'id', ',', ' model', ',', ' make', ' |', ' cars', '_', 'data', ' :', ' id', ',', ' m', 'pg', ',', ' cylinders', ',', ' ed', 'is', 'pl', ',', ' horsepower', ',', ' weight', ',', ' accelerate', ',', ' year', ' ;', ' SQL', ':', ' select', ' t', '1', '.', 'country', 'name', ' from', ' countries', ' as', ' t', '1', ' join', ' continents', ' as', ' t', '2', ' on', ' t', '1', '.', 'cont', 'inent', ' =', ' t', '

In [83]:
text_range, struct_range, out_range = uu2.find_text_struct_in_range_gpt2(mt_uskg_gpt2.tokenizer, _test_seq_input['input_ids'])

In [84]:
text_range, struct_range, out_range

((0, 14), (19, 102), (105, 176))

In [85]:
text_st, text_ed = text_range
struct_st, struct_ed = struct_range
out_st, out_ed = out_range

''.join(_toks[text_st : text_ed]), \
''.join(_toks[struct_st : struct_ed]), \
''.join(_toks[out_st : out_ed])

('What are the names of all European countries with at least 3 manufacturers?',
 ' | car_1 | continents : contid, continent ( europe ) | countries : countryid, countryname, continent | car_makers : id, maker, fullname, country | model_list : modelid, maker, model | car_names : makeid, model, make | cars_data : id, mpg, cylinders, edispl, horsepower, weight, accelerate, year',
 " select t1.countryname from countries as t1 join continents as t2 on t1.continent = t2.contid join car_makers as t3 on t1.countryid = t3.country where t2.continent = 'europe' group by t1.countryname having count(*) >= 3;")

In [86]:
''.join(_toks[36 : 40])

' : countryid,'

### create_analysis_sample_dicts() (& syntax)

In [215]:
_id = 222
ex = copy.deepcopy(processed_spider_dev[_id])

In [216]:
ex['text_in'], ex['struct_in'], ex['seq_out']

('Which city has the most frequent destination airport?',
 '| flight_2 | airlines : uid , airline , abbreviation , country | airports : city , airportcode , airportname , country , countryabbrev | flights : airline , flightno , sourceairport , destairport',
 'select t1.city from airports as t1 join flights as t2 on t1.airportcode = t2.destairport group by t1.city order by count(*) desc limit 1')

In [217]:
a_ex_list = ctu2.create_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex, subject_type='column')

In [218]:
a_ex_list[0].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'tokenized_item', 'pre_sql_sequence', 'text_range', 'struct_range', 'sql_range', 'parsed_struct_in', 'alias2table', 'col2table', 'col_name_counter', 'tab_name_counter', 'struct_node_ranges_dict', 'sql_tokens', 'sql_token_ranges', 'tok_ranges2type', 'type2tok_ranges', 'sql_col_nodes', 'sql_tab_nodes', 'sql_alias_nodes', 'dec_prompt', 'expect', 'expect_type', 'remove_struct_duplicate_nodes', 'token_ranges_dict', 'node_name_ranges', 'expect_input_ranges', 'self_ranges', 'context_ranges', 'category'])

In [219]:
a_ex_list[0]['alias2table']

{'t1': 'airports', 't2': 'flights'}

In [220]:
a_ex_list[0]['type2tok_ranges']

defaultdict(list,
            {'syntax': [(0, 6),
              (15, 19),
              (29, 31),
              (35, 39),
              (48, 50),
              (54, 56),
              (72, 73),
              (89, 94),
              (95, 97),
              (106, 111),
              (112, 114),
              (115, 120),
              (120, 121),
              (121, 122),
              (122, 123),
              (124, 128),
              (129, 134)],
             'table_alias': [(7, 10),
              (32, 34),
              (51, 53),
              (57, 60),
              (74, 77),
              (98, 101)],
             'column': [(10, 14), (60, 71), (77, 88), (101, 105)],
             'table': [(20, 28), (40, 47)],
             'val': [(135, 136)]})

In [221]:
_toks = mt_uskg_gpt2.tokenizer.tokenize(a_ex_list[0]['seq_out'])
len(_toks)

43

In [222]:
_toks[7:10]

['Ġas', 'Ġt', '1']

In [223]:
a_ex_list[0]['seq_out'][7:10]

't1.'

In [29]:
a_ex_list = ctu2.create_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex, subject_type='table')

In [None]:
a_ex_list[0]

In [31]:
a_ex_list = ctu2.create_syntax_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex)

In [None]:
a_ex_list[0]

#### create_analysis_sample_dicts_all_nodes_gpt2

In [239]:
_id = 222
ex = copy.deepcopy(processed_spider_dev[_id])

In [240]:
a_ex = ctu2.create_analysis_sample_dicts_all_nodes_gpt2(mt_uskg_gpt2, ex)

In [243]:
for k, v in a_ex.items():
    if 'col' in k or 'tab' in k:
        print(k, ':', v)

db_table_names : ['airlines', 'airports', 'flights']
db_column_names : {'table_id': [-1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2], 'column_name': ['*', 'uid', 'Airline', 'Abbreviation', 'Country', 'City', 'AirportCode', 'AirportName', 'Country', 'CountryAbbrev', 'Airline', 'FlightNo', 'SourceAirport', 'DestAirport']}
db_column_types : ['text', 'number', 'text', 'text', 'text', 'text', 'text', 'text', 'text', 'text', 'number', 'number', 'text', 'text']
alias2table : {'t1': 'airports', 't2': 'flights'}
col2table : defaultdict(<class 'list'>, {'uid': ['airlines'], 'airline': ['airlines', 'flights'], 'abbreviation': ['airlines'], 'country': ['airlines', 'airports'], 'city': ['airports'], 'airportcode': ['airports'], 'airportname': ['airports'], 'countryabbrev': ['airports'], 'flightno': ['flights'], 'sourceairport': ['flights'], 'destairport': ['flights']})
col_name_counter : Counter({'airline': 2, 'country': 2, 'uid': 1, 'abbreviation': 1, 'city': 1, 'airportcode': 1, 'airportname': 1, 'cou

In [244]:
a_ex['seq_out']

'select t1.city from airports as t1 join flights as t2 on t1.airportcode = t2.destairport group by t1.city order by count(*) desc limit 1'

In [245]:
_toks = mt_uskg_gpt2.tokenizer.tokenize(a_ex_list[0]['enc_sentence'])
len(_toks)

62

In [248]:
for k, v in a_ex['col_self_ranges'].items():
    print(k, _toks[v[0][0]:v[0][1]])

airportcode ['Ġ,', 'Ġairport', 'code', 'Ġ,']
city ['Ġ:', 'Ġcity', 'Ġ,']
destairport ['Ġ,', 'Ġdest', 'air', 'port']
uid ['Ġ:', 'Ġu', 'id', 'Ġ,']
abbreviation ['Ġ,', 'Ġabbre', 'viation', 'Ġ,']
airportname ['Ġ,', 'Ġairport', 'name', 'Ġ,']
countryabbrev ['Ġ,', 'Ġcountry', 'ab', 'bre', 'v', 'Ġ|']
flightno ['Ġ,', 'Ġflight', 'no', 'Ġ,']
sourceairport ['Ġ,', 'Ġsource', 'air', 'port', 'Ġ,']


### add_clean_prediction()

In [43]:
_id = 130
ex = copy.deepcopy(processed_spider_dev[_id])
a_ex_list = ctu2.create_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex, subject_type='column')
a_ex = a_ex_list[0]

In [44]:
a_ex['dec_prompt'], a_ex['expect']

('What are the names of all European countries with at least 3 manufacturers?; structed knowledge: | car_1 | continents : contid, continent ( europe ) | countries : countryid, countryname, continent | car_makers : id, maker, fullname, country | model_list : modelid, maker, model | car_names : makeid, model, make | cars_data : id, mpg, cylinders, edispl, horsepower, weight, accelerate, year ; SQL: select t1.',
 'countryname')

In [45]:
a_ex = ctu2.add_clean_prediction_gpt2(mt_uskg_gpt2, a_ex, samples=2)

In [46]:
a_ex['answer_len'], \
a_ex['base_score'], \
a_ex['answers_t'], \
a_ex['answer'], \
a_ex['correct_prediction']

(2,
 0.9999053478240967,
 tensor([19315,  3672], device='cuda:0'),
 'countryname',
 True)

### find_struct_name_ranges()

In [154]:
_id = 130
ex = copy.deepcopy(processed_spider_dev[_id])
ctu2.add_basic_analysis_info_gpt2(mt_uskg_gpt2, ex)

In [155]:
ex.keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'tokenized_item', 'pre_sql_sequence', 'text_range', 'struct_range', 'sql_range', 'parsed_struct_in', 'alias2table', 'col2table', 'col_name_counter', 'tab_name_counter', 'struct_node_ranges_dict', 'sql_tokens', 'sql_token_ranges', 'tok_ranges2type', 'type2tok_ranges', 'sql_col_nodes', 'sql_tab_nodes', 'sql_alias_nodes'])

In [156]:
struct_node_ranges_dict = uu2.find_struct_name_ranges_gpt2(mt_uskg_gpt2.tokenizer, ex)

In [157]:
struct_node_ranges_dict

{'db_id_ranges': defaultdict(list, {'car_1': [(20, 23)]}),
 'table_name_ranges': defaultdict(list,
             {'continents': [(24, 25)],
              'countries': [(35, 36)],
              'car_makers': [(45, 48)],
              'model_list': [(58, 61)],
              'car_names': [(69, 72)],
              'cars_data': [(80, 83)]}),
 'col_name_ranges': defaultdict(list,
             {'contid': [(26, 28)],
              'continent': [(29, 34), (43, 44)],
              'countryid': [(37, 39)],
              'countryname': [(40, 42)],
              'id': [(49, 50), (84, 85)],
              'maker': [(51, 52), (65, 66)],
              'fullname': [(53, 55)],
              'country': [(56, 57)],
              'modelid': [(62, 64)],
              'model': [(67, 68), (76, 77)],
              'makeid': [(73, 75)],
              'make': [(78, 79)],
              'mpg': [(86, 88)],
              'cylinders': [(89, 90)],
              'edispl': [(91, 94)],
              'horsepower': [(95, 96)

In [158]:
_toks = mt_uskg_gpt2.tokenizer.convert_ids_to_tokens(ex['tokenized_item']['input_ids'])
print(_toks)

['What', 'Ġare', 'Ġthe', 'Ġnames', 'Ġof', 'Ġall', 'ĠEuropean', 'Ġcountries', 'Ġwith', 'Ġat', 'Ġleast', 'Ġ3', 'Ġmanufacturers', '?', ';', 'Ġstruct', 'ed', 'Ġknowledge', ':', 'Ġ|', 'Ġcar', '_', '1', 'Ġ|', 'Ġcontinents', 'Ġ:', 'Ġcont', 'id', 'Ġ,', 'Ġcontinent', 'Ġ(', 'Ġeuro', 'pe', 'Ġ)', 'Ġ|', 'Ġcountries', 'Ġ:', 'Ġcountry', 'id', 'Ġ,', 'Ġcountry', 'name', 'Ġ,', 'Ġcontinent', 'Ġ|', 'Ġcar', '_', 'makers', 'Ġ:', 'Ġid', 'Ġ,', 'Ġmaker', 'Ġ,', 'Ġfull', 'name', 'Ġ,', 'Ġcountry', 'Ġ|', 'Ġmodel', '_', 'list', 'Ġ:', 'Ġmodel', 'id', 'Ġ,', 'Ġmaker', 'Ġ,', 'Ġmodel', 'Ġ|', 'Ġcar', '_', 'names', 'Ġ:', 'Ġmake', 'id', 'Ġ,', 'Ġmodel', 'Ġ,', 'Ġmake', 'Ġ|', 'Ġcars', '_', 'data', 'Ġ:', 'Ġid', 'Ġ,', 'Ġm', 'pg', 'Ġ,', 'Ġcylinders', 'Ġ,', 'Ġed', 'is', 'pl', 'Ġ,', 'Ġhorsepower', 'Ġ,', 'Ġweight', 'Ġ,', 'Ġaccelerate', 'Ġ,', 'Ġyear', 'Ġ;', 'ĠSQL', ':', 'Ġselect', 'Ġt', '1', '.', 'country', 'name', 'Ġfrom', 'Ġcountries', 'Ġas', 'Ġt', '1', 'Ġjoin', 'Ġcontinents', 'Ġas', 'Ġt', '2', 'Ġon', 'Ġt', '1', '.', 'cont', 'inen

In [159]:
for k, v in struct_node_ranges_dict['col_name_ranges'].items():
    print(k)
    for s, e in v:
        print(_toks[s:e])
    print('-'*20)

contid
['Ġcont', 'id']
--------------------
continent
['Ġcontinent', 'Ġ(', 'Ġeuro', 'pe', 'Ġ)']
['Ġcontinent']
--------------------
countryid
['Ġcountry', 'id']
--------------------
countryname
['Ġcountry', 'name']
--------------------
id
['Ġid']
['Ġid']
--------------------
maker
['Ġmaker']
['Ġmaker']
--------------------
fullname
['Ġfull', 'name']
--------------------
country
['Ġcountry']
--------------------
modelid
['Ġmodel', 'id']
--------------------
model
['Ġmodel']
['Ġmodel']
--------------------
makeid
['Ġmake', 'id']
--------------------
make
['Ġmake']
--------------------
mpg
['Ġm', 'pg']
--------------------
cylinders
['Ġcylinders']
--------------------
edispl
['Ġed', 'is', 'pl']
--------------------
horsepower
['Ġhorsepower']
--------------------
weight
['Ġweight']
--------------------
accelerate
['Ġaccelerate']
--------------------
year
['Ġyear']
--------------------


In [160]:
ex['struct_in'], ex['parsed_struct_in']

('| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 ((1, 'car_1', 'car_1'),
  [((3, 'continents', 'continents'),
    [[(5, 'contid', 'contid'), []],
     [(7, 'continent', 'continent ( europe )'), [(9, 'europe', 'europe')]]]),
   ((12, 'countries', 'countries'),
    [[(14, 'countryid', 'countryid'), []],
     [(16, 'countryname', 'countryname'), []],
     [(18, 'continent', 'continent'), []]]),
   ((20, 'car_makers', 'car_makers'),
    [[(22, 'id', 'id'), []],
     [(24, 'maker', 'maker'), []],
     [(26, 'fullname', 'fullname'), []],
     [(28, 'country', 'country'), []]]),
   ((30, 'model_list', 'model_list'),
    [[(32, 'modelid', 'modelid'), []],
     [(34, 'maker', 'maker'), []],
     [(36, 'model', 'model'), []]]),
   ((

### test exp2

In [168]:
from experiments.exp_2_gpt2 import trace_exp2_section_corrupt_restore_gpt2

In [190]:
_id = 10
ex = copy.deepcopy(processed_spider_dev[_id])
a_ex_list = ctu2.create_analysis_sample_dicts_gpt2(mt_uskg_gpt2, ex, subject_type='column')
len(a_ex_list)

2

In [191]:
a_ex = a_ex_list[0]
a_ex = ctu2.add_clean_prediction_gpt2(mt_uskg_gpt2, a_ex, samples=2)

a_ex['answer_len'], \
a_ex['base_score'], \
a_ex['answers_t'], \
a_ex['answer'], \
a_ex['correct_prediction']

(1, 0.9999260902404785, tensor([1499], device='cuda:0'), ' country', True)

In [192]:
result = trace_exp2_section_corrupt_restore_gpt2(mt_uskg_gpt2, a_ex)

In [193]:
a_ex['text_range'], a_ex['struct_range'], a_ex['sql_range'], a_ex['dec_prompt'], a_ex['expect']

((0, 11),
 (16, 97),
 (100, 112),
 'Show all countries and the number of singers in each country.; structed knowledge: | concert_singer | stadium : stadium_id, location, name, capacity, highest, lowest, average | singer : singer_id, name, country, song_name, song_release_year, age, is_male | concert : concert_id, concert_name, theme, stadium_id, year | singer_in_concert : concert_id, singer_id ; SQL: select',
 'country')

In [198]:
a_ex['self_ranges']

[(47, 50)]

In [195]:
dec_toks = mt_uskg_gpt2.tokenizer.tokenize(a_ex['dec_prompt'])
len(dec_toks)

101

In [200]:
dec_toks[a_ex['self_ranges'][0][0] : a_ex['self_ranges'][0][1]]

[',', 'Ġcountry', ',']

In [201]:
len(mt_uskg_gpt2.tokenizer.tokenize(a_ex['expect']))

1

In [202]:
result

{'enc_sentence': 'Show all countries and the number of singers in each country.; structed knowledge: | concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id',
 'seq_out': 'select country, count(*) from singer group by country',
 'dec_prompt': 'Show all countries and the number of singers in each country.; structed knowledge: | concert_singer | stadium : stadium_id, location, name, capacity, highest, lowest, average | singer : singer_id, name, country, song_name, song_release_year, age, is_male | concert : concert_id, concert_name, theme, stadium_id, year | singer_in_concert : concert_id, singer_id ; SQL: select',
 'db_id': 'concert_singer',
 'expect': 'country',
 'expect_type': 'column',
 'expect_input_ranges': [(48, 49)],
 'self_ranges': [(47, 50)],


### test exp4, exp4.1

In [271]:
from experiments.exp_4_gpt2 import trace_exp4_1_attention_patterns_gpt2

In [293]:
_id = 111
ex = copy.deepcopy(processed_spider_dev[_id])
a_ex = ctu2.create_analysis_sample_dicts_all_nodes_gpt2(mt_uskg_gpt2, ex)
a_ex.keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'tokenized_item', 'pre_sql_sequence', 'text_range', 'struct_range', 'sql_range', 'parsed_struct_in', 'alias2table', 'col2table', 'col_name_counter', 'tab_name_counter', 'struct_node_ranges_dict', 'sql_tokens', 'sql_token_ranges', 'tok_ranges2type', 'type2tok_ranges', 'sql_col_nodes', 'sql_tab_nodes', 'sql_alias_nodes', 'dec_prompt', 'occ_cols', 'occ_tabs', 'non_occ_cols', 'non_occ_tabs', 'col_self_ranges', 'col_context_ranges', 'tab_self_ranges', 'tab_context_ranges'])

In [294]:
a_ex['enc_sentence'], a_ex['seq_out']

('What is the accelerate of the car make amc hornet sportabout (sw)?; structed knowledge: | car_1 | continents : contid , continent | countries : countryid , countryname , continent | car_makers : id , maker ( amc ) , fullname , country | model_list : modelid , maker , model ( amc ) | car_names : makeid , model ( amc ) , make ( amc hornet , amc hornet sportabout (sw) ) | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';")

In [295]:
res = trace_exp4_1_attention_patterns_gpt2(mt_uskg_gpt2, a_ex)

In [296]:
res.keys()

dict_keys(['enc_sentence', 'seq_out', 'dec_prompt', 'db_id', 'col_self_ranges', 'col_context_ranges', 'tab_self_ranges', 'tab_context_ranges', 'occ_cols', 'non_occ_cols', 'occ_tabs', 'non_occ_tabs', 'attentions'])

In [297]:
res['attentions']['col'].keys()

dict_keys(['makeid', 'make', 'accelerate', 'contid', 'countryid', 'countryname', 'fullname', 'country', 'modelid', 'mpg', 'cylinders', 'edispl', 'horsepower', 'weight', 'year'])

In [298]:
_c = 'accelerate'

for sect_k, sect_l in res['attentions']['col'][_c].items():
    for layer_i, layer_l in enumerate(sect_l):
        for head_i, w in enumerate(layer_l):
            if float(w) > 0.5:
                print(f'{sect_k} Layer={layer_i} Head={head_i} -> {w}')

prefix#0 Layer=22 Head=9 -> 0.55
prefix#1 Layer=17 Head=13 -> 0.67
prefix#1 Layer=18 Head=1 -> 0.61
prefix#1 Layer=18 Head=12 -> 0.7
prefix#2 Layer=18 Head=6 -> 0.78
prefix#2 Layer=21 Head=11 -> 0.67
prefix#4 Layer=18 Head=5 -> 0.54
prefix#4 Layer=18 Head=14 -> 0.88
prefix#4 Layer=21 Head=0 -> 0.56
prefix#4 Layer=21 Head=6 -> 0.6
prefix#5 Layer=17 Head=1 -> 0.54
prefix#6 Layer=10 Head=1 -> 0.92
prefix#6 Layer=19 Head=0 -> 0.77
prefix#7 Layer=23 Head=14 -> 0.62
prefix#8 Layer=16 Head=10 -> 0.52
prefix#8 Layer=19 Head=13 -> 0.62
prefix#8 Layer=21 Head=9 -> 0.52
prefix#10 Layer=5 Head=8 -> 0.55
prefix#10 Layer=10 Head=2 -> 0.54
prefix#10 Layer=10 Head=6 -> 0.82
prefix#10 Layer=14 Head=0 -> 0.56
prefix#10 Layer=21 Head=2 -> 0.71
prefix#11 Layer=13 Head=11 -> 0.92
prefix#11 Layer=15 Head=14 -> 0.89
prefix#11 Layer=19 Head=2 -> 0.78
prefix#11 Layer=19 Head=15 -> 0.78
prefix#12 Layer=18 Head=13 -> 0.61
prefix#12 Layer=20 Head=7 -> 0.77
prefix#12 Layer=20 Head=14 -> 0.61
prefix#12 Layer=22 Hea

In [299]:
_c = 'contid'

for sect_k, sect_l in res['attentions']['col'][_c].items():
    for layer_i, layer_l in enumerate(sect_l):
        for head_i, w in enumerate(layer_l):
            if float(w) > 0.5:
                print(f'{sect_k} Layer={layer_i} Head={head_i} -> {w}')

prefix#0 Layer=12 Head=6 -> 0.53
prefix#0 Layer=21 Head=5 -> 0.54
prefix#1 Layer=17 Head=5 -> 0.86
prefix#1 Layer=18 Head=1 -> 0.56
prefix#1 Layer=18 Head=8 -> 0.55
prefix#1 Layer=21 Head=0 -> 0.90
prefix#2 Layer=15 Head=8 -> 0.57
prefix#3 Layer=13 Head=4 -> 0.55
prefix#3 Layer=14 Head=4 -> 0.56
prefix#3 Layer=18 Head=3 -> 0.52
prefix#3 Layer=23 Head=15 -> 0.84
prefix#4 Layer=23 Head=6 -> 0.72
prefix#5 Layer=7 Head=2 -> 0.81
prefix#5 Layer=21 Head=15 -> 0.75
prefix#6 Layer=14 Head=1 -> 0.62
prefix#6 Layer=15 Head=10 -> 0.63
prefix#6 Layer=19 Head=4 -> 0.6
prefix#6 Layer=20 Head=4 -> 0.76
prefix#6 Layer=23 Head=9 -> 0.74
prefix#6 Layer=23 Head=12 -> 0.85
prefix#7 Layer=21 Head=14 -> 0.72
prefix#8 Layer=16 Head=2 -> 0.94
prefix#8 Layer=21 Head=3 -> 0.56
prefix#9 Layer=19 Head=10 -> 0.53
prefix#9 Layer=21 Head=10 -> 0.53
prefix#9 Layer=23 Head=14 -> 0.69
prefix#10 Layer=5 Head=8 -> 0.66
prefix#10 Layer=11 Head=5 -> 0.62
prefix#10 Layer=18 Head=5 -> 0.76
prefix#10 Layer=20 Head=3 -> 0.84
p

## Check results

### Exp2

In [229]:
expect_type = 'column'

res_path = f'/home/yshao/Projects/rome/results/gpt2_tracing/exp2_section_corrupt_restore_gpt2/exp=gpt2_2.0_dev_{expect_type}-replace=True-noise=0.0.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [230]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d.get('is_good_sample', True):
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(2002, (1364, 1364), 614, 24, 638, 'good / correct = 1364 / 1388')

In [231]:
good_samples[0]

{'enc_sentence': 'Show name, country, age for all singers ordered by age from the oldest to the youngest.; structed knowledge: | concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id',
 'seq_out': 'select name, country, age from singer order by age desc',
 'dec_prompt': 'Show name, country, age for all singers ordered by age from the oldest to the youngest.; structed knowledge: | concert_singer | stadium : stadium_id, location, name, capacity, highest, lowest, average | singer : singer_id, name, country, song_name, song_release_year, age, is_male | concert : concert_id, concert_name, theme, stadium_id, year | singer_in_concert : concert_id, singer_id ; SQL: select name,',
 'db_id': 'concert_singer',
 'expect': 'country',
 'expect_type': 'column',
 'ex

In [233]:
#### Overall avg

trace_scores_avg = {
    'corrupt' : defaultdict(float),
    'corrupt_text_restore': defaultdict(float),
}

for d in good_samples:
    for type_k, type_d in d['trace_scores'].items():
        for k, v in type_d.items():
            trace_scores_avg[type_k][k] += v

for type_k, type_d in trace_scores_avg.items():
    for k, s in type_d.items():
        type_d[k] = s / len(good_samples)

In [234]:
trace_scores_avg

{'corrupt': defaultdict(float,
             {'text': 0.041550320565249876,
              'struct': 0.0037280468717201805,
              'self': 0.3648502066803721,
              'struct_context': 0.006297542297234418,
              'all': 5.139980438708388e-05}),
 'corrupt_text_restore': defaultdict(float,
             {'text': 0.07921842629209634,
              'struct': 0.2023287613100771,
              'self': 0.11909362401478271,
              'struct_context': 0.11375640216889772,
              'all': 0.24608358883900305})}

In [235]:
format_print_2D_dict(trace_scores_avg)

XXXXXXXXXXXX	text  	struct	self  	struct_context	all   
corrupt     	0.0416	0.0037	0.3649	0.0063	0.0001
corrupt_text_restore	0.0792	0.2023	0.1191	0.1138	0.2461



### Exp4: inspect attention

In [335]:
exp_type = 'table'

res_path = f'/home/yshao/Projects/rome/results/gpt2_tracing/exp4_inspect_attention_gpt2/exp=4_dev_{exp_type}.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

104

In [336]:
all_samples[1]['trace_results'][0].keys()

dict_keys(['enc_sentence', 'seq_out', 'dec_prompt', 'db_id', 'expect', 'expect_type', 'expect_input_ranges', 'self_ranges', 'expect_table', 'category', 'answer', 'probs', 'base_score', 'answers_t', 'correct_prediction', 'attentions'])

In [337]:
_id = 6
d = all_samples[_id]['trace_results'][0]
d['enc_sentence'], d['dec_prompt'], d['expect']

("What are the students' first names who have both cats and dogs as pets?; structed knowledge: | pets_1 | student : stuid , lname , fname , age , sex , major , advisor , city_code | has_pet : stuid , petid | pets : petid , pettype ( dog , cat ) , pet_age , weight",
 "What are the students' first names who have both cats and dogs as pets?; structed knowledge: | pets_1 | student : stuid, lname, fname, age, sex, major, advisor, city_code | has_pet : stuid, petid | pets : petid, pettype ( dog, cat ), pet_age, weight ; SQL: select t1.fname from",
 'student')

In [338]:
d.keys(), d['attentions'].keys(), d['attentions']['all'].keys()

(dict_keys(['enc_sentence', 'seq_out', 'dec_prompt', 'db_id', 'expect', 'expect_type', 'expect_input_ranges', 'self_ranges', 'expect_table', 'category', 'answer', 'probs', 'base_score', 'answers_t', 'correct_prediction', 'attentions']),
 dict_keys(['all']),
 dict_keys(['attn', 'head_tokens', 'cand_tokens']))

In [339]:
prompt_tokens = d['attentions']['all']['cand_tokens']
prompt_tokens = '@'.join(prompt_tokens).split('@ ;@ SQL@:@')[1].split('@')
prompt_tokens

[' select', ' t', '1', '.', 'f', 'name', ' from']

#### Plotting

In [340]:
def _draw_single_plot_2(ax, val_mat, x_labels=None, y_labels=None, title=None):
    """
    X: # heads
    Y: cand tokens
    Title: head token (together with full expect)
    """
    if isinstance(val_mat, list):
        val_mat = numpy.array(val_mat)
    h = ax.pcolormesh(
        val_mat,
        cmap="Reds",
        vmax=1.0,
        vmin=0.0,
    )
    ax.invert_yaxis()
    ax.set_yticks([0.5 + i for i in range(val_mat.shape[0])])
    ax.set_xticks([0.5 + i for i in range(val_mat.shape[1])])
    if x_labels is not None:
        ax.set_xticklabels(x_labels, fontsize=8)
    if y_labels is not None:
        ax.set_yticklabels(y_labels, fontsize=8)

    if title is not None:
        ax.set_title(title)
    ax.set_xlabel(f"# Head")
    ax.set_ylabel(f"Attention candidate tokens")
    
    # cb = plt.colorbar(h)
    # divider = make_axes_locatable(ax)
    # cax = divider.append_axes('right', size='5%', pad=0.05)
    # cb = fig.colorbar(h, cax=cax)
    cb = plt.colorbar(h, ax=ax)

#     if xlabel is not None:
#         ax.set_xlabel(xlabel)
#     elif answer is not None:
#         # The following should be cb.ax.set_xlabel, but this is broken in matplotlib 3.5.1.
#         cb.ax.set_title(f"p({str(answer).strip()})", y=-0.16, fontsize=10)

In [341]:
def plot_uskg_attention_gpt2(d, inspect_layers=None, savepdf=None):
    """
    Assume 16 heads, 24 layers (T5 large config)
    
    att_part: enc, cross, dec 
    """
    
    ## encoder self attention 
    if inspect_layers is None:
        inspect_layers = [0, 6, 12, 18, 23]
    elif inspect_layers == 'all':
        inspect_layers = [i for i in range(24)]
    att_dict = d['attentions']['all']
    
    cand_len = len(att_dict['cand_tokens'])
    head_len = len(att_dict['head_tokens'])
    # prompt_tokens: for title
    prompt_tokens = d['attentions']['all']['cand_tokens']
    prompt_tokens = '@'.join(prompt_tokens).split('@ ;@ SQL@:@')[1].split('@')
    prompt_len = len(prompt_tokens)

    fig_w = len(inspect_layers) * 4 + 2
    fig_h = (0.11*cand_len + 1) * head_len + 2
    fig, ax_list = plt.subplots(
        nrows=head_len,
        ncols=len(inspect_layers),
        squeeze=False,
        figsize=(fig_w, fig_h))

    att_mat = ctu.nested_list_processing(att_dict['attn'], func=float)
    att_mat = np.array(att_mat)
    
    for expect_i in range(len(att_dict['head_tokens'])):
        for l_id, layer in enumerate(inspect_layers):
            val_mat = att_mat[layer, :, expect_i, :]  # layer, all heads, expect tok i -> all toks 
            val_mat = val_mat.transpose()    # (cand_toks, n_heads)
            x_labels = range(val_mat.shape[1])
            y_labels = att_dict['cand_tokens']
#             if att_part == 'enc':
#                 # enc: correct tokens
#                 title_toks = att_dict['head_tokens'][:expect_i] + [f"*{att_dict['head_tokens'][expect_i]}*"]
#             else:
            # cross / dec: use gold tokens from dec_prompt for previous steps and predicted token at this step
            # (dec_prompt ends with the first (head_len-1) tokens of the target node)
            title_toks = prompt_tokens[prompt_len - (head_len-1) : prompt_len - (head_len-1) + expect_i] + [f"*{att_dict['head_tokens'][expect_i]}*"]
            
            title = f"L{layer}  Head token: {' '.join(title_toks)}\n"
            
            ax = ax_list[expect_i, l_id]
            _draw_single_plot_2(ax,
                                val_mat=val_mat, 
                                x_labels=x_labels, 
                                y_labels=y_labels,
                                title=title)
            
    fig.tight_layout()
    if savepdf:
        plt.savefig(savepdf, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

In [342]:
savepdf_dir = '/home/yshao/Projects/rome/results/figs/gpt2_tracing/exp4_inspect_attention'
os.makedirs(savepdf_dir, exist_ok=True)

In [343]:
savepdf_tmpl = os.path.join(savepdf_dir, f'tmp-{_id}.pdf')
plot_uskg_attention_gpt2(d, savepdf=savepdf_tmpl)

In [344]:
## Run for all! 

fig_dir = os.path.join(savepdf_dir, f'dev_{exp_type}')
os.makedirs(fig_dir, exist_ok=True)

global_ex_id = 0
for ex_id in tqdm(range(len(all_samples))):
    for a_ex_id in range(len(all_samples[ex_id]['trace_results'])):
        d = all_samples[ex_id]['trace_results'][a_ex_id]
        _suffix = '-WRONGPRED' if not d['correct_prediction'] else ''
        savepdf_path = os.path.join(fig_dir, f'{global_ex_id}-ex={ex_id}.{a_ex_id}{_suffix}.pdf')
        plot_uskg_attention_gpt2(d, savepdf=savepdf_path)
        # print(f'{global_ex_id}-ex={ex_id}.{a_ex_id}')
        global_ex_id += 1

  0%|          | 0/104 [00:00<?, ?it/s]

In [349]:
# Check correctness stats

# corr_cnt = 0
err_cnt = 0
global_ex_id = 0
for ex_id in range(len(all_samples)):
    for a_ex_id in range(len(all_samples[ex_id]['trace_results'])):
        d = all_samples[ex_id]['trace_results'][a_ex_id]
        if not d['correct_prediction']:
            print(f'Not correct: {global_ex_id}-ex={ex_id}.{a_ex_id}')
#             savepdf_path = os.path.join(fig_dir, f'{global_ex_id}-ex={ex_id}.{a_ex_id}.pdf')
#             new_path = os.path.join(fig_dir, f'{global_ex_id}-ex={ex_id}.{a_ex_id}-WRONGPRED.pdf')
#             os.rename(savepdf_path, new_path)
            err_cnt += 1
        global_ex_id += 1

Not correct: 0-ex=0.0
Not correct: 1-ex=1.0
Not correct: 3-ex=3.0
Not correct: 4-ex=3.1
Not correct: 5-ex=4.0
Not correct: 6-ex=5.0
Not correct: 9-ex=6.2
Not correct: 12-ex=6.5
Not correct: 13-ex=7.0
Not correct: 16-ex=9.0
Not correct: 17-ex=9.1
Not correct: 19-ex=10.1
Not correct: 20-ex=10.2
Not correct: 21-ex=10.3
Not correct: 22-ex=11.0
Not correct: 24-ex=12.0
Not correct: 25-ex=13.0
Not correct: 26-ex=13.1
Not correct: 31-ex=16.0
Not correct: 32-ex=16.1
Not correct: 33-ex=17.0
Not correct: 34-ex=18.0
Not correct: 35-ex=19.0
Not correct: 36-ex=20.0
Not correct: 37-ex=21.0
Not correct: 38-ex=21.1
Not correct: 39-ex=22.0
Not correct: 40-ex=22.1
Not correct: 41-ex=22.2
Not correct: 42-ex=23.0
Not correct: 43-ex=23.1
Not correct: 44-ex=24.0
Not correct: 45-ex=24.1
Not correct: 46-ex=24.2
Not correct: 47-ex=24.3
Not correct: 48-ex=25.0
Not correct: 49-ex=26.0
Not correct: 51-ex=28.0
Not correct: 52-ex=28.1
Not correct: 53-ex=29.0
Not correct: 54-ex=30.0
Not correct: 55-ex=31.0
Not correc

In [350]:
err_cnt

105

### Exp4.1: attention pattern for schema linking

#### Check sanity

In [345]:
res_path = f'/home/yshao/Projects/rome/results/gpt2_tracing/exp4_1_attention_weights_distribution_gpt2/exp=4.1_dev.jsonl'
samples_N = len(processed_spider_dev)

# check one sample (ex)
with open(res_path, 'r') as f:
    for l in f:
        ex = ujson.loads(l)
        break

In [346]:
ex.keys(), ex['trace_results'].keys()

(dict_keys(['ex_id', 'trace_results']),
 dict_keys(['enc_sentence', 'seq_out', 'dec_prompt', 'db_id', 'col_self_ranges', 'col_context_ranges', 'tab_self_ranges', 'tab_context_ranges', 'occ_cols', 'non_occ_cols', 'occ_tabs', 'non_occ_tabs', 'attentions']))

In [347]:
for k, v in ex['trace_results'].items():
    if k != 'attentions':
        print(k, ':', v)

enc_sentence : How many singers do we have?; structed knowledge: | concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id
seq_out : select count(*) from singer
dec_prompt : select
db_id : concert_singer
col_self_ranges : {'location': [[23, 26]], 'capacity': [[27, 30]], 'highest': [[29, 32]], 'lowest': [[31, 34]], 'average': [[33, 36]], 'country': [[43, 46]], 'song_name': [[45, 50]], 'song_release_year': [[49, 56]], 'age': [[55, 58]], 'is_male': [[57, 62]], 'concert_name': [[67, 72]], 'theme': [[71, 74]], 'year': [[77, 80]]}
col_context_ranges : {'location': [[12, 23], [26, 93]], 'capacity': [[12, 27], [30, 93]], 'highest': [[12, 29], [32, 93]], 'lowest': [[12, 31], [34, 93]], 'average': [[12, 33], [36, 93]], 'country': [[12, 43], [46, 93]], 'song_name'

#### Full loading & processing

In [None]:
ds = 'dev'

res_path = f'/home/yshao/Projects/rome/results/gpt2_tracing/exp4_1_attention_weights_distribution_gpt2/exp=4.1_{ds}.jsonl'
samples_N = len(processed_spider_dev)

In [351]:
# Dict: {layer -> {head_id -> {occ_type -> {section -> List[att_w]}}}}; list for all samples, all nodes in each occ_type 
# Perhaps too large...
# att_weights_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
att_weights_sum_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))

# Dict: {occ_type -> int}
att_weights_cnt_dict = defaultdict(int)

# Dict: {layer -> {head_id -> {occ_type -> {section -> avg_att_w}}}}; averaged by all samples, all nodes in each occ_type 
att_weights_avg_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))

# Dict: {occ_type -> List[Tuple(ex_id, node_name)]}
sample_backtrace_dict = defaultdict(list)


with open(res_path, 'r') as f:
    for l in tqdm(f, total=samples_N):
        ex = ujson.loads(l)
        ex_id = ex['ex_id']

        if 'err_msg' in ex:
            continue

        result_d = ex['trace_results']

        all_col_atts = result_d['attentions']['col']
        all_tab_atts = result_d['attentions']['tab']

        for col_occ_type in ['occ_cols', 'non_occ_cols']:
            for col in result_d[col_occ_type]:
                sect_att_dict = all_col_atts[col]
                sample_backtrace_dict[col_occ_type].append((ex_id, col))
                att_weights_cnt_dict[col_occ_type] += 1

                for sect_k, att_mat in sect_att_dict.items():
                    att_mat = np.array(ctu.nested_list_processing(att_mat, func=float))
                    n_layers, n_heads = att_mat.shape
                    for l in range(n_layers):
                        for h in range(n_heads):
                            att_w = att_mat[l, h]
                            # att_weights_dict[l][h][col_occ_type][sect_k].append(att_w)
                            att_weights_sum_dict[l][h][col_occ_type][sect_k] += att_w

        for tab_occ_type in ['occ_tabs', 'non_occ_tabs']:
            for tab in result_d[tab_occ_type]:
                sect_att_dict = all_tab_atts[tab]
                sample_backtrace_dict[tab_occ_type].append((ex_id, tab))
                att_weights_cnt_dict[tab_occ_type] += 1

                for sect_k, att_mat in sect_att_dict.items():
                    att_mat = np.array(ctu.nested_list_processing(att_mat, func=float))
                    n_layers, n_heads = att_mat.shape
                    for l in range(n_layers):
                        for h in range(n_heads):
                            att_w = att_mat[l, h]
                            # att_weights_dict[l][h][tab_occ_type][sect_k].append(att_w)
                            att_weights_sum_dict[l][h][tab_occ_type][sect_k] += att_w

  0%|          | 0/1034 [00:00<?, ?it/s]

In [352]:
for l_id, layer_d in att_weights_sum_dict.items():
    for h_id, head_d in layer_d.items():
        for occ_type, occ_type_d in head_d.items():
            att_w_cnt = att_weights_cnt_dict[occ_type]
            for sect_k, att_w_sum in occ_type_d.items():
                att_weights_avg_dict[l_id][h_id][occ_type][sect_k] = att_w_sum / att_w_cnt

In [353]:
att_weights_avg_dict[0][0]

defaultdict(<function __main__.<lambda>.<locals>.<lambda>.<locals>.<lambda>()>,
            {'non_occ_cols': defaultdict(float,
                         {'prefix#0': 0.00020246420246420108,
                          'prefix#1': 0.00043223443223442597,
                          'prefix#2': 0.00011388611388611397,
                          'prefix#3': 0.0013413253413253647,
                          'prefix#4': 0.0001092241092241093,
                          'prefix#5': 0.003972027972027774,
                          'prefix#6': 0.00046486846486845797,
                          'prefix#7': 0.001931401931402047,
                          'prefix#8': 0.0007958707958707819,
                          'prefix#9': 7.992007992007997e-05,
                          'prefix#10': 0.0005328005328005245,
                          'prefix#11': 0.0010236430236430047,
                          'prefix#12': 0.0002477522477522454,
                          'prefix#13': 0.006751248751249487,
             

#### Dump results 
- do not rerun unless updated

In [354]:
res_path = f'/home/yshao/Projects/rome/results/gpt2_tracing/exp4_1_attention_weights_distribution_gpt2/exp=4.1_{ds}_dump.pkl'

dump_d = {
    'att_weights_sum_dict': ctu.nested_json_processing(att_weights_sum_dict, func=lambda x: x),   # defaultdict -> dict 
    'att_weights_cnt_dict': ctu.nested_json_processing(att_weights_cnt_dict, func=lambda x: x),
    'att_weights_avg_dict': ctu.nested_json_processing(att_weights_avg_dict, func=lambda x: x),
    'sample_backtrace_dict': ctu.nested_json_processing(sample_backtrace_dict, func=lambda x: x),
}

with open(res_path, 'wb') as f:
    pickle.dump(dump_d, f)

#### Load results 

In [355]:
res_path = f'/home/yshao/Projects/rome/results/gpt2_tracing/exp4_1_attention_weights_distribution_gpt2/exp=4.1_dev_dump.pkl'

with open(res_path, 'rb') as f:
    dump_d = pickle.load(f)

att_weights_sum_dict = dump_d['att_weights_sum_dict']
att_weights_cnt_dict = dump_d['att_weights_cnt_dict']
att_weights_avg_dict = dump_d['att_weights_avg_dict']
sample_backtrace_dict = dump_d['sample_backtrace_dict']

In [356]:
att_weights_avg_dict[0][0]

{'non_occ_cols': {'prefix#0': 0.00020246420246420108,
  'prefix#1': 0.00043223443223442597,
  'prefix#2': 0.00011388611388611397,
  'prefix#3': 0.0013413253413253647,
  'prefix#4': 0.0001092241092241093,
  'prefix#5': 0.003972027972027774,
  'prefix#6': 0.00046486846486845797,
  'prefix#7': 0.001931401931402047,
  'prefix#8': 0.0007958707958707819,
  'prefix#9': 7.992007992007997e-05,
  'prefix#10': 0.0005328005328005245,
  'prefix#11': 0.0010236430236430047,
  'prefix#12': 0.0002477522477522454,
  'prefix#13': 0.006751248751249487,
  'prefix#14': 0.014741258741258464,
  'prefix#15': 0.00138927738927742,
  'prefix#16': 0.00025308025308025064,
  'prefix#17': 0.006893772893773302,
  'prefix#18': 0.0002823842823842793,
  'prefix#19': 0.00019114219114219,
  'text': 0.27961971361970833,
  'self': 0.05460073260072387,
  'context': 0.5567678987678978,
  'others': 0.04184948384947878},
 'occ_tabs': {'prefix#0': 0.0001237785016286645,
  'prefix#1': 0.0008664495114006521,
  'prefix#2': 0.0001172

In [357]:
occ_types = ['occ_cols', 'non_occ_cols']

for l_id in [1, 6, 12, 18, 23]:
    # Dict: occ_type -> sect -> List[att_w]; list for all heads 
    layer_ob_dict = defaultdict(lambda: defaultdict(list))

    for h_id in range(len(att_weights_avg_dict[l_id])):
        for occ in occ_types:
            for sect_k, att_w in att_weights_avg_dict[l_id][h_id][occ].items():
                att_w_str = np.format_float_positional(att_w, precision=2, min_digits=2)
                layer_ob_dict[occ][sect_k].append(att_w_str)
    
    print(f'===== Layer {l_id} =====')
    for occ in occ_types:
        print(occ)
        for sect_k, att_w_list in layer_ob_dict[occ].items():
            att_w_list_str = "  ".join(att_w_list)
            print(f'{sect_k:<10s}{att_w_list_str}')
        print()

===== Layer 1 =====
occ_cols
prefix#0  0.01  0.01  0.00  0.01  0.00  0.01  0.01  0.00  0.00  0.02  0.04  0.02  0.02  0.00  0.01  0.00
prefix#1  0.01  0.02  0.01  0.02  0.00  0.02  0.00  0.01  0.01  0.01  0.05  0.01  0.01  0.02  0.03  0.00
prefix#2  0.01  0.01  0.00  0.02  0.00  0.01  0.01  0.01  0.01  0.01  0.02  0.02  0.00  0.01  0.00  0.01
prefix#3  0.01  0.01  0.01  0.05  0.00  0.01  0.01  0.00  0.00  0.01  0.03  0.01  0.00  0.03  0.00  0.02
prefix#4  0.00  0.01  0.01  0.01  0.00  0.01  0.00  0.00  0.00  0.01  0.03  0.01  0.02  0.01  0.01  0.01
prefix#5  0.02  0.04  0.02  0.08  0.00  0.06  0.02  0.00  0.01  0.01  0.04  0.01  0.02  0.02  0.00  0.00
prefix#6  0.01  0.01  0.01  0.04  0.00  0.01  0.00  0.00  0.00  0.02  0.02  0.02  0.03  0.01  0.03  0.00
prefix#7  0.00  0.02  0.01  0.00  0.00  0.01  0.00  0.00  0.00  0.01  0.05  0.03  0.03  0.01  0.01  0.01
prefix#8  0.01  0.01  0.02  0.03  0.00  0.01  0.00  0.01  0.00  0.01  0.02  0.01  0.01  0.04  0.00  0.01
prefix#9  0.01  0.01  0.01

In [358]:
len(att_weights_avg_dict[0])

16

## Temp

In [10]:
gpt2_model = AutoModelForPreTraining.from_pretrained('gpt2')

In [9]:
gpt2_tokenizer = AutoTokenizer.from_pretrained('gpt2')

### Tokenizer

In [11]:
gpt2_tokenizer

PreTrainedTokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_len=1024, is_fast=True, padding_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'})

In [12]:
gpt2_tokenizer.eos_token_id

50256

In [13]:
_toks = gpt2_tokenizer.tokenize(' Select * from singer')
_toks

['ĠSelect', 'Ġ*', 'Ġfrom', 'Ġsinger']

In [18]:
_toks = gpt2_tokenizer.tokenize('x ; SQL: Select')
_toks

['x', 'Ġ;', 'ĠSQL', ':', 'ĠSelect']

In [19]:
gpt2_tokenizer.convert_tokens_to_string(_toks)

'x ; SQL: Select'

In [20]:
gpt2_tokenizer.convert_tokens_to_ids(_toks)

[87, 2162, 16363, 25, 9683]

In [30]:
_sql = "select role_code from project_staff where date_from > '2003-04-19 15:06:20' and date_to < '2016-03-15 00:33:18'"
_toks = gpt2_tokenizer.tokenize(_sql)

In [31]:
print(_toks)

['select', 'Ġrole', '_', 'code', 'Ġfrom', 'Ġproject', '_', 'staff', 'Ġwhere', 'Ġdate', '_', 'from', 'Ġ>', "Ġ'", '2003', '-', '04', '-', '19', 'Ġ15', ':', '06', ':', '20', "'", 'Ġand', 'Ġdate', '_', 'to', 'Ġ<', "Ġ'", '2016', '-', '03', '-', '15', 'Ġ00', ':', '33', ':', '18', "'"]


In [32]:
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

In [33]:
gpt2_tokenizer.convert_tokens_to_string(_toks)

"select role_code from project_staff where date_from > '2003-04-19 15:06:20' and date_to < '2016-03-15 00:33:18'"

In [34]:
len(_toks)

42

In [142]:
ex['parsed_struct_in']

((1, 'car_1', 'car_1'),
 [((3, 'continents', 'continents'),
   [[(5, 'contid', 'contid'), []],
    [(7, 'continent', 'continent ( europe )'), [(9, 'europe', 'europe')]]]),
  ((12, 'countries', 'countries'),
   [[(14, 'countryid', 'countryid'), []],
    [(16, 'countryname', 'countryname'), []],
    [(18, 'continent', 'continent'), []]]),
  ((20, 'car_makers', 'car_makers'),
   [[(22, 'id', 'id'), []],
    [(24, 'maker', 'maker'), []],
    [(26, 'fullname', 'fullname'), []],
    [(28, 'country', 'country'), []]]),
  ((30, 'model_list', 'model_list'),
   [[(32, 'modelid', 'modelid'), []],
    [(34, 'maker', 'maker'), []],
    [(36, 'model', 'model'), []]]),
  ((38, 'car_names', 'car_names'),
   [[(40, 'makeid', 'makeid'), []],
    [(42, 'model', 'model'), []],
    [(44, 'make', 'make'), []]]),
  ((46, 'cars_data', 'cars_data'),
   [[(48, 'id', 'id'), []],
    [(50, 'mpg', 'mpg'), []],
    [(52, 'cylinders', 'cylinders'), []],
    [(54, 'edispl', 'edispl'), []],
    [(56, 'horsepower', 'ho

In [136]:
# _sent = 'What are the names of all European countries with at least 3 manufacturers?; structed knowledge: | car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year'
_sent = '| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year'
_sent

'| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year'

In [137]:
_toks = mt_uskg_gpt2.tokenizer.tokenize(_sent)

In [138]:
tokenized_sent = mt_uskg_gpt2.tokenizer(_sent)

In [139]:
list(enumerate(_sent.split()))

[(0, '|'),
 (1, 'car_1'),
 (2, '|'),
 (3, 'continents'),
 (4, ':'),
 (5, 'contid'),
 (6, ','),
 (7, 'continent'),
 (8, '('),
 (9, 'europe'),
 (10, ')'),
 (11, '|'),
 (12, 'countries'),
 (13, ':'),
 (14, 'countryid'),
 (15, ','),
 (16, 'countryname'),
 (17, ','),
 (18, 'continent'),
 (19, '|'),
 (20, 'car_makers'),
 (21, ':'),
 (22, 'id'),
 (23, ','),
 (24, 'maker'),
 (25, ','),
 (26, 'fullname'),
 (27, ','),
 (28, 'country'),
 (29, '|'),
 (30, 'model_list'),
 (31, ':'),
 (32, 'modelid'),
 (33, ','),
 (34, 'maker'),
 (35, ','),
 (36, 'model'),
 (37, '|'),
 (38, 'car_names'),
 (39, ':'),
 (40, 'makeid'),
 (41, ','),
 (42, 'model'),
 (43, ','),
 (44, 'make'),
 (45, '|'),
 (46, 'cars_data'),
 (47, ':'),
 (48, 'id'),
 (49, ','),
 (50, 'mpg'),
 (51, ','),
 (52, 'cylinders'),
 (53, ','),
 (54, 'edispl'),
 (55, ','),
 (56, 'horsepower'),
 (57, ','),
 (58, 'weight'),
 (59, ','),
 (60, 'accelerate'),
 (61, ','),
 (62, 'year')]

In [140]:
list(enumerate(_toks))

[(0, '|'),
 (1, 'Ġcar'),
 (2, '_'),
 (3, '1'),
 (4, 'Ġ|'),
 (5, 'Ġcontinents'),
 (6, 'Ġ:'),
 (7, 'Ġcont'),
 (8, 'id'),
 (9, 'Ġ,'),
 (10, 'Ġcontinent'),
 (11, 'Ġ('),
 (12, 'Ġeuro'),
 (13, 'pe'),
 (14, 'Ġ)'),
 (15, 'Ġ|'),
 (16, 'Ġcountries'),
 (17, 'Ġ:'),
 (18, 'Ġcountry'),
 (19, 'id'),
 (20, 'Ġ,'),
 (21, 'Ġcountry'),
 (22, 'name'),
 (23, 'Ġ,'),
 (24, 'Ġcontinent'),
 (25, 'Ġ|'),
 (26, 'Ġcar'),
 (27, '_'),
 (28, 'makers'),
 (29, 'Ġ:'),
 (30, 'Ġid'),
 (31, 'Ġ,'),
 (32, 'Ġmaker'),
 (33, 'Ġ,'),
 (34, 'Ġfull'),
 (35, 'name'),
 (36, 'Ġ,'),
 (37, 'Ġcountry'),
 (38, 'Ġ|'),
 (39, 'Ġmodel'),
 (40, '_'),
 (41, 'list'),
 (42, 'Ġ:'),
 (43, 'Ġmodel'),
 (44, 'id'),
 (45, 'Ġ,'),
 (46, 'Ġmaker'),
 (47, 'Ġ,'),
 (48, 'Ġmodel'),
 (49, 'Ġ|'),
 (50, 'Ġcar'),
 (51, '_'),
 (52, 'names'),
 (53, 'Ġ:'),
 (54, 'Ġmake'),
 (55, 'id'),
 (56, 'Ġ,'),
 (57, 'Ġmodel'),
 (58, 'Ġ,'),
 (59, 'Ġmake'),
 (60, 'Ġ|'),
 (61, 'Ġcars'),
 (62, '_'),
 (63, 'data'),
 (64, 'Ġ:'),
 (65, 'Ġid'),
 (66, 'Ġ,'),
 (67, 'Ġm'),
 (68, 'pg'),
 (69

In [141]:
list([(i, mt_uskg_gpt2.tokenizer.decode(tokenized_sent['input_ids'][i])) for i in range(len(tokenized_sent['input_ids']))])

[(0, '|'),
 (1, ' car'),
 (2, '_'),
 (3, '1'),
 (4, ' |'),
 (5, ' continents'),
 (6, ' :'),
 (7, ' cont'),
 (8, 'id'),
 (9, ','),
 (10, ' continent'),
 (11, ' ('),
 (12, ' euro'),
 (13, 'pe'),
 (14, ' )'),
 (15, ' |'),
 (16, ' countries'),
 (17, ' :'),
 (18, ' country'),
 (19, 'id'),
 (20, ','),
 (21, ' country'),
 (22, 'name'),
 (23, ','),
 (24, ' continent'),
 (25, ' |'),
 (26, ' car'),
 (27, '_'),
 (28, 'makers'),
 (29, ' :'),
 (30, ' id'),
 (31, ','),
 (32, ' maker'),
 (33, ','),
 (34, ' full'),
 (35, 'name'),
 (36, ','),
 (37, ' country'),
 (38, ' |'),
 (39, ' model'),
 (40, '_'),
 (41, 'list'),
 (42, ' :'),
 (43, ' model'),
 (44, 'id'),
 (45, ','),
 (46, ' maker'),
 (47, ','),
 (48, ' model'),
 (49, ' |'),
 (50, ' car'),
 (51, '_'),
 (52, 'names'),
 (53, ' :'),
 (54, ' make'),
 (55, 'id'),
 (56, ','),
 (57, ' model'),
 (58, ','),
 (59, ' make'),
 (60, ' |'),
 (61, ' cars'),
 (62, '_'),
 (63, 'data'),
 (64, ' :'),
 (65, ' id'),
 (66, ','),
 (67, ' m'),
 (68, 'pg'),
 (69, ','),
 (7

In [129]:
tokenized_sent.word_to_tokens(12)

TokenSpan(start=12, end=13)

In [144]:
len(tokenized_sent.word_ids()), tokenized_sent.word_ids()

(83,
 [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  7,
  8,
  9,
  10,
  11,
  11,
  12,
  13,
  14,
  15,
  16,
  16,
  17,
  18,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  60,
  61,
  62,
  63,
  64,
  64,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72])

In [134]:
i = 0
_words = []
while True:
    try:
        s, e = tokenized_sent.word_to_chars(i)
        _words.append(_sent[s : e])
        i += 1
    except:
        break

In [135]:
_words

['What',
 ' are',
 ' the',
 ' names',
 ' of',
 ' all',
 ' European',
 ' countries',
 ' with',
 ' at',
 ' least',
 ' 3',
 ' manufacturers',
 '?;',
 ' structed',
 ' knowledge',
 ':',
 ' |',
 ' car',
 '_',
 '1',
 ' |',
 ' continents',
 ' :',
 ' contid',
 ' ,',
 ' continent',
 ' (',
 ' europe',
 ' )',
 ' |',
 ' countries',
 ' :',
 ' countryid',
 ' ,',
 ' countryname',
 ' ,',
 ' continent',
 ' |',
 ' car',
 '_',
 'makers',
 ' :',
 ' id',
 ' ,',
 ' maker',
 ' ,',
 ' fullname',
 ' ,',
 ' country',
 ' |',
 ' model',
 '_',
 'list',
 ' :',
 ' modelid',
 ' ,',
 ' maker',
 ' ,',
 ' model',
 ' |',
 ' car',
 '_',
 'names',
 ' :',
 ' makeid',
 ' ,',
 ' model',
 ' ,',
 ' make',
 ' |',
 ' cars',
 '_',
 'data',
 ' :',
 ' id',
 ' ,',
 ' mpg',
 ' ,',
 ' cylinders',
 ' ,',
 ' edispl',
 ' ,',
 ' horsepower',
 ' ,',
 ' weight',
 ' ,',
 ' accelerate',
 ' ,',
 ' year']

In [146]:
tokenized_sent.char_to_token(10)

5

In [147]:
_sent[8:12], _toks[5]

('| co', 'Ġcontinents')

### GPT2 Model

In [38]:
[k for k, v in gpt2_model.named_modules()]

['',
 'transformer',
 'transformer.wte',
 'transformer.wpe',
 'transformer.drop',
 'transformer.h',
 'transformer.h.0',
 'transformer.h.0.ln_1',
 'transformer.h.0.attn',
 'transformer.h.0.attn.c_attn',
 'transformer.h.0.attn.c_proj',
 'transformer.h.0.attn.attn_dropout',
 'transformer.h.0.attn.resid_dropout',
 'transformer.h.0.ln_2',
 'transformer.h.0.mlp',
 'transformer.h.0.mlp.c_fc',
 'transformer.h.0.mlp.c_proj',
 'transformer.h.0.mlp.dropout',
 'transformer.h.1',
 'transformer.h.1.ln_1',
 'transformer.h.1.attn',
 'transformer.h.1.attn.c_attn',
 'transformer.h.1.attn.c_proj',
 'transformer.h.1.attn.attn_dropout',
 'transformer.h.1.attn.resid_dropout',
 'transformer.h.1.ln_2',
 'transformer.h.1.mlp',
 'transformer.h.1.mlp.c_fc',
 'transformer.h.1.mlp.c_proj',
 'transformer.h.1.mlp.dropout',
 'transformer.h.2',
 'transformer.h.2.ln_1',
 'transformer.h.2.attn',
 'transformer.h.2.attn.c_attn',
 'transformer.h.2.attn.c_proj',
 'transformer.h.2.attn.attn_dropout',
 'transformer.h.2.attn.r

### Attention Format

In [249]:
_id = 130
ex = copy.deepcopy(processed_spider_dev[_id])
ctu2.add_basic_analysis_info_gpt2(mt_uskg_gpt2, ex)

In [250]:
ex.keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'tokenized_item', 'pre_sql_sequence', 'text_range', 'struct_range', 'sql_range', 'parsed_struct_in', 'alias2table', 'col2table', 'col_name_counter', 'tab_name_counter', 'struct_node_ranges_dict', 'sql_tokens', 'sql_token_ranges', 'tok_ranges2type', 'type2tok_ranges', 'sql_col_nodes', 'sql_tab_nodes', 'sql_alias_nodes'])

In [251]:
_item = ex['tokenized_item']
_item.keys()

dict_keys(['input_ids', 'attention_mask', 'predict_input_ids', 'predict_attention_mask', 'labels'])

In [252]:
dec_prompt = ex['pre_sql_sequence'] + ' ' + ex['seq_out']
dec_prompt

"What are the names of all European countries with at least 3 manufacturers?; structed knowledge: | car_1 | continents : contid, continent ( europe ) | countries : countryid, countryname, continent | car_makers : id, maker, fullname, country | model_list : modelid, maker, model | car_names : makeid, model, make | cars_data : id, mpg, cylinders, edispl, horsepower, weight, accelerate, year ; SQL: select t1.countryname from countries as t1 join continents as t2 on t1.continent = t2.contid join car_makers as t3 on t1.countryid = t3.country where t2.continent = 'europe' group by t1.countryname having count(*) >= 3;"

In [263]:
inp = ctu2.make_inputs_gpt2(mt_uskg_gpt2.tokenizer, [dec_prompt] * 2)
inp.keys(), inp['input_ids'].size()

(dict_keys(['input_ids', 'attention_mask']), torch.Size([2, 176]))

In [264]:
with torch.no_grad():
    outputs_exp = ctu2.run_model_forward_uskg_gpt2(mt_uskg_gpt2.model, **inp, output_attentions=True)

In [265]:
outputs_exp.keys()

odict_keys(['logits', 'past_key_values', 'attentions'])

In [266]:
len(outputs_exp['attentions'])

24

In [267]:
outputs_exp['attentions'][0].size()

torch.Size([2, 16, 176, 196])

In [None]:
## outputs_exp['attentions']: (n_layers, (bsz, n_heads, seq_len, seq_len + prompt_len))

### Temp debugging

In [21]:
import uskg

from transformers.models.auto.configuration_auto import AutoConfig

In [45]:
import importlib

importlib.reload(uskg.models.prompt.modeling_auto)

from uskg.models.unified import finetune, prefixtuning

from uskg.models.prompt.modeling_auto import AutoModelForPreTraining
from uskg.models.prompt.modeling_gpt2 import GPT2LMHeadModel
from uskg.models.prompt.modeling_t5 import T5ForConditionalGeneration

In [46]:
_m = AutoModelForPreTraining.from_pretrained('gpt2')

In [47]:
type(_m)

uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

In [48]:
id(type(_m))

94854136059536

In [49]:
isinstance(_m, uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel)

True

In [50]:
isinstance(_m, GPT2LMHeadModel)

True

In [51]:
GPT2LMHeadModel is uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

True

In [37]:
id(uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel)

94854135704752

In [15]:
type(_m) is uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

True

In [67]:
type(_m) is GPT2LMHeadModel

False

In [69]:
isinstance(_m, type(_m))

True

In [64]:
type(_m)

uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

In [65]:
_m.__class__

uskg.models.prompt.modeling_gpt2.GPT2LMHeadModel

In [87]:
_cfg = AutoConfig.from_pretrained('gpt2')

In [88]:
_cfg

GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.9.2",
  "use_cache": true,
  "vocab_size": 50257
}

In [None]:
## T5
_m2 = AutoModelForPreTraining.from_pretrained('t5-base')

In [83]:
isinstance(_m2, uskg.models.prompt.modeling_t5.T5ForConditionalGeneration)

True

In [84]:
isinstance(_m2, T5ForConditionalGeneration)

True