<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [2]:
IS_COLAB = False

## Causal Tracing

A demonstration of the double-intervention causal tracing method.

The strategy used by causal tracing is to understand important
states within a transfomer by doing two interventions simultaneously:

1. Corrupt a subset of the input.  In our paper, we corrupt the subject tokens
   to frustrate the ability of the transformer to accurately complete factual
   prompts about the subject.
2. Restore a subset of the internal hidden states.  In our paper, we scan
   hidden states at all layers and all tokens, searching for individual states
   that carry the necessary information for the transformer to recover its
   capability to complete the factual prompt.

The traces of decisive states can be shown on a heatmap.  This notebook
demonstrates the code for conducting causal traces and creating these heatmaps.

In [3]:
%load_ext autoreload
%autoreload 2

The `experiments.causal_trace` module contains a set of functions for running causal traces.

In this notebook, we reproduce, demonstrate and discuss the interesting functions.

We begin by importing several utility functions that deal with tokens and transformer models.

In [4]:
import os, sys, re, json
import string
import torch
import numpy as np
import copy
from collections import defaultdict, Counter
from util import nethook
from util.globals import DATA_DIR
from experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)
from dsets import KnownsDataset

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f79bc4e45e0>

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## USKG

In [7]:
from transformers import (
    HfArgumentParser,
    set_seed,
    AutoTokenizer
)

# from uskg.models.unified.prefixtuning import Model
from uskg.models.unified import finetune, prefixtuning
from uskg.utils.configue import Configure
from uskg.utils.training_arguments import WrappedSeq2SeqTrainingArguments
from uskg.seq2seq_construction import spider as s2s_spider
from uskg.third_party.spider.preprocess.get_tables import dump_db_json_schema
from uskg.third_party.spider import evaluation as sp_eval
from tqdm.notebook import tqdm

# from nltk.stem.wordnet import WordNetLemmatizer
# import stanza

import matplotlib.pyplot as plt
import sqlite3

from experiments import causal_trace_uskg as ctu

In [82]:
mt_uskg = ctu.ModelAndTokenizer_USKG('t5-large-prefix')

Using tokenizer_uskg: hkunlp/from_all_T5_large_prefix_spider_with_cell_value2
Using tokenizer_fast: t5-large
prefix-tuning sequence length is 10.


In [83]:
list(mt_uskg.task_args.seq2seq)

[('constructor', 'seq2seq_construction.spider'),
 ('schema_serialization_with_db_content', True),
 ('target_with_db_id', False)]

In [84]:
mt_uskg.model.pretrain_model.encoder.embed_tokens is mt_uskg.model.pretrain_model.shared, \
mt_uskg.model.pretrain_model.decoder.embed_tokens is mt_uskg.model.pretrain_model.shared

(True, False)

In [85]:
mt_uskg.model.preseqlen

10

In [86]:
# [k for k,v in mt_uskg.model.named_parameters()]
# [k for k,v in mt_uskg.model.named_modules()]

In [87]:
inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    enc_sentences=["Translate to German: My name is Wolfgang and I live in Berlin"],
    dec_prompts=["Mein Name ist Wolfgang"],
    device="cuda:0"
)

In [88]:
out = ctu.run_model_forward_uskg(mt_uskg.model, **inp)

In [89]:
out.keys(), out['logits'].size()

(odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state']),
 torch.Size([1, 5, 32102]))

In [90]:
logits = out["logits"][0, -1].detach().cpu().numpy()
logits.shape

(32102,)

In [91]:
top_5 = sorted(list(enumerate(logits)), key=lambda p: -p[1])[:5]
top_5

[(11, -1.8642352),
 (6, -9.727753),
 (5, -10.966707),
 (27, -11.037394),
 (213, -12.864212)]

In [92]:
[mt_uskg.tokenizer.decode([p[0]]) for p in top_5]

['and', ',', '.', 'I', 'where']

### Load spider dataset

In [93]:
spider_train_path = '/home/yshao/Projects/SDR-analysis/data/spider/train+ratsql_graph.json'
spider_dev_path = '/home/yshao/Projects/SDR-analysis/data/spider/dev+ratsql_graph.json'
spider_db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [94]:
raw_spider_dev = ctu.load_raw_dataset(
    data_filepath = spider_dev_path,
    db_path=spider_db_dir,
#     schema_cache=SCHEMA_CACHE
)
len(raw_spider_dev)

1034

In [95]:
raw_spider_dev[0].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph'])

In [96]:
mt_uskg.task_args.dataset.use_cache

True

In [97]:
processed_spider_dev = s2s_spider.DevDataset(
    args=mt_uskg.task_args,
    raw_datasets=raw_spider_dev,
    cache_root='../cache')

In [98]:
_id = 130
processed_spider_dev[_id]['text_in'], \
processed_spider_dev[_id]['struct_in'], \
processed_spider_dev[_id]['seq_out']

('What are the names of all European countries with at least 3 manufacturers?',
 '| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.countryname from countries as t1 join continents as t2 on t1.continent = t2.contid join car_makers as t3 on t1.countryid = t3.country where t2.continent = 'europe' group by t1.countryname having count(*) >= 3;")

In [99]:
_enc_sentence = f"{processed_spider_dev[_id]['text_in']}; structed knowledge: {processed_spider_dev[_id]['struct_in']}"
_toks = mt_uskg.tokenizer.tokenize(_enc_sentence)
len(_toks)

142

In [100]:
# # _occ_punct = set()

# for _id in range(len(processed_spider_dev)):
#     ex = processed_spider_dev[_id]
# #     _occ_punct.update(set(string.punctuation) & set(ex['seq_out']))
#     if '_(' in ex['struct_in']:
#         print(_id, ex['question'])
#         print(ex['struct_in'])
#         print(ex['seq_out'])
#         print()

In [101]:
# ## Train set

# raw_spider_train = ctu.load_raw_dataset(
#     data_filepath = spider_train_path,
#     db_path=spider_db_dir,
# )
# processed_spider_train = s2s_spider.TrainDataset(
#     args=mt_uskg.task_args,
#     raw_datasets=raw_spider_train,
#     cache_root='../cache')
# len(processed_spider_train)

In [102]:
# processed_spider_train[5441]

### Helpers
- merged in create_analysis_sample_dicts()

#### Evaluator

In [121]:
table_path = '/home/yshao/Projects/language/language/xsp/data/spider/tables.json'
db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [122]:
kmaps = sp_eval.build_foreign_key_map_from_json(table_path)
evaluator = sp_eval.Evaluator(db_dir=db_dir, kmaps=kmaps, etype='all')

In [123]:
ctu.evaluate_hardness.evaluator = evaluator

In [124]:
# test
_sql_str = 'select t1.birth_date from people as t1 join poker_player as t2 on t1.people_id = t2.people_id order by t2.earnings asc limit 1'
db_name = 'poker_player'
schema = evaluator.schemas[db_name]
_sql = sp_eval.get_sql(schema, _sql_str)
sp_eval.count_component1(_sql), sp_eval.count_component2(_sql), sp_eval.count_others(_sql), \
evaluator.eval_hardness(_sql)

(3, 0, 0, 'hard')

#### Hardness

In [125]:
ctu.evaluate_hardness(_sql_str, db_name, evaluator=evaluator)

'hard'

In [126]:
ctu.evaluate_hardness.evaluator

<uskg.third_party.spider.evaluation.Evaluator at 0x7f7814ed2040>

#### Node role

In [None]:
dec_prompt = 'select avg(age), min(age), max(age) from'
ctu.detect_node_role(dec_prompt)

#### Text match

In [None]:
a_dicts = ctu.create_analysis_sample_dicts(
    mt=mt_uskg,
    ex=processed_spider_dev[100],
    subject_type='table'
)
len(a_dicts), [d['expect'] for d in a_dicts]

In [None]:
a_ex = a_dicts[2]
ctu.check_table_text_match(a_ex, 'car_names')

In [None]:
a_ex['text_in']

### Exp-5.0: dirty attention vector effect 

#### Load & Check results

In [255]:
expect_type = 'table_alias'
res_path = f'/home/yshao/Projects/rome/results/exp5_0_dirty_attention_vector_effect/exp=5_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [256]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [257]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1
#         # TEMP adjustment for column results 
#         d['low_score'] = d['trace_scores']['high_layers_corrupt'].get("0", 0.0)  # "0" is key (for layer 0), 0.0 is default 
#         if d['base_score'] - d['low_score'] < 0.5:
#             d['is_good_sample'] = False
#         # END_TEMP
        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5, (i, d)
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
n_good_samples + n_too_easy

(2039, (164, 164), 339, 1536, 1875, 1700)

In [None]:
[s for s in bad_samples if s['correct_prediction']][0]

In [None]:
good_samples[0]

#### Overall avg

In [193]:
trace_scores_avg = {k: {str(l): 0 for l in range(24)} for k in good_samples[0]['trace_scores'].keys()}

In [194]:
for d in good_samples:
    for k, layer_d in d['trace_scores'].items():
        for l, s in layer_d.items():
            trace_scores_avg[k][l] += s

for k, layer_d in trace_scores_avg.items():
    for l, s in layer_d.items():
        layer_d[l] = s / len(good_samples)

In [None]:
trace_scores_avg

#### Avg by aspects (category)
- Still kind of linear, as in exp-2.2

In [243]:
d['category']

{'sql_hardness': 'hard', 'node_role': 'where', 'text_match': 'partial'}

In [265]:
# Key: (trace_key, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no trace key & layer key 

In [266]:
for d in good_samples:
    for trace_k, trace_layer_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for l, s in trace_layer_d.items():
                trace_scores_by_aspect[trace_k][aspect][asp_val][l].append(s)

for trace_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for l, s in d3.items():
                trace_scores_avg_by_aspect[trace_k][asp_k][asp_v][l] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [267]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 400,
                          'hard': 155,
                          'easy': 142,
                          'extra': 170}),
             'node_role': defaultdict(int,
                         {'where': 268,
                          'select': 412,
                          'order by': 66,
                          'join': 92,
                          'group by': 25,
                          'having': 4}),
             'text_match': defaultdict(int,
                         {'no-match': 360, 'partial': 148, 'exact': 359})})

In [None]:
trace_scores_avg_by_aspect['high_layers_corrupt']

### Exp-5.2: attention section removal effect

#### Load & Check

In [679]:
expect_type = 'table'

res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [680]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [681]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples),\
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(1683, (1207, 1207), 136, 340, 476, 'good / correct = 1207 / 1547')

In [682]:
[s for s in bad_samples if not s['correct_prediction']][0]

{'enc_sentence': 'Find the number of concerts happened in the stadium with the highest capacity .; structed knowledge: | concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id',
 'seq_out': 'select count(*) from concert where stadium_id = (select stadium_id from stadium order by capacity desc limit 1)',
 'dec_prompt': 'select count(*) from',
 'expect': 'concert',
 'expect_type': 'table',
 'db_id': 'concert_singer',
 'expect_input_ranges': [[88, 89]],
 'expect_table': 'concert',
 'answer': 'stadium',
 'base_score': 0.9941871166229248,
 'answers_t': [14939],
 'correct_prediction': False,
 'category': {'sql_hardness': 'hard',
  'node_role': 'from',
  'text_match': 'exact'},
 'self_ranges': [[87, 91]],
 'struct_context_ranges': [[22, 87], [91, 132]],
 'is_

In [None]:
good_samples[0]

#### Overall avg

In [647]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            if k == 'window':
                for l, s in v.items():
                    trace_scores_avg[sect_k][f'{k}-{l}'] += s
            else:
                s = v
                trace_scores_avg[sect_k][k] += s

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [None]:
trace_scores_avg

#### Avg by aspects (category)

In [696]:
# TEMP patch for node_len category 
for d in good_samples + bad_samples:
    node_len = len(d['answers_t'])
    assert len(mt_uskg.tokenizer.tokenize(d['expect'])) == node_len, (d['expect'], node_len)
    d['category']['node_len'] = str(node_len) if node_len <= 3 else '4+'

In [701]:
d['category']

{'sql_hardness': 'medium',
 'node_role': 'from',
 'text_match': 'exact',
 'node_len': '1'}

In [651]:
# Key: (sect_k, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no sect key & layer key 

In [652]:
for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for k, v in sect_d.items():
                if k == 'window':
                    for l, s in v.items():
                        if not (int(l) % 4 == 3): continue
                        layer_k = f'{k}-{l}'
                        trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                else:
                    layer_k = k
                    s = v
                    trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                    
for sect_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for layer_k, s in d3.items():
                trace_scores_avg_by_aspect[sect_k][asp_k][asp_v][layer_k] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [653]:
for sect_k, sect_d in trace_scores_avg_by_aspect.items():
    sect_d['overall'] = dict()
    for layer_k, s in trace_scores_avg[sect_k].items():
        if layer_k.startswith('window'):
            # only keep a subset of layers 
            _, l = layer_k.split('-')
            if not (int(l) % 4 == 3): continue
        sect_d['overall'][layer_k] = s

In [654]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 378,
                          'hard': 148,
                          'easy': 134,
                          'extra': 160}),
             'node_role': defaultdict(int,
                         {'where': 248,
                          'select': 393,
                          'order by': 63,
                          'join': 91,
                          'group by': 21,
                          'having': 4}),
             'text_match': defaultdict(int,
                         {'no-match': 345, 'partial': 142, 'exact': 333}),
             'node_len': defaultdict(int,
                         {'1': 329, '3': 238, '4+': 160, '2': 93})})

In [None]:
trace_scores_avg_by_aspect['self']

In [641]:
dump_d = ctu.nested_json_processing(trace_scores_avg_by_aspect, func=lambda x: np.format_float_positional(x, precision=4, min_digits=4))
# dump_d

In [642]:
dump_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/summ-exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(dump_path, 'w') as f:
    json.dump(dump_d, f, indent=1)

#### (one-time temp patch)

In [531]:
# expect_type = 'table_alias'
# orig_res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/no_structcontext-exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'
# add_res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/exp=5.2.1+structcontext_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

# merge_res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

In [532]:
# with open(orig_res_path, 'r') as f:
#     orig_all_samples = [json.loads(l) for l in f]
# with open(add_res_path, 'r') as f:
#     add_all_samples = [json.loads(l) for l in f]

# f = open(merge_res_path, 'w')
    
# for i, (orig_ex, add_ex) in enumerate(zip(orig_all_samples, add_all_samples)):
#     assert len(orig_ex['trace_results']) == len(add_ex['trace_results']), i
#     # There is randomness in the order of expected node (from set()), thus sorting here 
#     orig_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     add_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     for j, (orig_d, add_d) in enumerate(zip(orig_ex['trace_results'], add_ex['trace_results'])):
#         assert orig_d['is_good_sample'] == add_d['is_good_sample'], (i, j)
#         if not orig_d['is_good_sample']:
#             continue
            
#         # is good sample: add the new sections 
#         orig_d['trace_scores']['struct_context'] = add_d['trace_scores']['struct_context']
#         orig_d['trace_scores']['text+struct_context'] = add_d['trace_scores']['text+struct_context']
        
#     f.write(json.dumps(orig_ex, indent=None) + '\n')
    
# f.close()

#### Single samples observations

In [702]:
good_samples[0]['trace_scores'].keys()

dict_keys(['prefix', 'text', 'struct', 'text+struct', 'all', 'self', 'struct_context', 'text+struct_context'])

In [703]:
_id = 0

d = good_samples[_id]

check_info_d = defaultdict(dict)

for sect_k, sect_d in d['trace_scores'].items():
    for layer_k, s in sect_d.items():
        if layer_k == 'window':
            layer_k = 'window-19'
            s = s['19']
        if s < 0.5:
            check_info_d[sect_k][layer_k] = s

In [704]:
print(json.dumps(check_info_d, indent=2))

{
  "text+struct": {
    "all_layers": 0.3990614414215088
  },
  "all": {
    "all_layers": 0.4772564172744751
  }
}


##### Layer

In [705]:
### Check "breaking" window layer, i.e. those with sudden changes 
### For now: single layer drop > _th

_th = 0.4
check_info_l = []
for i, d in enumerate(good_samples):
#     for sect_k, sect_d in d['trace_scores'].items():
    sect_k = 'all'
    sect_d = d['trace_scores'][sect_k]
    window_d = sect_d['window']
    for l in range(1, 24):
        if window_d[str(l-1)] - window_d[str(l)] > _th:
            _info_d = {
                'id': i,
                'sect_k': sect_k,
                'layer': l,
                'last_layer_score': window_d[str(l-1)],
                'this_layer_score': window_d[str(l)],
            }
            check_info_l.append(_info_d)
            break
len(check_info_l)

844

In [706]:
len(check_info_l), len(good_samples)

(844, 1207)

In [707]:
break_layer_counter = Counter([_d['layer'] for _d in check_info_l])
sorted(break_layer_counter.items())

[(1, 10),
 (2, 25),
 (3, 31),
 (4, 81),
 (5, 24),
 (6, 3),
 (7, 15),
 (8, 87),
 (9, 69),
 (10, 19),
 (11, 38),
 (12, 36),
 (13, 70),
 (14, 101),
 (15, 27),
 (16, 31),
 (17, 31),
 (18, 79),
 (19, 67)]

In [None]:
for info_d in check_info_l:
    if info_d['layer'] < 6:
        print(info_d)
        sample_id = info_d['id']
        d = good_samples[sample_id]
        print(d['enc_sentence'])
        print(d['dec_prompt'], '---->', d['expect'])
        print('Categories:', d['category'])
        print('--' * 20)

In [713]:
sample_counter_by_aspect = defaultdict(Counter)  # [asp_k, asp_v] -> count 
sample_counter = Counter()

for info_d in check_info_l:
    if info_d['layer'] < 6:
        sample_id = info_d['id']
        d = good_samples[sample_id]
        text_match = d['category']['text_match']
        node_len = d['category']['node_len']
        sample_counter[(text_match, node_len)] += 1
        
        for asp_k, asp_v in d['category'].items():
            sample_counter_by_aspect[asp_k][asp_v] += 1

In [715]:
sample_counter_by_aspect

defaultdict(collections.Counter,
            {'sql_hardness': Counter({'medium': 58,
                      'hard': 42,
                      'extra': 51,
                      'easy': 20}),
             'node_role': Counter({'from': 105, 'join': 66}),
             'text_match': Counter({'partial': 59,
                      'no-match': 74,
                      'exact': 38}),
             'node_len': Counter({'4+': 60, '3': 51, '2': 24, '1': 36})})

In [716]:
sample_counter.most_common()

[(('partial', '3'), 30),
 (('partial', '4+'), 29),
 (('exact', '1'), 22),
 (('no-match', '3'), 21),
 (('no-match', '4+'), 21),
 (('no-match', '2'), 18),
 (('no-match', '1'), 14),
 (('exact', '4+'), 10),
 (('exact', '2'), 6)]

In [717]:
sample_counter_by_aspect = defaultdict(Counter)  # [asp_k, asp_v] -> count 
sample_counter = Counter()

for info_d in check_info_l:
    if info_d['layer'] > 18:
        sample_id = info_d['id']
        d = good_samples[sample_id]
        text_match = d['category']['text_match']
        node_len = d['category']['node_len']
        sample_counter[(text_match, node_len)] += 1
        
        for asp_k, asp_v in d['category'].items():
            sample_counter_by_aspect[asp_k][asp_v] += 1

In [718]:
sample_counter_by_aspect

defaultdict(collections.Counter,
            {'sql_hardness': Counter({'hard': 9,
                      'medium': 27,
                      'easy': 11,
                      'extra': 20}),
             'node_role': Counter({'join': 21, 'from': 46}),
             'text_match': Counter({'exact': 46,
                      'no-match': 20,
                      'partial': 1}),
             'node_len': Counter({'3': 12, '1': 46, '2': 6, '4+': 3})})

In [719]:
sample_counter.most_common()

[(('exact', '1'), 38),
 (('no-match', '3'), 9),
 (('no-match', '1'), 8),
 (('exact', '3'), 3),
 (('no-match', '2'), 3),
 (('exact', '2'), 3),
 (('exact', '4+'), 2),
 (('partial', '4+'), 1)]

In [743]:
# _min_p = 1.0

# for i, d in enumerate(good_samples):
#     sect_k = 'text'
#     sect_d = d['trace_scores'][sect_k]
#     _min_p = min(_min_p, sect_d['all_layers'])

# _min_p

7.385318918917694e-12

In [751]:
### Systematic 

ob_sect_k = 'all'

_th = 0.4

check_info_l = []
all_layers_eff_cnt = 0  # for this section to observe, how many samples are effective with all_layers
window_eff_cnt = 0      # for this section to observe, how many samples are effective with any window 

for i, d in enumerate(good_samples):
#     for sect_k, sect_d in d['trace_scores'].items():
    sect_k = ob_sect_k
    sect_d = d['trace_scores'][sect_k]
    if sect_d['all_layers'] > 0.5:
        # not effective
        continue
    else:
        all_layers_eff_cnt += 1
        
    if min(sect_d['window'].values()) > 0.5:
        # not effective
        continue
    else:
        window_eff_cnt += 1
        
    window_d = sect_d['window']
    for l in range(1, 24):
        if window_d[str(l-1)] - window_d[str(l)] > _th:
            _info_d = {
                'id': i,
                'sect_k': sect_k,
                'layer': l,
                'last_layer_score': window_d[str(l-1)],
                'this_layer_score': window_d[str(l)],
            }
            check_info_l.append(_info_d)
            break
len(check_info_l), window_eff_cnt, all_layers_eff_cnt, len(good_samples)

(840, 916, 1207, 1207)

In [745]:
ctg_list = [(tm, nl) for tm in ['exact', 'partial', 'no-match'] for nl in ['1', '2', '3', '4+']]
layer_list = [str(l) for l in range(1, 24)]

ctg_elem2id = {elem : i for i, elem in enumerate(ctg_list)}
layer_elem2id = {elem : i for i, elem in enumerate(layer_list)}

cnt_matrix = np.zeros((len(ctg_list), len(layer_list)), int)

for info_d in check_info_l:
    sample_id = info_d['id']
    d = good_samples[sample_id]
    text_match = d['category']['text_match']
    node_len = d['category']['node_len']
    _ctg = (text_match, node_len)
    _layer = str(info_d['layer'])
    
    _ctg_idx = ctg_elem2id[_ctg]
    _layer_idx = layer_elem2id[_layer]
    cnt_matrix[_ctg_idx, _layer_idx] += 1

In [None]:
# Create a figure and axis
fig, ax = plt.subplots(figsize=(8, 6))

# Display the matrix using imshow
im = ax.imshow(cnt_matrix, cmap='Blues')

# Set the tick labels for the first and second dimensions
ax.set_xticks(np.arange(len(layer_list)))
ax.set_yticks(np.arange(len(ctg_list)))

# Set the tick labels using the ctg_list and layer_list
ax.set_xticklabels(layer_list)
ax.set_yticklabels(ctg_list)

ax.set_title(f'Section: {ob_sect_k}\n')

# Rotate the x-axis tick labels if needed
# plt.xticks(rotation=90)

# Add a colorbar
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.5)

# Show the plot
plt.show()


In [None]:
# Create a figure and axis
fig, ax = plt.subplots(figsize=(8, 6))

# Display the matrix using imshow
im = ax.imshow(cnt_matrix, cmap='Blues')

# Set the tick labels for the first and second dimensions
ax.set_xticks(np.arange(len(layer_list)))
ax.set_yticks(np.arange(len(ctg_list)))

# Set the tick labels using the ctg_list and layer_list
ax.set_xticklabels(layer_list)
ax.set_yticklabels(ctg_list)

ax.set_title(f'Section: {ob_sect_k}\n')

# Rotate the x-axis tick labels if needed
# plt.xticks(rotation=90)

# Add a colorbar
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.5)

# Show the plot
plt.show()


In [747]:
plt.close()

### Exp-5.3: attention section mutual removal

#### Load & Check

In [837]:
expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [838]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [839]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(2039, (364, 364), 339, 1336, 1675, 'good / correct = 364 / 1700')

#### Overall avg

In [840]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            if k == 'window':
                for l, s in v.items():
                    trace_scores_avg[sect_k][f'{k}-{l}'] += s
            else:
                s = v
                trace_scores_avg[sect_k][k] += s

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [None]:
trace_scores_avg

#### Avg by aspects (category)

In [842]:
d['category']

{'sql_hardness': 'medium', 'node_role': 'where', 'text_match': 'no-match'}

In [843]:
# TEMP patch for node_len category 
for d in good_samples + bad_samples:
    node_len = len(d['answers_t'])
    assert len(mt_uskg.tokenizer.tokenize(d['expect'])) == node_len, (d['expect'], node_len)
    d['category']['node_len'] = str(node_len) if node_len <= 3 else '4+'

In [844]:
d['category']

{'sql_hardness': 'medium',
 'node_role': 'group by',
 'text_match': 'exact',
 'node_len': '3'}

In [845]:
# Key: (sect_k, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no sect key & layer key 

In [846]:
for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for k, v in sect_d.items():
                if k == 'window':
                    for l, s in v.items():
                        if not (int(l) % 4 == 3): continue
                        layer_k = f'{k}-{l}'
                        trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                else:
                    layer_k = k
                    s = v
                    trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                    
for sect_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for layer_k, s in d3.items():
                trace_scores_avg_by_aspect[sect_k][asp_k][asp_v][layer_k] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [847]:
for sect_k, sect_d in trace_scores_avg_by_aspect.items():
    sect_d['overall'] = dict()
    for layer_k, s in trace_scores_avg[sect_k].items():
        if layer_k.startswith('window'):
            # only keep a subset of layers 
            _, l = layer_k.split('-')
            if not (int(l) % 4 == 3): continue
        sect_d['overall'][layer_k] = s

In [848]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 123, 'extra': 175, 'hard': 64, 'easy': 2}),
             'node_role': defaultdict(int,
                         {'select': 191,
                          'group by': 33,
                          'join': 12,
                          'where': 111,
                          'order by': 17}),
             'text_match': defaultdict(int,
                         {'exact': 230, 'partial': 28, 'no-match': 106}),
             'node_len': defaultdict(int, {'3': 364})})

In [None]:
trace_scores_avg_by_aspect['c->p']

In [850]:
dump_d = ctu.nested_json_processing(trace_scores_avg_by_aspect, func=lambda x: np.format_float_positional(x, precision=4, min_digits=4))
# dump_d

In [851]:
dump_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/summ-exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(dump_path, 'w') as f:
    json.dump(dump_d, f, indent=1)

#### (one-time temp patch)

In [804]:
# expect_type = 'table_alias'
# orig_res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/no_c2p_exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'
# add_res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.1+c2p_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

# merge_res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

In [805]:
# with open(orig_res_path, 'r') as f:
#     orig_all_samples = [json.loads(l) for l in f]
# with open(add_res_path, 'r') as f:
#     add_all_samples = [json.loads(l) for l in f]

# f = open(merge_res_path, 'w')
    
# for i, (orig_ex, add_ex) in enumerate(zip(orig_all_samples, add_all_samples)):
#     assert len(orig_ex['trace_results']) == len(add_ex['trace_results']), i
#     # There is randomness in the order of expected node (from set()), thus sorting here 
#     orig_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     add_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     for j, (orig_d, add_d) in enumerate(zip(orig_ex['trace_results'], add_ex['trace_results'])):
#         assert orig_d['is_good_sample'] == add_d['is_good_sample'], (i, j)
#         if not orig_d['is_good_sample']:
#             continue
            
#         # is good sample: add the new sections 
#         orig_d['trace_scores']['c->p'] = add_d['trace_scores']['c->p']
        
#         # put all at end in the dict 
#         _t = orig_d['trace_scores']['all']
#         del orig_d['trace_scores']['all']
#         orig_d['trace_scores']['all'] = _t
        
#     f.write(json.dumps(orig_ex, indent=None) + '\n')
    
# f.close()

## Tests

### create_analysis_samples

In [107]:
ex_id = 111
a_ex_id = 0

ex = processed_spider_dev[ex_id]
ex['text_in'], \
ex['struct_in'], \
ex['seq_out']

('What is the accelerate of the car make amc hornet sportabout (sw)?',
 '| car_1 | continents : contid , continent | countries : countryid , countryname , continent | car_makers : id , maker ( amc ) , fullname , country | model_list : modelid , maker , model ( amc ) | car_names : makeid , model ( amc ) , make ( amc hornet , amc hornet sportabout (sw) ) | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';")

In [108]:
# temp test
# ex['seq_out'] = 'select year from cars_data'

In [109]:
a_ex_list = ctu.create_analysis_sample_dicts(
                mt_uskg, ex,
                subject_type='column',
                remove_struct_duplicate_nodes=True)

In [110]:
a_ex_list[a_ex_id].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'enc_tokenized', 'text_range', 'struct_range', 'struct_node_ranges_dict', 'dec_prompt', 'expect', 'expect_type', 'remove_struct_duplicate_nodes', 'parsed_struct_in', 'col2table', 'token_ranges_dict', 'node_name_ranges', 'expect_input_ranges', 'alias2table', 'self_ranges', 'context_ranges', 'category'])

In [111]:
a_ex_list[a_ex_id]['alias2table']

{'t1': 'cars_data', 't2': 'car_names'}

In [None]:
[(d['dec_prompt'], d['expect'], d['node_name_ranges'], d['expect_input_ranges'], '------',\
  d['self_ranges'], d['context_ranges'],\
  d['category'], '------' * 2) for d in a_ex_list]

In [None]:
d = dict(a_ex_list[a_ex_id])
d

In [None]:
d = ctu.add_clean_prediction(mt_uskg, d)

In [None]:
d

#### parse_sql_alias2table

In [185]:
_sql = 'SELECT t2.aaa , t3.ccc FROM table_name as t1 JOIN other_table as t2 on table_name.a_a = other_table.b_a JOIN ttt as t3 on other_table.asth = ttt.asth'
ctu.parse_sql_alias2table(_sql)

{'t1': 'table_name', 't2': 'other_table', 't3': 'ttt'}

#### for syntax

In [130]:
a_ex_list_syntax = ctu.create_syntax_analysis_sample_dicts(mt_uskg, ex)

In [131]:
for a_ex in a_ex_list_syntax:
    print(a_ex['dec_prompt'], ' --> ', a_ex['expect'])

select t1.accelerate  -->  from
select t1.accelerate from cars_data  -->  as
select t1.accelerate from cars_data as t1  -->  join
select t1.accelerate from cars_data as t1 join car_names  -->  as
select t1.accelerate from cars_data as t1 join car_names as t2  -->  on
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id  -->  =
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid  -->  where
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make  -->  =
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make =  -->  '
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = '  -->  amc
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc  -->  hornet
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet 

#### utils


In [117]:
_sql = 'SELECT t2.aaa, DISTINCT(t3.ccc), COUNT(*) FROM table_name as t1 JOIN other_table as t2 on table_name.a_a = other_table.b_a JOIN ttt as t3 on other_table.asth = ttt.asth WHERE t2.col like %hey% AND t3.p <= 40'.lower()
_tok_ranges = ctu.separate_punct_by_offset(_sql)
print([_sql[s:e] for s, e in _tok_ranges])

['select', 't2.', 'aaa', ',', 'distinct', '(', 't3.', 'ccc', ')', ',', 'count', '(', '*', ')', 'from', 'table_name', 'as', 't1', 'join', 'other_table', 'as', 't2', 'on', 'table_name', '.', 'a_a', '=', 'other_table', '.', 'b_a', 'join', 'ttt', 'as', 't3', 'on', 'other_table', '.', 'asth', '=', 'ttt', '.', 'asth', 'where', 't2.', 'col', 'like', '%', 'hey', '%', 'and', 't3.', 'p', '<=', '40']


### trace

In [207]:
a_ex = dict(a_ex_list[a_ex_id])
a_ex = ctu.add_clean_prediction(mt_uskg, a_ex)

In [208]:
result = ctu.make_basic_result_dict(a_ex)
result

{'enc_sentence': 'Which city has the most frequent destination airport?; structed knowledge: | flight_2 | airlines : uid , airline , abbreviation , country | airports : city , airportcode , airportname , country , countryabbrev | flights : airline , flightno , sourceairport , destairport',
 'seq_out': 'select t1.city from airports as t1 join flights as t2 on t1.airportcode = t2.destairport group by t1.city order by count(*) desc limit 1',
 'dec_prompt': 'select t1.city from airports as t1 join flights as t2 on t1.airportcode = t2.destairport group by t1.',
 'expect': 'city',
 'expect_type': 'column',
 'db_id': 'flight_2',
 'expect_input_ranges': [(45, 46)],
 'expect_table': 'airports',
 'answer': 'city',
 'base_score': 0.9983423948287964,
 'answers_t': [6726],
 'correct_prediction': True,
 'category': {'sql_hardness': 'extra',
  'node_role': 'group by',
  'text_match': 'exact'}}

In [209]:
enc_sentence = a_ex['enc_sentence']
dec_prompt = a_ex['dec_prompt']
expect = a_ex['expect']
answer = result['answer']
answers_t = result['answers_t']

inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    [enc_sentence] * 11,
    [dec_prompt] * 11,
    answer=expect)

text_range = a_ex['text_range']
struct_range = a_ex['struct_range']

self_ranges = a_ex['self_ranges']
context_ranges = a_ex['context_ranges']

self_tok_indices = [i for s, e in self_ranges for i in range(s, e)]
context_tok_indices = corrupt_tok_indices = [i for s, e in context_ranges for i in range(s, e)]
text_tok_indices = list(range(*text_range))
struct_tok_indices = list(range(*struct_range))

In [210]:
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', mt_uskg.num_enc_layers - 1))
#                     for tnum in range(*struct_range)],
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
    tokens_to_mix=text_tok_indices,
    tokens_to_mix_individual_indices=True,
    replace=True,
).item()

In [211]:
answers_t, answer, _score

([6726], 'city', 0.8450507521629333)

In [212]:
states_to_corrupt = [(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
                for tnum in text_tok_indices]

_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
    states_to_corrupt=states_to_corrupt,
#     tokens_to_mix=corrupt_tok_indices,
#     tokens_to_mix_individual_indices=True,
    replace=True,
).item()
_score

0.8450507521629333

In [213]:
# Pair of identical input to test correctness 

_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
#                     for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_patch=[],
    states_to_patch_1st_pass=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 12))
                    for tnum in text_tok_indices],
    states_to_unpatch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 23))
                    for tnum in struct_tok_indices],
    answers_t=answers_t,
    tokens_to_mix=text_tok_indices,
    tokens_to_mix_individual_indices=True,
    tokens_to_mix_1st_pass=context_tok_indices,
    replace=True,
).item()

_score

tensor(0.9952, device='cuda:0')

In [214]:
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
#                     for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_patch=[],
    states_to_patch_1st_pass=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 12))
                    for tnum in text_tok_indices],
    states_to_unpatch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 23))
                    for tnum in struct_tok_indices],
    answers_t=answers_t,
    states_to_corrupt=[(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
                    for tnum in text_tok_indices],
    states_to_corrupt_1st_pass=[(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
                    for tnum in context_tok_indices],
    replace=True,
).item()

_score

tensor(0.9952, device='cuda:0')

In [215]:
# Test corrupting attention 
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
#                     for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
    states_to_corrupt=[(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", l, "self_attn"))
                    for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers)],
    replace=True,
).item()

_score

0.7253002524375916

In [None]:
[n for n, w in mt_uskg.model.named_parameters()]

In [216]:
vocab_probs = ctu.run_repatch_uskg_multi_token(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', mt_uskg.num_enc_layers - 1))
#                     for tnum in range(*struct_range)],
    states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
                    for tnum in self_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_unpatch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', mt_uskg.num_enc_layers - 1))
                    for tnum in self_tok_indices],
    answer_len=len(answers_t),
    tokens_to_mix=corrupt_tok_indices,
    tokens_to_mix_individual_indices=True,
    replace=True,
)

In [217]:
vocab_probs.size()

torch.Size([1, 32102])

In [218]:
torch.max(vocab_probs, dim=-1)

torch.return_types.max(
values=tensor([1.], device='cuda:0'),
indices=tensor([7634], device='cuda:0'))

In [219]:
vocab_probs[0, 7634]

tensor(1., device='cuda:0')

In [220]:
vocab_probs

tensor([[2.2642e-25, 1.2223e-15, 7.3942e-18,  ..., 9.1578e-20, 2.6884e-39,
         2.8131e-39]], device='cuda:0')

## Temp

### Debugging exp

In [497]:
ex = processed_spider_dev[97]
text_in = ex['text_in']
struct_in = ex['struct_in']

enc_sentence = f"{text_in}; structed knowledge: {struct_in}"
dec_prompt = "select t1.model from"
expect = "car_names"

In [498]:
inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    enc_sentences=[enc_sentence]*11,
    dec_prompts=[dec_prompt]*11,
    answer=expect
)

### RE

In [70]:
seq = 'aa,bb< cc  \t dd(  )ee <= ff=5 %h% "06-15".'
sep_pattern = r'\s+|\W'

all_matches = re.finditer(sep_pattern, seq)

In [71]:
all_matches = list(all_matches)
all_matches

[<re.Match object; span=(2, 3), match=','>,
 <re.Match object; span=(5, 6), match='<'>,
 <re.Match object; span=(6, 7), match=' '>,
 <re.Match object; span=(9, 13), match='  \t '>,
 <re.Match object; span=(15, 16), match='('>,
 <re.Match object; span=(16, 18), match='  '>,
 <re.Match object; span=(18, 19), match=')'>,
 <re.Match object; span=(21, 22), match=' '>,
 <re.Match object; span=(22, 23), match='<'>,
 <re.Match object; span=(23, 24), match='='>,
 <re.Match object; span=(24, 25), match=' '>,
 <re.Match object; span=(27, 28), match='='>,
 <re.Match object; span=(29, 30), match=' '>,
 <re.Match object; span=(30, 31), match='%'>,
 <re.Match object; span=(32, 33), match='%'>,
 <re.Match object; span=(33, 34), match=' '>,
 <re.Match object; span=(34, 35), match='"'>,
 <re.Match object; span=(37, 38), match='-'>,
 <re.Match object; span=(40, 41), match='"'>,
 <re.Match object; span=(41, 42), match='.'>]

In [72]:
all_matches[0].span()

(2, 3)

In [73]:
splits = [0] + [i for m in all_matches for i in m.span()] + [len(seq)]
splits = sorted(list(set(splits)))
print(splits)

[0, 2, 3, 5, 6, 7, 9, 13, 15, 16, 18, 19, 21, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 40, 41, 42]


In [74]:
st = 0
SP = ["<=", ">=", "<>", "!="]
toks = []

for s, e in zip(splits[:-1], splits[1:]):
    if not seq[s:e].strip():
        # is a whitespace
        st = e
    else:
        # is a punct
        if seq[s:s+2] in SP:
            # wait next
            continue
        toks.append(seq[st:e])
        st = e
        

In [75]:
print(toks)

['aa', ',', 'bb', '<', 'cc', 'dd', '(', ')', 'ee', '<=', 'ff', '=', '5', '%', 'h', '%', '"', '06', '-', '15', '"', '.']


### other temp

## (placeholder)