<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [2]:
IS_COLAB = False

## Causal Tracing

A demonstration of the double-intervention causal tracing method.

The strategy used by causal tracing is to understand important
states within a transfomer by doing two interventions simultaneously:

1. Corrupt a subset of the input.  In our paper, we corrupt the subject tokens
   to frustrate the ability of the transformer to accurately complete factual
   prompts about the subject.
2. Restore a subset of the internal hidden states.  In our paper, we scan
   hidden states at all layers and all tokens, searching for individual states
   that carry the necessary information for the transformer to recover its
   capability to complete the factual prompt.

The traces of decisive states can be shown on a heatmap.  This notebook
demonstrates the code for conducting causal traces and creating these heatmaps.

In [3]:
%load_ext autoreload
%autoreload 2

The `experiments.causal_trace` module contains a set of functions for running causal traces.

In this notebook, we reproduce, demonstrate and discuss the interesting functions.

We begin by importing several utility functions that deal with tokens and transformer models.

In [4]:
import os, sys, re, json
import string
import torch
import numpy as np
import copy
from collections import defaultdict, Counter
from util import nethook
from util.globals import DATA_DIR
from experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)
from dsets import KnownsDataset

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f79bc4e45e0>

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## USKG

In [7]:
from transformers import (
    HfArgumentParser,
    set_seed,
    AutoTokenizer
)

# from uskg.models.unified.prefixtuning import Model
from uskg.models.unified import finetune, prefixtuning
from uskg.utils.configue import Configure
from uskg.utils.training_arguments import WrappedSeq2SeqTrainingArguments
from uskg.seq2seq_construction import spider as s2s_spider
from uskg.third_party.spider.preprocess.get_tables import dump_db_json_schema
from uskg.third_party.spider import evaluation as sp_eval
from tqdm.notebook import tqdm

# from nltk.stem.wordnet import WordNetLemmatizer
# import stanza

import matplotlib.pyplot as plt
import sqlite3

from experiments import causal_trace_uskg as ctu

In [82]:
mt_uskg = ctu.ModelAndTokenizer_USKG('t5-large-prefix')

Using tokenizer_uskg: hkunlp/from_all_T5_large_prefix_spider_with_cell_value2
Using tokenizer_fast: t5-large
prefix-tuning sequence length is 10.


In [83]:
list(mt_uskg.task_args.seq2seq)

[('constructor', 'seq2seq_construction.spider'),
 ('schema_serialization_with_db_content', True),
 ('target_with_db_id', False)]

In [84]:
mt_uskg.model.pretrain_model.encoder.embed_tokens is mt_uskg.model.pretrain_model.shared, \
mt_uskg.model.pretrain_model.decoder.embed_tokens is mt_uskg.model.pretrain_model.shared

(True, False)

In [85]:
mt_uskg.model.preseqlen

10

In [86]:
# [k for k,v in mt_uskg.model.named_parameters()]
# [k for k,v in mt_uskg.model.named_modules()]

In [87]:
inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    enc_sentences=["Translate to German: My name is Wolfgang and I live in Berlin"],
    dec_prompts=["Mein Name ist Wolfgang"],
    device="cuda:0"
)

In [88]:
out = ctu.run_model_forward_uskg(mt_uskg.model, **inp)

In [89]:
out.keys(), out['logits'].size()

(odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state']),
 torch.Size([1, 5, 32102]))

In [90]:
logits = out["logits"][0, -1].detach().cpu().numpy()
logits.shape

(32102,)

In [91]:
top_5 = sorted(list(enumerate(logits)), key=lambda p: -p[1])[:5]
top_5

[(11, -1.8642352),
 (6, -9.727753),
 (5, -10.966707),
 (27, -11.037394),
 (213, -12.864212)]

In [92]:
[mt_uskg.tokenizer.decode([p[0]]) for p in top_5]

['and', ',', '.', 'I', 'where']

### Load spider dataset

In [93]:
spider_train_path = '/home/yshao/Projects/SDR-analysis/data/spider/train+ratsql_graph.json'
spider_dev_path = '/home/yshao/Projects/SDR-analysis/data/spider/dev+ratsql_graph.json'
spider_db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [94]:
raw_spider_dev = ctu.load_raw_dataset(
    data_filepath = spider_dev_path,
    db_path=spider_db_dir,
#     schema_cache=SCHEMA_CACHE
)
len(raw_spider_dev)

1034

In [95]:
raw_spider_dev[0].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph'])

In [96]:
mt_uskg.task_args.dataset.use_cache

True

In [97]:
processed_spider_dev = s2s_spider.DevDataset(
    args=mt_uskg.task_args,
    raw_datasets=raw_spider_dev,
    cache_root='../cache')

In [98]:
_id = 130
processed_spider_dev[_id]['text_in'], \
processed_spider_dev[_id]['struct_in'], \
processed_spider_dev[_id]['seq_out']

('What are the names of all European countries with at least 3 manufacturers?',
 '| car_1 | continents : contid , continent ( europe ) | countries : countryid , countryname , continent | car_makers : id , maker , fullname , country | model_list : modelid , maker , model | car_names : makeid , model , make | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.countryname from countries as t1 join continents as t2 on t1.continent = t2.contid join car_makers as t3 on t1.countryid = t3.country where t2.continent = 'europe' group by t1.countryname having count(*) >= 3;")

In [99]:
_enc_sentence = f"{processed_spider_dev[_id]['text_in']}; structed knowledge: {processed_spider_dev[_id]['struct_in']}"
_toks = mt_uskg.tokenizer.tokenize(_enc_sentence)
len(_toks)

142

In [100]:
# # _occ_punct = set()

# for _id in range(len(processed_spider_dev)):
#     ex = processed_spider_dev[_id]
# #     _occ_punct.update(set(string.punctuation) & set(ex['seq_out']))
#     if '_(' in ex['struct_in']:
#         print(_id, ex['question'])
#         print(ex['struct_in'])
#         print(ex['seq_out'])
#         print()

In [101]:
# ## Train set

# raw_spider_train = ctu.load_raw_dataset(
#     data_filepath = spider_train_path,
#     db_path=spider_db_dir,
# )
# processed_spider_train = s2s_spider.TrainDataset(
#     args=mt_uskg.task_args,
#     raw_datasets=raw_spider_train,
#     cache_root='../cache')
# len(processed_spider_train)

In [102]:
# processed_spider_train[5441]

### Helpers
- Aspects-related helpers are merged into create_analysis_sample_dicts()

#### Utils

In [None]:
def exp6_ob_by_exp_tok(samples):
    # samples: usually `good_samples`
    
    # Key: (expect_tok, sect_k, layer) -> [scores]
    trace_scores_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    trace_scores_avg_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
    trace_scores_cnt_by_exp_tok = defaultdict(int)  # no sect key & layer key 

    trace_sample_ids_by_exp_tok = defaultdict(list)
    
    for i, d in enumerate(samples):
        expect = d['expect']
        trace_sample_ids_by_exp_tok[expect].append(i)
        for sect_k, sect_d in d['trace_scores'].items():
            for layer_k, v in sect_d.items():
                trace_scores_by_exp_tok[expect][sect_k][layer_k].append(v)

    for exp_tok, d1 in trace_scores_by_exp_tok.items():
        if exp_tok.isnumeric(): continue
        for sect_k, d2 in d1.items():
            for layer_k, scores in d2.items():
                if len(scores) <= 2: continue
                trace_scores_avg_by_exp_tok[exp_tok][sect_k][layer_k] = np.mean(scores)
                trace_scores_cnt_by_exp_tok[exp_tok] = len(scores)
    
    return {
        'avg': trace_scores_avg_by_exp_tok,
        'cnt': trace_scores_cnt_by_exp_tok,
        'sample_ids': trace_sample_ids_by_exp_tok,
    }

In [440]:
def reverse_2D_dict(d):
    out_d = defaultdict(lambda: defaultdict(np.nan))
    for k1, d1 in d.items():
        for k2, v in d1.items():
            out_d[k2][k1] = v
    return out_d

def format_print_1D_dict(d, sort_by=None, reverse=False, head_col_w=10, col_w=6):
    # sort: None, 'key' or 'value'
    
    item_l = list(d.items())
    if sort_by == 'key':
        item_l.sort(reverse=reverse)
    elif sort_by == 'value':
        item_l.sort(key=lambda x: (x[1], x[0]), reverse=reverse)
    
    decm_w = col_w - 2
    
    for k, v in item_l:
        print(f'{k:<{head_col_w}s}{v:.{decm_w}f}')

def format_print_2D_dict(d, 
                         all_k1=None, 
                         all_k2=None, 
                         sort_k1_kwargs=None, 
                         sort_k2_kwargs=None, 
                         head_col_w=12, 
                         col_w=6,
                         decm_w=4):
    if all_k1 is None:
        all_k1 = list(d.keys())
        if sort_k1_kwargs is not None:
            all_k1.sort(**sort_k1_kwargs)
    
    if all_k2 is None:
        for k1, d1 in d.items():
            d1_keys = list(d1.keys())
            if all_k2 is None:
                all_k2 = d1_keys
            else:
                if set(d1_keys) != set(all_k2):
                    print('Warning:\n', d1_keys, '\n', all_k2)
            # all_k2.update(list(d1.keys()))
        if sort_k2_kwargs is not None:
            all_k2.sort(**sort_k2_kwargs)
    
    print_str = '\t'.join(['X' * head_col_w] + [f'{k2:<{col_w}s}' for k2 in all_k2]) + '\n'
    
    for k1 in all_k1:
        d1 = d[k1]
        print_str += f'{k1:<{head_col_w}s}'
        for k2 in all_k2:
            v = d1[k2]
            print_str += f'\t{v:<{col_w}.{decm_w}f}'
        print_str += '\n'
    
    print(print_str)

#### Evaluator

In [121]:
table_path = '/home/yshao/Projects/language/language/xsp/data/spider/tables.json'
db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [122]:
kmaps = sp_eval.build_foreign_key_map_from_json(table_path)
evaluator = sp_eval.Evaluator(db_dir=db_dir, kmaps=kmaps, etype='all')

In [123]:
ctu.evaluate_hardness.evaluator = evaluator

In [124]:
# test
_sql_str = 'select t1.birth_date from people as t1 join poker_player as t2 on t1.people_id = t2.people_id order by t2.earnings asc limit 1'
db_name = 'poker_player'
schema = evaluator.schemas[db_name]
_sql = sp_eval.get_sql(schema, _sql_str)
sp_eval.count_component1(_sql), sp_eval.count_component2(_sql), sp_eval.count_others(_sql), \
evaluator.eval_hardness(_sql)

(3, 0, 0, 'hard')

#### Hardness

In [125]:
ctu.evaluate_hardness(_sql_str, db_name, evaluator=evaluator)

'hard'

In [126]:
ctu.evaluate_hardness.evaluator

<uskg.third_party.spider.evaluation.Evaluator at 0x7f7814ed2040>

#### Node role

In [None]:
dec_prompt = 'select avg(age), min(age), max(age) from'
ctu.detect_node_role(dec_prompt)

#### Text match

In [None]:
a_dicts = ctu.create_analysis_sample_dicts(
    mt=mt_uskg,
    ex=processed_spider_dev[100],
    subject_type='table'
)
len(a_dicts), [d['expect'] for d in a_dicts]

In [None]:
a_ex = a_dicts[2]
ctu.check_table_text_match(a_ex, 'car_names')

In [None]:
a_ex['text_in']

### Exp-5.0: dirty attention vector effect 

#### Load & Check results

In [255]:
expect_type = 'table_alias'
res_path = f'/home/yshao/Projects/rome/results/exp5_0_dirty_attention_vector_effect/exp=5_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [256]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [257]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1
#         # TEMP adjustment for column results 
#         d['low_score'] = d['trace_scores']['high_layers_corrupt'].get("0", 0.0)  # "0" is key (for layer 0), 0.0 is default 
#         if d['base_score'] - d['low_score'] < 0.5:
#             d['is_good_sample'] = False
#         # END_TEMP
        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5, (i, d)
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
n_good_samples + n_too_easy

(2039, (164, 164), 339, 1536, 1875, 1700)

In [None]:
[s for s in bad_samples if s['correct_prediction']][0]

In [None]:
good_samples[0]

#### Overall avg

In [193]:
trace_scores_avg = {k: {str(l): 0 for l in range(24)} for k in good_samples[0]['trace_scores'].keys()}

In [194]:
for d in good_samples:
    for k, layer_d in d['trace_scores'].items():
        for l, s in layer_d.items():
            trace_scores_avg[k][l] += s

for k, layer_d in trace_scores_avg.items():
    for l, s in layer_d.items():
        layer_d[l] = s / len(good_samples)

In [None]:
trace_scores_avg

#### Avg by aspects (category)
- Still kind of linear, as in exp-2.2

In [243]:
d['category']

{'sql_hardness': 'hard', 'node_role': 'where', 'text_match': 'partial'}

In [265]:
# Key: (trace_key, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no trace key & layer key 

In [266]:
for d in good_samples:
    for trace_k, trace_layer_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for l, s in trace_layer_d.items():
                trace_scores_by_aspect[trace_k][aspect][asp_val][l].append(s)

for trace_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for l, s in d3.items():
                trace_scores_avg_by_aspect[trace_k][asp_k][asp_v][l] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [267]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 400,
                          'hard': 155,
                          'easy': 142,
                          'extra': 170}),
             'node_role': defaultdict(int,
                         {'where': 268,
                          'select': 412,
                          'order by': 66,
                          'join': 92,
                          'group by': 25,
                          'having': 4}),
             'text_match': defaultdict(int,
                         {'no-match': 360, 'partial': 148, 'exact': 359})})

In [None]:
trace_scores_avg_by_aspect['high_layers_corrupt']

### Exp-5.2: attention section removal effect

#### Load & Check

In [679]:
expect_type = 'table'

res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [680]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [681]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples),\
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(1683, (1207, 1207), 136, 340, 476, 'good / correct = 1207 / 1547')

In [682]:
[s for s in bad_samples if not s['correct_prediction']][0]

{'enc_sentence': 'Find the number of concerts happened in the stadium with the highest capacity .; structed knowledge: | concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id',
 'seq_out': 'select count(*) from concert where stadium_id = (select stadium_id from stadium order by capacity desc limit 1)',
 'dec_prompt': 'select count(*) from',
 'expect': 'concert',
 'expect_type': 'table',
 'db_id': 'concert_singer',
 'expect_input_ranges': [[88, 89]],
 'expect_table': 'concert',
 'answer': 'stadium',
 'base_score': 0.9941871166229248,
 'answers_t': [14939],
 'correct_prediction': False,
 'category': {'sql_hardness': 'hard',
  'node_role': 'from',
  'text_match': 'exact'},
 'self_ranges': [[87, 91]],
 'struct_context_ranges': [[22, 87], [91, 132]],
 'is_

In [None]:
good_samples[0]

#### Overall avg

In [647]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            if k == 'window':
                for l, s in v.items():
                    trace_scores_avg[sect_k][f'{k}-{l}'] += s
            else:
                s = v
                trace_scores_avg[sect_k][k] += s

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [None]:
trace_scores_avg

#### Avg by aspects (category)

In [696]:
# TEMP patch for node_len category 
for d in good_samples + bad_samples:
    node_len = len(d['answers_t'])
    assert len(mt_uskg.tokenizer.tokenize(d['expect'])) == node_len, (d['expect'], node_len)
    d['category']['node_len'] = str(node_len) if node_len <= 3 else '4+'

In [701]:
d['category']

{'sql_hardness': 'medium',
 'node_role': 'from',
 'text_match': 'exact',
 'node_len': '1'}

In [651]:
# Key: (sect_k, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no sect key & layer key 

In [652]:
for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for k, v in sect_d.items():
                if k == 'window':
                    for l, s in v.items():
                        if not (int(l) % 4 == 3): continue
                        layer_k = f'{k}-{l}'
                        trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                else:
                    layer_k = k
                    s = v
                    trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                    
for sect_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for layer_k, s in d3.items():
                trace_scores_avg_by_aspect[sect_k][asp_k][asp_v][layer_k] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [653]:
for sect_k, sect_d in trace_scores_avg_by_aspect.items():
    sect_d['overall'] = dict()
    for layer_k, s in trace_scores_avg[sect_k].items():
        if layer_k.startswith('window'):
            # only keep a subset of layers 
            _, l = layer_k.split('-')
            if not (int(l) % 4 == 3): continue
        sect_d['overall'][layer_k] = s

In [654]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 378,
                          'hard': 148,
                          'easy': 134,
                          'extra': 160}),
             'node_role': defaultdict(int,
                         {'where': 248,
                          'select': 393,
                          'order by': 63,
                          'join': 91,
                          'group by': 21,
                          'having': 4}),
             'text_match': defaultdict(int,
                         {'no-match': 345, 'partial': 142, 'exact': 333}),
             'node_len': defaultdict(int,
                         {'1': 329, '3': 238, '4+': 160, '2': 93})})

In [None]:
trace_scores_avg_by_aspect['self']

In [641]:
dump_d = ctu.nested_json_processing(trace_scores_avg_by_aspect, func=lambda x: np.format_float_positional(x, precision=4, min_digits=4))
# dump_d

In [642]:
dump_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/summ-exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(dump_path, 'w') as f:
    json.dump(dump_d, f, indent=1)

#### (one-time temp patch)

In [531]:
# expect_type = 'table_alias'
# orig_res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/no_structcontext-exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'
# add_res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/exp=5.2.1+structcontext_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

# merge_res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

In [532]:
# with open(orig_res_path, 'r') as f:
#     orig_all_samples = [json.loads(l) for l in f]
# with open(add_res_path, 'r') as f:
#     add_all_samples = [json.loads(l) for l in f]

# f = open(merge_res_path, 'w')
    
# for i, (orig_ex, add_ex) in enumerate(zip(orig_all_samples, add_all_samples)):
#     assert len(orig_ex['trace_results']) == len(add_ex['trace_results']), i
#     # There is randomness in the order of expected node (from set()), thus sorting here 
#     orig_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     add_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     for j, (orig_d, add_d) in enumerate(zip(orig_ex['trace_results'], add_ex['trace_results'])):
#         assert orig_d['is_good_sample'] == add_d['is_good_sample'], (i, j)
#         if not orig_d['is_good_sample']:
#             continue
            
#         # is good sample: add the new sections 
#         orig_d['trace_scores']['struct_context'] = add_d['trace_scores']['struct_context']
#         orig_d['trace_scores']['text+struct_context'] = add_d['trace_scores']['text+struct_context']
        
#     f.write(json.dumps(orig_ex, indent=None) + '\n')
    
# f.close()

#### Single samples observations

In [702]:
good_samples[0]['trace_scores'].keys()

dict_keys(['prefix', 'text', 'struct', 'text+struct', 'all', 'self', 'struct_context', 'text+struct_context'])

In [703]:
_id = 0

d = good_samples[_id]

check_info_d = defaultdict(dict)

for sect_k, sect_d in d['trace_scores'].items():
    for layer_k, s in sect_d.items():
        if layer_k == 'window':
            layer_k = 'window-19'
            s = s['19']
        if s < 0.5:
            check_info_d[sect_k][layer_k] = s

In [704]:
print(json.dumps(check_info_d, indent=2))

{
  "text+struct": {
    "all_layers": 0.3990614414215088
  },
  "all": {
    "all_layers": 0.4772564172744751
  }
}


##### Layer

In [705]:
### Check "breaking" window layer, i.e. those with sudden changes 
### For now: single layer drop > _th

_th = 0.4
check_info_l = []
for i, d in enumerate(good_samples):
#     for sect_k, sect_d in d['trace_scores'].items():
    sect_k = 'all'
    sect_d = d['trace_scores'][sect_k]
    window_d = sect_d['window']
    for l in range(1, 24):
        if window_d[str(l-1)] - window_d[str(l)] > _th:
            _info_d = {
                'id': i,
                'sect_k': sect_k,
                'layer': l,
                'last_layer_score': window_d[str(l-1)],
                'this_layer_score': window_d[str(l)],
            }
            check_info_l.append(_info_d)
            break
len(check_info_l)

844

In [706]:
len(check_info_l), len(good_samples)

(844, 1207)

In [707]:
break_layer_counter = Counter([_d['layer'] for _d in check_info_l])
sorted(break_layer_counter.items())

[(1, 10),
 (2, 25),
 (3, 31),
 (4, 81),
 (5, 24),
 (6, 3),
 (7, 15),
 (8, 87),
 (9, 69),
 (10, 19),
 (11, 38),
 (12, 36),
 (13, 70),
 (14, 101),
 (15, 27),
 (16, 31),
 (17, 31),
 (18, 79),
 (19, 67)]

In [None]:
for info_d in check_info_l:
    if info_d['layer'] < 6:
        print(info_d)
        sample_id = info_d['id']
        d = good_samples[sample_id]
        print(d['enc_sentence'])
        print(d['dec_prompt'], '---->', d['expect'])
        print('Categories:', d['category'])
        print('--' * 20)

In [713]:
sample_counter_by_aspect = defaultdict(Counter)  # [asp_k, asp_v] -> count 
sample_counter = Counter()

for info_d in check_info_l:
    if info_d['layer'] < 6:
        sample_id = info_d['id']
        d = good_samples[sample_id]
        text_match = d['category']['text_match']
        node_len = d['category']['node_len']
        sample_counter[(text_match, node_len)] += 1
        
        for asp_k, asp_v in d['category'].items():
            sample_counter_by_aspect[asp_k][asp_v] += 1

In [715]:
sample_counter_by_aspect

defaultdict(collections.Counter,
            {'sql_hardness': Counter({'medium': 58,
                      'hard': 42,
                      'extra': 51,
                      'easy': 20}),
             'node_role': Counter({'from': 105, 'join': 66}),
             'text_match': Counter({'partial': 59,
                      'no-match': 74,
                      'exact': 38}),
             'node_len': Counter({'4+': 60, '3': 51, '2': 24, '1': 36})})

In [716]:
sample_counter.most_common()

[(('partial', '3'), 30),
 (('partial', '4+'), 29),
 (('exact', '1'), 22),
 (('no-match', '3'), 21),
 (('no-match', '4+'), 21),
 (('no-match', '2'), 18),
 (('no-match', '1'), 14),
 (('exact', '4+'), 10),
 (('exact', '2'), 6)]

In [717]:
sample_counter_by_aspect = defaultdict(Counter)  # [asp_k, asp_v] -> count 
sample_counter = Counter()

for info_d in check_info_l:
    if info_d['layer'] > 18:
        sample_id = info_d['id']
        d = good_samples[sample_id]
        text_match = d['category']['text_match']
        node_len = d['category']['node_len']
        sample_counter[(text_match, node_len)] += 1
        
        for asp_k, asp_v in d['category'].items():
            sample_counter_by_aspect[asp_k][asp_v] += 1

In [718]:
sample_counter_by_aspect

defaultdict(collections.Counter,
            {'sql_hardness': Counter({'hard': 9,
                      'medium': 27,
                      'easy': 11,
                      'extra': 20}),
             'node_role': Counter({'join': 21, 'from': 46}),
             'text_match': Counter({'exact': 46,
                      'no-match': 20,
                      'partial': 1}),
             'node_len': Counter({'3': 12, '1': 46, '2': 6, '4+': 3})})

In [719]:
sample_counter.most_common()

[(('exact', '1'), 38),
 (('no-match', '3'), 9),
 (('no-match', '1'), 8),
 (('exact', '3'), 3),
 (('no-match', '2'), 3),
 (('exact', '2'), 3),
 (('exact', '4+'), 2),
 (('partial', '4+'), 1)]

In [743]:
# _min_p = 1.0

# for i, d in enumerate(good_samples):
#     sect_k = 'text'
#     sect_d = d['trace_scores'][sect_k]
#     _min_p = min(_min_p, sect_d['all_layers'])

# _min_p

7.385318918917694e-12

In [751]:
### Systematic 

ob_sect_k = 'all'

_th = 0.4

check_info_l = []
all_layers_eff_cnt = 0  # for this section to observe, how many samples are effective with all_layers
window_eff_cnt = 0      # for this section to observe, how many samples are effective with any window 

for i, d in enumerate(good_samples):
#     for sect_k, sect_d in d['trace_scores'].items():
    sect_k = ob_sect_k
    sect_d = d['trace_scores'][sect_k]
    if sect_d['all_layers'] > 0.5:
        # not effective
        continue
    else:
        all_layers_eff_cnt += 1
        
    if min(sect_d['window'].values()) > 0.5:
        # not effective
        continue
    else:
        window_eff_cnt += 1
        
    window_d = sect_d['window']
    for l in range(1, 24):
        if window_d[str(l-1)] - window_d[str(l)] > _th:
            _info_d = {
                'id': i,
                'sect_k': sect_k,
                'layer': l,
                'last_layer_score': window_d[str(l-1)],
                'this_layer_score': window_d[str(l)],
            }
            check_info_l.append(_info_d)
            break
len(check_info_l), window_eff_cnt, all_layers_eff_cnt, len(good_samples)

(840, 916, 1207, 1207)

In [745]:
ctg_list = [(tm, nl) for tm in ['exact', 'partial', 'no-match'] for nl in ['1', '2', '3', '4+']]
layer_list = [str(l) for l in range(1, 24)]

ctg_elem2id = {elem : i for i, elem in enumerate(ctg_list)}
layer_elem2id = {elem : i for i, elem in enumerate(layer_list)}

cnt_matrix = np.zeros((len(ctg_list), len(layer_list)), int)

for info_d in check_info_l:
    sample_id = info_d['id']
    d = good_samples[sample_id]
    text_match = d['category']['text_match']
    node_len = d['category']['node_len']
    _ctg = (text_match, node_len)
    _layer = str(info_d['layer'])
    
    _ctg_idx = ctg_elem2id[_ctg]
    _layer_idx = layer_elem2id[_layer]
    cnt_matrix[_ctg_idx, _layer_idx] += 1

In [None]:
# Create a figure and axis
fig, ax = plt.subplots(figsize=(8, 6))

# Display the matrix using imshow
im = ax.imshow(cnt_matrix, cmap='Blues')

# Set the tick labels for the first and second dimensions
ax.set_xticks(np.arange(len(layer_list)))
ax.set_yticks(np.arange(len(ctg_list)))

# Set the tick labels using the ctg_list and layer_list
ax.set_xticklabels(layer_list)
ax.set_yticklabels(ctg_list)

ax.set_title(f'Section: {ob_sect_k}\n')

# Rotate the x-axis tick labels if needed
# plt.xticks(rotation=90)

# Add a colorbar
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.5)

# Show the plot
plt.show()


In [None]:
# Create a figure and axis
fig, ax = plt.subplots(figsize=(8, 6))

# Display the matrix using imshow
im = ax.imshow(cnt_matrix, cmap='Blues')

# Set the tick labels for the first and second dimensions
ax.set_xticks(np.arange(len(layer_list)))
ax.set_yticks(np.arange(len(ctg_list)))

# Set the tick labels using the ctg_list and layer_list
ax.set_xticklabels(layer_list)
ax.set_yticklabels(ctg_list)

ax.set_title(f'Section: {ob_sect_k}\n')

# Rotate the x-axis tick labels if needed
# plt.xticks(rotation=90)

# Add a colorbar
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.5)

# Show the plot
plt.show()


In [747]:
plt.close()

### Exp-5.3: attention section mutual removal

#### Load & Check

In [837]:
expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [838]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [839]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(2039, (364, 364), 339, 1336, 1675, 'good / correct = 364 / 1700')

#### Overall avg

In [840]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            if k == 'window':
                for l, s in v.items():
                    trace_scores_avg[sect_k][f'{k}-{l}'] += s
            else:
                s = v
                trace_scores_avg[sect_k][k] += s

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [None]:
trace_scores_avg

#### Avg by aspects (category)

In [842]:
d['category']

{'sql_hardness': 'medium', 'node_role': 'where', 'text_match': 'no-match'}

In [843]:
# # TEMP patch for node_len category 
# for d in good_samples + bad_samples:
#     node_len = len(d['answers_t'])
#     assert len(mt_uskg.tokenizer.tokenize(d['expect'])) == node_len, (d['expect'], node_len)
#     d['category']['node_len'] = str(node_len) if node_len <= 3 else '4+'

In [844]:
d['category']

{'sql_hardness': 'medium',
 'node_role': 'group by',
 'text_match': 'exact',
 'node_len': '3'}

In [845]:
# Key: (sect_k, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no sect key & layer key 

In [846]:
for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for k, v in sect_d.items():
                if k == 'window':
                    for l, s in v.items():
                        if not (int(l) % 4 == 3): continue
                        layer_k = f'{k}-{l}'
                        trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                else:
                    layer_k = k
                    s = v
                    trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                    
for sect_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for layer_k, s in d3.items():
                trace_scores_avg_by_aspect[sect_k][asp_k][asp_v][layer_k] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [847]:
for sect_k, sect_d in trace_scores_avg_by_aspect.items():
    sect_d['overall'] = dict()
    for layer_k, s in trace_scores_avg[sect_k].items():
        if layer_k.startswith('window'):
            # only keep a subset of layers 
            _, l = layer_k.split('-')
            if not (int(l) % 4 == 3): continue
        sect_d['overall'][layer_k] = s

In [848]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 123, 'extra': 175, 'hard': 64, 'easy': 2}),
             'node_role': defaultdict(int,
                         {'select': 191,
                          'group by': 33,
                          'join': 12,
                          'where': 111,
                          'order by': 17}),
             'text_match': defaultdict(int,
                         {'exact': 230, 'partial': 28, 'no-match': 106}),
             'node_len': defaultdict(int, {'3': 364})})

In [None]:
trace_scores_avg_by_aspect['c->p']

In [850]:
dump_d = ctu.nested_json_processing(trace_scores_avg_by_aspect, func=lambda x: np.format_float_positional(x, precision=4, min_digits=4))
# dump_d

In [851]:
dump_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/summ-exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(dump_path, 'w') as f:
    json.dump(dump_d, f, indent=1)

#### (one-time temp patch)

In [804]:
# expect_type = 'table_alias'
# orig_res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/no_c2p_exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'
# add_res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.1+c2p_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

# merge_res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

In [805]:
# with open(orig_res_path, 'r') as f:
#     orig_all_samples = [json.loads(l) for l in f]
# with open(add_res_path, 'r') as f:
#     add_all_samples = [json.loads(l) for l in f]

# f = open(merge_res_path, 'w')
    
# for i, (orig_ex, add_ex) in enumerate(zip(orig_all_samples, add_all_samples)):
#     assert len(orig_ex['trace_results']) == len(add_ex['trace_results']), i
#     # There is randomness in the order of expected node (from set()), thus sorting here 
#     orig_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     add_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     for j, (orig_d, add_d) in enumerate(zip(orig_ex['trace_results'], add_ex['trace_results'])):
#         assert orig_d['is_good_sample'] == add_d['is_good_sample'], (i, j)
#         if not orig_d['is_good_sample']:
#             continue
            
#         # is good sample: add the new sections 
#         orig_d['trace_scores']['c->p'] = add_d['trace_scores']['c->p']
        
#         # put all at end in the dict 
#         _t = orig_d['trace_scores']['all']
#         del orig_d['trace_scores']['all']
#         orig_d['trace_scores']['all'] = _t
        
#     f.write(json.dumps(orig_ex, indent=None) + '\n')
    
# f.close()

### Exp-5.4: attention section mutual removal

#### Load & Check

In [336]:
expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp5_4_decoder_cross_attention_removal/exp=5.4_dev_{expect_type}-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [337]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [338]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(2039, (395, 395), 339, 1305, 1644, 'good / correct = 395 / 1700')

#### Overall avg

In [339]:
# Dict[sect_k, Dict[layer_k, s]]
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, s in sect_d.items():
            trace_scores_avg[sect_k][k] += s

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [340]:
trace_scores_avg

{'all': defaultdict(int,
             {'low_layers': 0.5355958729745377,
              'high_layers': 0.8155413844211263,
              'all_layers': 0.1056054830929244}),
 'ans->t': defaultdict(int,
             {'low_layers': 0.9304732342151586,
              'high_layers': 0.9741443378360842,
              'all_layers': 0.8695232256844865}),
 'all->t': defaultdict(int,
             {'low_layers': 0.9166845399916218,
              'high_layers': 0.9720493624806168,
              'all_layers': 0.8491711175152609}),
 'ans->s': defaultdict(int,
             {'low_layers': 0.9289467685215815,
              'high_layers': 0.9696432898008954,
              'all_layers': 0.9363590360931409}),
 'all->s': defaultdict(int,
             {'low_layers': 0.82528036536228,
              'high_layers': 0.9648957481430593,
              'all_layers': 0.8213527668459242}),
 'ans->p': defaultdict(int,
             {'low_layers': 0.8902059622092646,
              'high_layers': 0.8411666841540542,
     

In [341]:
layers_keys = ['low_layers', 'high_layers', 'all_layers']

for sect_k, sect_d in trace_scores_avg.items():
    # 'all->?' results seem similar to 'ans->?' and make less intuitive sense; skip for now 
    if sect_k.startswith('all->'):
        continue
    print_l = f'{sect_k:<8s}'
    for k in layers_keys:
        s = sect_d[k]
        print_l += f'\t{s:.4f}'
    print(print_l)

all     	0.5356	0.8155	0.1056
ans->t  	0.9305	0.9741	0.8695
ans->s  	0.9289	0.9696	0.9364
ans->p  	0.8902	0.8412	0.7503
ans->c  	0.9460	0.9760	0.9602
ans->self	0.9771	0.9714	0.9579


### Exp-6.0: corruption effect - syntax

#### Load & Check

In [243]:
# expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp6_0_encoding_corruption_effect_syntax/exp=6.0_dev.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [244]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [245]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d.get('is_good_sample', True):
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(10233, (2261, 2261), 1623, 6349, 7972, 'good / correct = 2261 / 8610')

#### Overall avg

In [246]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            trace_scores_avg[sect_k][k] += v

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [247]:
trace_scores_avg

{'text': defaultdict(int,
             {'embed': 0.2704069729491695, 'final_enc': 0.43288231847139375}),
 'struct': defaultdict(int,
             {'embed': 0.8434888537914244, 'final_enc': 0.7056294551667914}),
 'columns': defaultdict(int,
             {'embed': 0.8977011346416999, 'final_enc': 0.9094602057515592}),
 'tables': defaultdict(int,
             {'embed': 0.9401333304200568, 'final_enc': 0.9652279544981762}),
 'all': defaultdict(int,
             {'embed': 0.04223158108438767, 'final_enc': 0.14583704401879738})}

#### Corruption overall effect

In [248]:
# Dict[str, int]: expect_tok -> num of effective / not effective corruptions (all)
eff_counter = Counter()
neff_counter = Counter()

for d in good_samples:
    eff_counter[d['expect']] += 1
for d in bad_samples:
    if d['correct_prediction']:
        # "too easy", corruption not effective 
        neff_counter[d['expect']] += 1

In [249]:
eff_rate_d = dict()

for k in list(set(eff_counter.keys()) | set(neff_counter.keys())):
    eff_c = eff_counter[k]
    neff_c = neff_counter[k]
    eff_r = 1.0 * eff_c / (eff_c + neff_c)
    eff_rate_d[k] = eff_r
    # print(f'{k:<10s}{eff_c:5d} /{eff_c + neff_c:5d} = {eff_r:.4f}')

In [270]:
for k, eff_r in sorted(eff_rate_d.items(), key=lambda x: x[1], reverse=True):
    eff_c = eff_counter[k]
    neff_c = neff_counter[k]
    all_c = eff_c + neff_c
    if all_c <= 2: continue
    if k.isnumeric(): continue
#     print(f'{k:<10s}{eff_c:5d} /{all_c:5d} = {eff_r:.4f}')
    print(f'{k}\t{eff_c}\t{all_c}\t{eff_r:.4f}')

union	6	6	1.0000
!=	20	20	1.0000
like	12	12	1.0000
or	34	34	1.0000
min	18	18	1.0000
asc	19	19	1.0000
max	30	30	1.0000
between	6	6	1.0000
except	21	21	1.0000
avg	65	65	1.0000
intersect	34	34	1.0000
having	80	81	0.9877
distinct	25	26	0.9615
sum	21	22	0.9545
where	484	516	0.9380
not	42	46	0.9130
group	225	265	0.8491
and	31	39	0.7949
count	267	406	0.6576
order	142	221	0.6425
>	61	101	0.6040
)	11	23	0.4783
=	191	968	0.1973
as	93	952	0.0977
desc	16	164	0.0976
in	4	50	0.0800
join	39	496	0.0786
from	49	1196	0.0410
limit	6	177	0.0339
>=	1	30	0.0333
(	22	675	0.0326
*	0	381	0.0000
by	0	516	0.0000
on	0	516	0.0000
select	0	88	0.0000


#### Avg by expect syntax token

In [271]:
# Key: (expect_tok, sect_k, layer) -> [scores]
trace_scores_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
trace_scores_avg_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
trace_scores_cnt_by_exp_tok = defaultdict(int)  # no sect key & layer key 

trace_sample_ids_by_exp_tok = defaultdict(list)

In [272]:
for i, d in enumerate(good_samples):
    expect = d['expect']
    trace_sample_ids_by_exp_tok[expect].append(i)
    for sect_k, sect_d in d['trace_scores'].items():
        for layer_k, v in sect_d.items():
            trace_scores_by_exp_tok[expect][sect_k][layer_k].append(v)

for exp_tok, d1 in trace_scores_by_exp_tok.items():
    if exp_tok.isnumeric(): continue
    for sect_k, d2 in d1.items():
        for layer_k, scores in d2.items():
            if len(scores) <= 2: continue
            trace_scores_avg_by_exp_tok[exp_tok][sect_k][layer_k] = np.mean(scores)
            trace_scores_cnt_by_exp_tok[exp_tok] = len(scores)

In [264]:
trace_scores_cnt_by_exp_tok

defaultdict(int,
            {'count': 267,
             'order': 142,
             'avg': 65,
             'min': 18,
             'max': 30,
             'where': 484,
             'distinct': 25,
             '>': 61,
             'group': 225,
             '(': 22,
             'between': 6,
             'from': 49,
             'desc': 16,
             'or': 34,
             'not': 42,
             'intersect': 34,
             'except': 21,
             'as': 93,
             'join': 39,
             'like': 12,
             'and': 31,
             '=': 191,
             'having': 80,
             '!=': 20,
             'union': 6,
             'limit': 6,
             'sum': 21,
             'asc': 19,
             ')': 11,
             'in': 4})

In [265]:
trace_scores_avg_by_exp_tok['count']

defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
            {'text': defaultdict(float,
                         {'embed': 0.07181598544092065,
                          'final_enc': 0.08561127370659098}),
             'struct': defaultdict(float,
                         {'embed': 0.9997616432579269,
                          'final_enc': 0.9906901482785686}),
             'columns': defaultdict(float,
                         {'embed': 0.9989915059300397,
                          'final_enc': 0.9982228384035804}),
             'tables': defaultdict(float,
                         {'embed': 0.9940267473124387,
                          'final_enc': 0.9992919242783879}),
             'all': defaultdict(float,
                         {'embed': 0.022402080392785025,
                          'final_enc': 0.07584266595464821})})

In [266]:
sect_k = 'struct'
layer_k = 'embed'
scores_d = dict()

for exp_tok, d1 in trace_scores_avg_by_exp_tok.items():
    s = d1[sect_k][layer_k]
    scores_d[exp_tok] = s
    # print(f'{exp_tok:<10s}{s:.4f}')

In [267]:
for k, s in sorted(scores_d.items(), key=lambda x: x[1], reverse=True):
    print(f'{k:<10s}{s:.4f}')

like      1.0000
min       1.0000
count     0.9998
avg       0.9997
limit     0.9993
!=        0.9968
union     0.9831
or        0.9807
>         0.9722
having    0.9691
intersect 0.9634
except    0.9442
sum       0.9386
group     0.9373
order     0.9079
)         0.9057
max       0.9033
desc      0.8710
where     0.8684
distinct  0.8665
(         0.8557
between   0.8547
asc       0.8026
not       0.7477
and       0.6863
from      0.6251
=         0.6110
in        0.4806
as        0.2753
join      0.0001


In [268]:
sect_k = 'text'
layer_k = 'embed'
scores_d = dict()

for exp_tok, d1 in trace_scores_avg_by_exp_tok.items():
    s = d1[sect_k][layer_k]
    scores_d[exp_tok] = s
    # print(f'{exp_tok:<10s}{s:.4f}')

for k, s in sorted(scores_d.items(), key=lambda x: x[1], reverse=True):
    print(f'{k:<10s}{s:.4f}')

from      0.8645
as        0.8363
=         0.8272
in        0.7407
)         0.7101
desc      0.4644
where     0.3572
>         0.3472
join      0.3131
not       0.2707
and       0.2300
having    0.1737
min       0.1698
(         0.1657
like      0.1639
union     0.1435
order     0.1126
group     0.1046
asc       0.0877
count     0.0718
avg       0.0311
or        0.0211
except    0.0106
distinct  0.0097
max       0.0073
sum       0.0017
intersect 0.0005
!=        0.0000
between   0.0000
limit     0.0000


#### Corrupted answer

In [279]:
# Dict[str, Dict[str, int]]: exp_tok -> c_ans, cnt
confusion_counter = defaultdict(Counter)

for d in good_samples:
    exp_tok = d['expect']
    c_ans = d['corrupted_answer']
    if exp_tok.isnumeric():
        exp_tok = 'NUM'
    if c_ans.isnumeric():
        c_ans = 'NUM'
    confusion_counter[exp_tok][c_ans] += 1

In [280]:
confusion_counter

defaultdict(collections.Counter,
            {'count': Counter({'*': 182, '': 64, 'sum': 19, 'count': 2}),
             'order': Counter({'</s>': 96,
                      'join': 12,
                      ')': 2,
                      'where': 9,
                      'NUM': 5,
                      'order': 3,
                      'union': 2,
                      'and': 2,
                      'select': 4,
                      's': 6,
                      '_': 1}),
             'avg': Counter({'*tvg': 43,
                      'maxtvg': 8,
                      'maxavg': 2,
                      'tvg': 6,
                      'mintvg': 2,
                      'counttvg': 4}),
             'min': Counter({'': 9, 'max': 5, '*': 4}),
             'max': Counter({'': 13, '*': 16, 'min': 1}),
             'where': Counter({'group': 19,
                      'except': 2,
                      '</s>': 239,
                      'order': 114,
                      'where': 9,
        

#### Case study

In [281]:
for idx in trace_sample_ids_by_exp_tok['(']:
    d = good_samples[idx]
    print(f"{d['dec_prompt']} --> {d['corrupted_answer']} ({d['expect']})")

select song_name from singer where age > --> = (()
select song_name from singer where age > --> = (()
select count(*) from concert where stadium_id = -->  (()
select count(*) from concert where stadium_id = -->  (()
select t2.make, t1.year from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t1.year = --> 2004 (()
select t2.make, t1.year from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t1.year = --> 2004 (()
select count(*) from cars_data where accelerate > --> = (()
select count(*) from cars_data where accelerate > --> = (()
select t2.makeid, t2.make from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t1.horsepower > -->  (()
select name from shop where number_products > --> = (()
select name from shop where number_products > --> = (()
select name from museum where num_of_staff > --> = (()
select name from country where surfacearea > --> = (()
select name from country where surfacearea > --> = (()
select name from country where cont

### Exp-6.1: attention corruption effect - syntax

#### Load & Check

In [343]:
# expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp6_1_attention_corruption_effect_syntax/exp=6.1_dev_encoder_corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [344]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [345]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d.get('is_good_sample', True):
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(10233, (2524, 2524), 1623, 6086, 7709, 'good / correct = 2524 / 8610')

In [346]:
# good_samples[0]['trace_scores']

#### Overall avg

In [347]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            trace_scores_avg[sect_k][k] += v

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [348]:
trace_scores_avg

{'t->s': defaultdict(int,
             {'low_layers': 0.990021244916296,
              'high_layers': 0.9839921388335588,
              'all_layers': 0.9773429089978287}),
 's->t': defaultdict(int,
             {'low_layers': 0.9933494876220781,
              'high_layers': 0.9826286324625741,
              'all_layers': 0.9748112447138703}),
 't<->s': defaultdict(int,
             {'low_layers': 0.988124023103812,
              'high_layers': 0.9725914119834563,
              'all_layers': 0.9583643697837391}),
 't->p': defaultdict(int,
             {'low_layers': 0.9813802007457678,
              'high_layers': 0.9107886133535986,
              'all_layers': 0.8583388319168208}),
 's->p': defaultdict(int,
             {'low_layers': 0.9871366122154147,
              'high_layers': 0.949990724645636,
              'all_layers': 0.9415907011434735}),
 'ts->p': defaultdict(int,
             {'low_layers': 0.9728856815640369,
              'high_layers': 0.8155521331662294,
             

In [350]:
layers_keys = ['low_layers', 'high_layers', 'all_layers']

for sect_k, sect_d in trace_scores_avg.items():
    print_l = f'{sect_k:<8s}'
    for k in layers_keys:
        s = sect_d[k]
        print_l += f'\t{s:.4f}'
    print(print_l)

t->s    	0.9900	0.9840	0.9773
s->t    	0.9933	0.9826	0.9748
t<->s   	0.9881	0.9726	0.9584
t->p    	0.9814	0.9108	0.8583
s->p    	0.9871	0.9500	0.9416
ts->p   	0.9729	0.8156	0.6507
t->t    	0.9668	0.9419	0.8062
s->s    	0.9435	0.9212	0.7679
all     	0.7465	0.4480	0.0683


#### Corruption overall effect

In [301]:
# Dict[str, int]: expect_tok -> num of effective / not effective corruptions (all)
eff_counter = Counter()
neff_counter = Counter()

for d in good_samples:
    eff_counter[d['expect']] += 1
for d in bad_samples:
    if d['correct_prediction']:
        # "too easy", corruption not effective 
        neff_counter[d['expect']] += 1

In [302]:
eff_rate_d = dict()

for k in list(set(eff_counter.keys()) | set(neff_counter.keys())):
    eff_c = eff_counter[k]
    neff_c = neff_counter[k]
    eff_r = 1.0 * eff_c / (eff_c + neff_c)
    eff_rate_d[k] = eff_r
    # print(f'{k:<10s}{eff_c:5d} /{eff_c + neff_c:5d} = {eff_r:.4f}')

In [303]:
for k, eff_r in sorted(eff_rate_d.items(), key=lambda x: x[1], reverse=True):
    eff_c = eff_counter[k]
    neff_c = neff_counter[k]
    all_c = eff_c + neff_c
    if all_c <= 2: continue
    if k.isnumeric(): continue
#     print(f'{k:<10s}{eff_c:5d} /{all_c:5d} = {eff_r:.4f}')
    print(f'{k}\t{eff_c}\t{all_c}\t{eff_r:.4f}')

union	6	6	1.0000
!=	20	20	1.0000
like	12	12	1.0000
or	34	34	1.0000
asc	19	19	1.0000
distinct	26	26	1.0000
between	6	6	1.0000
except	21	21	1.0000
intersect	34	34	1.0000
not	45	46	0.9783
avg	63	65	0.9692
max	29	30	0.9667
having	77	81	0.9506
group	241	265	0.9094
sum	20	22	0.9091
order	197	221	0.8914
min	14	18	0.7778
and	29	39	0.7436
count	294	406	0.7241
where	350	516	0.6783
>	68	101	0.6733
)	13	23	0.5652
desc	82	164	0.5000
>=	11	30	0.3667
from	263	1196	0.2199
in	8	50	0.1600
limit	26	177	0.1469
(	98	675	0.1452
=	128	968	0.1322
join	44	496	0.0887
as	52	952	0.0546
*	10	381	0.0262
by	0	516	0.0000
on	0	516	0.0000
select	0	88	0.0000


#### Avg by expect syntax token

In [289]:
# Key: (expect_tok, sect_k, layer) -> [scores]
trace_scores_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
trace_scores_avg_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
trace_scores_cnt_by_exp_tok = defaultdict(int)  # no sect key & layer key 

trace_sample_ids_by_exp_tok = defaultdict(list)

In [290]:
for i, d in enumerate(good_samples):
    expect = d['expect']
    trace_sample_ids_by_exp_tok[expect].append(i)
    for sect_k, sect_d in d['trace_scores'].items():
        for layer_k, v in sect_d.items():
            trace_scores_by_exp_tok[expect][sect_k][layer_k].append(v)

for exp_tok, d1 in trace_scores_by_exp_tok.items():
    if exp_tok.isnumeric(): continue
    for sect_k, d2 in d1.items():
        for layer_k, scores in d2.items():
            if len(scores) <= 2: continue
            trace_scores_avg_by_exp_tok[exp_tok][sect_k][layer_k] = np.mean(scores)
            trace_scores_cnt_by_exp_tok[exp_tok] = len(scores)

In [291]:
trace_scores_cnt_by_exp_tok

defaultdict(int,
            {'count': 294,
             'order': 197,
             'desc': 82,
             'avg': 63,
             'min': 14,
             'max': 29,
             'where': 350,
             '=': 128,
             'distinct': 26,
             'from': 263,
             '>': 68,
             'group': 241,
             '(': 98,
             'between': 6,
             'limit': 26,
             'or': 34,
             '>=': 11,
             'not': 45,
             'intersect': 34,
             'except': 21,
             'as': 52,
             'join': 44,
             'like': 12,
             'and': 29,
             'in': 8,
             'having': 77,
             '!=': 20,
             '*': 10,
             'union': 6,
             'sum': 20,
             'asc': 19,
             ')': 13})

In [292]:
trace_scores_avg_by_exp_tok['count']

defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
            {'t->s': defaultdict(float,
                         {'low_layers': 0.9978861685107354,
                          'high_layers': 0.9932368714425738,
                          'all_layers': 0.9911034725571596}),
             's->t': defaultdict(float,
                         {'low_layers': 0.9994523721892817,
                          'high_layers': 0.9979493310865091,
                          'all_layers': 0.997748848329596}),
             't<->s': defaultdict(float,
                         {'low_layers': 0.998062480874613,
                          'high_layers': 0.995554579890707,
                          'all_layers': 0.9959644333643167}),
             't->p': defaultdict(float,
                         {'low_layers': 0.9941589293855347,
                          'high_layers': 0.9844484830047099,
                          'all_layers': 0.9640641826280487}),
             's->p': defaultdict(float,
        

In [296]:
sect_k = 't<->s'
layer_k = 'all_layers'
scores_d = dict()

for exp_tok, d1 in trace_scores_avg_by_exp_tok.items():
    s = d1[sect_k][layer_k]
    scores_d[exp_tok] = s
    # print(f'{exp_tok:<10s}{s:.4f}')

In [297]:
for k, s in sorted(scores_d.items(), key=lambda x: x[1], reverse=True):
    print(f'{k:<10s}{s:.4f}')

>=        1.0000
>         1.0000
like      1.0000
*         1.0000
min       1.0000
between   1.0000
in        0.9992
avg       0.9972
count     0.9960
(         0.9946
desc      0.9942
having    0.9894
order     0.9826
!=        0.9825
=         0.9786
group     0.9771
max       0.9741
)         0.9716
where     0.9702
limit     0.9684
from      0.9671
or        0.9582
intersect 0.9522
asc       0.9437
except    0.9415
not       0.9360
sum       0.8969
as        0.8349
distinct  0.8284
union     0.8037
and       0.6727
join      0.4976


In [308]:
all_exp_toks = sorted(list(trace_scores_cnt_by_exp_tok.keys()))
all_sections = list(good_samples[0]['trace_scores'].keys())

print_str = '\t'.join(['Syntax-tok'] + all_sections + ['Eff_cnt', 'All_cnt', 'Eff_rate']) + '\n'

for exp_tok in all_exp_toks:
    print_str += f'{exp_tok:<10s}'
    for sect_k in all_sections:
        s = trace_scores_avg_by_exp_tok[exp_tok][sect_k]['all_layers']
        print_str += f'\t{s:.4f}'
    eff_c = eff_counter[exp_tok]
    neff_c = neff_counter[exp_tok]
    all_c = eff_c + neff_c
    eff_r = eff_rate_d[exp_tok]
    print_str += f'\t{eff_c:<7d}\t{all_c:<7d}\t{eff_r:.4f}'
    print_str += '\n'
    

In [309]:
print(print_str)

Syntax-tok	t->s	s->t	t<->s	t->p	s->p	ts->p	t->t	s->s	all	Eff_cnt	All_cnt	Eff_rate
!=        	0.9984	0.9995	0.9825	0.8485	0.9998	0.9070	0.3025	0.9193	0.0284	20     	20     	1.0000
(         	0.9999	0.9794	0.9946	0.9719	0.9795	0.9685	0.8983	0.9467	0.1856	98     	675    	0.1452
)         	0.9257	0.9926	0.9716	0.9914	0.9074	0.8420	0.9244	0.4715	0.1108	13     	23     	0.5652
*         	1.0000	0.9999	1.0000	0.9997	0.9999	0.9998	0.9999	0.9982	0.1802	10     	381    	0.0262
=         	0.9784	0.9923	0.9786	0.9378	0.8940	0.7439	0.9466	0.5933	0.1149	128    	968    	0.1322
>         	1.0000	1.0000	1.0000	0.9509	0.9990	0.9652	0.8193	0.9845	0.0219	68     	101    	0.6733
>=        	1.0000	1.0000	1.0000	0.9151	1.0000	0.9048	0.2593	0.9964	0.1703	11     	30     	0.3667
and       	0.7931	0.8981	0.6727	0.4962	0.8409	0.3468	0.3603	0.6749	0.0554	29     	39     	0.7436
as        	0.9212	0.8936	0.8349	0.7992	0.5034	0.3457	0.8478	0.3288	0.1002	52     	952    	0.0546
asc       	0.9723	0.9096	0.9437	0.4939	0.8596

In [310]:
# Issue checking: multi-token 
for exp_tok in all_exp_toks:
    print(mt_uskg.tokenizer.tokenize(exp_tok))

['▁', '!', '=']
['▁(']
['▁', ')']
['▁*']
['▁=']
['▁>']
['▁>', '=']
['▁and']
['▁as']
['▁as', 'c']
['▁', 'a', 'v', 'g']
['▁between']
['▁count']
['▁des', 'c']
['▁distinct']
['▁except']
['▁from']
['▁group']
['▁having']
['▁in']
['▁intersect']
['▁join']
['▁like']
['▁limit']
['▁max']
['▁min']
['▁not']
['▁or']
['▁order']
['▁sum']
['▁union']
['▁where']


### Exp-6.2: dec cross attention corruption effect - syntax

#### Load & Check

In [351]:
# expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp6_2_decoder_cross_attention_corruption_syntax/exp=6.2_dev_corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [352]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [353]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d.get('is_good_sample', True):
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(10233, (4938, 4938), 1623, 3672, 5295, 'good / correct = 4938 / 8610')

In [354]:
# good_samples[0]['trace_scores']

#### Overall avg

In [355]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            trace_scores_avg[sect_k][k] += v

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [356]:
trace_scores_avg

{'all': defaultdict(int,
             {'q1_layers': 0.9430973235928074,
              'q2_layers': 0.958486134229112,
              'q3_layers': 0.9611448908818483,
              'q4_layers': 0.96011300946364,
              'low_layers': 0.612874021720125,
              'mid_layers': 0.7143970828654054,
              'high_layers': 0.8043229374130009,
              'all_layers': 0.056718012975393826}),
 'ans->t': defaultdict(int,
             {'q1_layers': 0.9958142755001073,
              'q2_layers': 0.9903281497861435,
              'q3_layers': 0.9766086787565446,
              'q4_layers': 0.9752564481665504,
              'low_layers': 0.9848858261341927,
              'mid_layers': 0.9215631818847314,
              'high_layers': 0.8960677835880787,
              'all_layers': 0.7925715927119913}),
 'all->t': defaultdict(int,
             {'q1_layers': 0.9956391696137693,
              'q2_layers': 0.9868496488591167,
              'q3_layers': 0.9741849958150045,
              

In [445]:
# layers_keys = trace_scores_avg['all'].keys()

# for sect_k, sect_d in trace_scores_avg.items():
#     print_l = f'{sect_k:<8s}'
#     for k in layers_keys:
#         s = sect_d[k]
#         print_l += f'\t{s:.4f}'
#     print(print_l)

format_print_2D_dict(trace_scores_avg, head_col_w=7, col_w=11)

XXXXXXX	q1_layers  	q2_layers  	q3_layers  	q4_layers  	low_layers 	mid_layers 	high_layers	all_layers 
all    	0.9431     	0.9585     	0.9611     	0.9601     	0.6129     	0.7144     	0.8043     	0.0567     
ans->t 	0.9958     	0.9903     	0.9766     	0.9753     	0.9849     	0.9216     	0.8961     	0.7926     
all->t 	0.9956     	0.9868     	0.9742     	0.9753     	0.9686     	0.8916     	0.8913     	0.7414     
ans->s 	0.9906     	0.9921     	0.9955     	0.9970     	0.9592     	0.9905     	0.9947     	0.9581     
all->s 	0.9844     	0.9896     	0.9953     	0.9970     	0.9246     	0.9864     	0.9947     	0.9204     
ans->p 	0.9937     	0.9916     	0.9775     	0.9799     	0.9862     	0.9607     	0.9272     	0.8813     
all->p 	0.9806     	0.9769     	0.9748     	0.9795     	0.9317     	0.8975     	0.9103     	0.7091     



#### Avg by expect syntax token

In [379]:
res_by_exp_tok = exp6_ob_by_exp_tok(good_samples)

In [380]:
trace_scores_cnt_by_exp_tok = res_by_exp_tok['cnt']
trace_scores_cnt_by_exp_tok

defaultdict(int,
            {'count': 394,
             '(': 407,
             'order': 219,
             'avg': 63,
             'min': 18,
             'max': 28,
             'where': 502,
             'distinct': 26,
             'from': 702,
             '>': 65,
             'group': 265,
             'by': 33,
             'between': 6,
             'desc': 156,
             'or': 34,
             'as': 561,
             'on': 474,
             '>=': 3,
             'not': 46,
             'intersect': 34,
             'select': 44,
             'except': 21,
             '=': 409,
             'join': 20,
             'like': 10,
             'in': 15,
             'having': 62,
             'and': 6,
             'limit': 16,
             '!=': 20,
             '*': 10,
             'union': 6,
             'sum': 22,
             'asc': 19,
             ')': 17})

In [381]:
trace_scores_avg_by_exp_tok = res_by_exp_tok['avg']
trace_scores_avg_by_exp_tok['count']

defaultdict(<function __main__.exp6_ob_by_exp_tok.<locals>.<lambda>.<locals>.<lambda>()>,
            {'all': defaultdict(float,
                         {'q1_layers': 0.9964363076889575,
                          'q2_layers': 0.9963203995739143,
                          'q3_layers': 0.9551561746639813,
                          'q4_layers': 0.9972184849572061,
                          'low_layers': 0.8215906578539822,
                          'mid_layers': 0.46418598421683044,
                          'high_layers': 0.7123522221803182,
                          'all_layers': 0.014818084030064214}),
             'ans->t': defaultdict(float,
                         {'q1_layers': 0.9994401957480435,
                          'q2_layers': 0.9989948929263855,
                          'q3_layers': 0.9765654405254729,
                          'q4_layers': 0.9960574505255004,
                          'low_layers': 0.9980983920206273,
                          'mid_layers': 0.713294777

In [383]:
sect_k = 'ans->t'
layer_k = 'all_layers'
scores_d = dict()

for exp_tok, d1 in trace_scores_avg_by_exp_tok.items():
    s = d1[sect_k][layer_k]
    scores_d[exp_tok] = s
    # print(f'{exp_tok:<10s}{s:.4f}')

In [388]:
format_print_1D_dict(scores_d, sort_by='value', reverse=True)

select    1.0000
by        1.0000
on        1.0000
as        0.9977
in        0.9977
)         0.9955
from      0.9942
*         0.9826
=         0.9823
(         0.9583
where     0.9210
not       0.8927
join      0.8654
desc      0.8452
having    0.8170
>=        0.6667
group     0.6521
limit     0.6250
like      0.5985
>         0.5678
min       0.5502
except    0.5067
count     0.4403
sum       0.3922
or        0.3801
order     0.3762
union     0.3335
distinct  0.3008
and       0.1670
max       0.1403
asc       0.0839
intersect 0.0266
!=        0.0208
avg       0.0062
between   0.0000


In [399]:
all_exp_toks = sorted(list(trace_scores_cnt_by_exp_tok.keys()))
all_sections = list(good_samples[0]['trace_scores'].keys())

_d = {exp_tok:
          {sect_k: trace_scores_avg_by_exp_tok[exp_tok][sect_k]['all_layers']
           for sect_k in all_sections}
      for exp_tok in all_exp_toks}

format_print_2D_dict(_d, all_k1=all_exp_toks, all_k2=all_sections)

XXXXXXXXXX	all	ans->t	all->t	ans->s	all->s	ans->p	all->p
!=        	0.0013	0.0208	0.0001	0.9509	0.8203	0.7167	0.5739
(         	0.0747	0.9583	0.9549	0.9997	0.9923	0.7309	0.6955
)         	0.0212	0.9955	0.9951	0.8845	0.8465	0.9982	0.9014
*         	0.1240	0.9826	0.9720	0.9978	0.9306	0.7433	0.5726
=         	0.1195	0.9823	0.9761	0.8742	0.8298	0.9526	0.9377
>         	0.0210	0.5678	0.4234	0.9599	0.9359	0.9659	0.9533
>=        	0.1964	0.6667	0.6667	1.0000	1.0000	0.9994	0.8787
and       	0.1812	0.1670	0.1309	0.9885	0.6761	0.9118	0.6287
as        	0.0883	0.9977	0.9969	0.9404	0.8392	0.9844	0.8963
asc       	0.0000	0.0839	0.1381	0.8809	0.8523	0.8530	0.8652
avg       	0.0018	0.0062	0.0049	1.0000	1.0000	0.9441	0.9177
between   	0.0006	0.0000	0.0000	0.9992	0.9991	0.9994	0.9984
by        	0.0757	1.0000	1.0000	1.0000	1.0000	1.0000	1.0000
count     	0.0148	0.4403	0.3509	1.0000	0.9959	0.9574	0.9006
desc      	0.0511	0.8452	0.8698	0.9805	0.9593	0.9229	0.7996
distinct  	0.0005	0.3008	0.2718	0.9965	0.99

#### By aspects

In [401]:
sample_ids_by_aspect = defaultdict(list)

for i, d in enumerate(good_samples):
    for asp_k, asp_v in d['category'].items():
        asp_str_k = f'{asp_k}={asp_v}'
        sample_ids_by_aspect[asp_str_k].append(i)

In [403]:
{(k, len(l)) for k, l in sample_ids_by_aspect.items()}

{('sql_hardness=easy', 689),
 ('sql_hardness=extra', 1402),
 ('sql_hardness=hard', 1004),
 ('sql_hardness=medium', 1843)}

In [418]:
# asp_k -> avg/cnt/sample_ids -> exp_tok -> sect_k -> layer_k -> s
all_res_by_exp_tok = {asp_k : exp6_ob_by_exp_tok(good_samples[i] for i in asp_sample_ids)
                      for asp_k, asp_sample_ids in sample_ids_by_aspect.items()}

# exp_tok -> [sect_k -> [layer_k -> s]]
avg_d = all_res_by_exp_tok['sql_hardness=extra']['avg']

all_exp_toks = sorted(list(avg_d.keys()))
all_sections = list(avg_d[all_exp_toks[0]].keys())

_d = {exp_tok:
          {sect_k: avg_d[exp_tok][sect_k]['all_layers']
           for sect_k in all_sections}
      for exp_tok in all_exp_toks}

In [419]:
format_print_2D_dict(_d, all_k1=all_exp_toks, all_k2=all_sections, col_w=6)

XXXXXXXXXXXX	all   	ans->t	all->t	ans->s	all->s	ans->p	all->p
(           	0.0960	0.9150	0.9145	0.9995	0.9683	0.5015	0.4892
)           	0.0147	0.9977	0.9928	0.9659	0.9458	0.9992	0.9988
*           	0.2047	1.0000	1.0000	0.9963	0.8843	0.9579	0.6780
=           	0.1128	0.9823	0.9769	0.7711	0.7308	0.9014	0.8690
>           	0.0111	0.5996	0.5905	0.9682	0.8337	0.8913	0.8308
as          	0.0735	0.9936	0.9916	0.9374	0.8383	0.9773	0.8630
avg         	0.0001	0.0003	0.0001	1.0000	1.0000	0.9980	0.9884
count       	0.0166	0.7992	0.7450	1.0000	1.0000	0.9679	0.9111
desc        	0.0754	0.9459	0.9655	1.0000	1.0000	0.9224	0.8590
distinct    	0.0005	0.5546	0.5506	0.9998	0.9998	0.8881	0.8270
except      	0.0000	0.2832	0.0007	0.8574	0.6992	0.5941	0.5798
from        	0.0672	0.9886	0.9863	0.9704	0.9614	0.7393	0.6001
group       	0.0083	0.3118	0.1122	0.9870	0.8914	0.8986	0.3441
having      	0.0316	0.8593	0.4232	0.7806	0.9744	0.7360	0.7438
in          	0.1954	1.0000	1.0000	1.0000	0.7843	0.9224	0.7463
intersec

In [446]:
_d = dict()

h_list = ['easy', 'medium', 'hard', 'extra']

for exp_tok in all_exp_toks:
    _d[exp_tok] = {
        h: all_res_by_exp_tok[f'sql_hardness={h}']['avg'][exp_tok]['all->t']['all_layers']
        for h in h_list
    }
    for h in h_list:
        _cnt = all_res_by_exp_tok[f'sql_hardness={h}']['cnt'][exp_tok]
        if _cnt < 3:
            # _d[exp_tok][h] = - 1 - _cnt
            _d[exp_tok][h] = np.nan

In [447]:
format_print_2D_dict(_d, sort_k1_kwargs={}, col_w=8)

XXXXXXXXXXXX	easy    	medium  	hard    	extra   
(           	1.0000  	1.0000  	0.8646  	0.9145  
)           	nan     	nan     	0.9994  	0.9928  
*           	nan     	nan     	nan     	1.0000  
=           	0.9633  	0.9855  	0.9654  	0.9769  
>           	0.4182  	0.4960  	0.2578  	0.5905  
as          	1.0000  	0.9998  	1.0000  	0.9916  
avg         	0.0046  	0.0101  	0.0000  	0.0001  
count       	0.1406  	0.1884  	0.5781  	0.7450  
desc        	0.4946  	0.6938  	0.9553  	0.9655  
distinct    	0.0554  	0.1229  	0.3660  	0.5506  
except      	nan     	nan     	0.0244  	0.0007  
from        	0.9417  	0.9906  	0.9807  	0.9863  
group       	0.5330  	0.7302  	0.3732  	0.1122  
having      	0.6407  	0.7282  	0.2358  	0.4232  
in          	nan     	nan     	0.9526  	1.0000  
intersect   	nan     	nan     	0.0005  	0.0007  
join        	nan     	nan     	0.9998  	0.8096  
limit       	nan     	0.9985  	nan     	0.3333  
min         	nan     	0.1869  	nan     	0.3291  
not         	nan    

In [424]:
# _d = dict()

# for exp_tok in all_exp_toks:
#     _d[exp_tok] = {
#         'easy': all_res_by_exp_tok['sql_hardness=easy']['cnt'][exp_tok],
#         'extra': all_res_by_exp_tok['sql_hardness=extra']['cnt'][exp_tok],
#     }

## Tests

### create_analysis_samples

In [136]:
ex_id = 111
a_ex_id = 0

ex = processed_spider_dev[ex_id]
ex['text_in'], \
ex['struct_in'], \
ex['seq_out']

('What is the accelerate of the car make amc hornet sportabout (sw)?',
 '| car_1 | continents : contid , continent | countries : countryid , countryname , continent | car_makers : id , maker ( amc ) , fullname , country | model_list : modelid , maker , model ( amc ) | car_names : makeid , model ( amc ) , make ( amc hornet , amc hornet sportabout (sw) ) | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';")

In [137]:
# temp test
# ex['seq_out'] = 'select year from cars_data'

In [138]:
a_ex_list = ctu.create_analysis_sample_dicts(
                mt_uskg, ex,
                subject_type='column',
                remove_struct_duplicate_nodes=True)

In [139]:
a_ex_list[a_ex_id].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'enc_tokenized', 'text_range', 'struct_range', 'struct_node_ranges_dict', 'dec_prompt', 'expect', 'expect_type', 'remove_struct_duplicate_nodes', 'parsed_struct_in', 'col2table', 'token_ranges_dict', 'node_name_ranges', 'expect_input_ranges', 'alias2table', 'self_ranges', 'context_ranges', 'category'])

In [140]:
a_ex_list[a_ex_id]['alias2table']

{'t1': 'cars_data', 't2': 'car_names'}

In [None]:
[(d['dec_prompt'], d['expect'], d['node_name_ranges'], d['expect_input_ranges'], '------',\
  d['self_ranges'], d['context_ranges'],\
  d['category'], '------' * 2) for d in a_ex_list]

In [None]:
d = dict(a_ex_list[a_ex_id])
d

In [None]:
d = ctu.add_clean_prediction(mt_uskg, d)

In [None]:
d

#### parse_sql_alias2table

In [185]:
_sql = 'SELECT t2.aaa , t3.ccc FROM table_name as t1 JOIN other_table as t2 on table_name.a_a = other_table.b_a JOIN ttt as t3 on other_table.asth = ttt.asth'
ctu.parse_sql_alias2table(_sql)

{'t1': 'table_name', 't2': 'other_table', 't3': 'ttt'}

#### for syntax

In [149]:
_ex = copy.deepcopy(ex)
# _ex['seq_out'] += 'order by t1.mpg'
a_ex_list_syntax = ctu.create_syntax_analysis_sample_dicts(mt_uskg, _ex)

from 	 None 	 False
as 	 None 	 False
join 	 None 	 False
as 	 None 	 False
on 	 None 	 False
= 	 None 	 False
where 	 None 	 False
= 	 None 	 False
' 	 ' 	 True
amc 	 ' 	 True
hornet 	 ' 	 True
sportabout 	 ' 	 True
( 	 ' 	 True
sw 	 ' 	 True
) 	 ' 	 True
' 	 None 	 True


In [150]:
for a_ex in a_ex_list_syntax:
    print(a_ex['dec_prompt'], ' --> ', a_ex['expect'])

select t1.accelerate  -->  from
select t1.accelerate from cars_data  -->  as
select t1.accelerate from cars_data as t1  -->  join
select t1.accelerate from cars_data as t1 join car_names  -->  as
select t1.accelerate from cars_data as t1 join car_names as t2  -->  on
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id  -->  =
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid  -->  where
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make  -->  =


In [156]:
a_ex['enc_tokenized']

{'input_ids': [363, 19, 8, 16845, 13, 8, 443, 143, 183, 75, 3, 6293, 15, 17, 2600, 7932, 41, 7, 210, 61, 58, 117, 3, 7593, 15, 26, 1103, 10, 1820, 443, 834, 536, 1820, 10829, 7, 3, 10, 3622, 23, 26, 3, 6, 10829, 1820, 1440, 3, 10, 684, 23, 26, 3, 6, 684, 4350, 3, 6, 10829, 1820, 443, 834, 8910, 3, 10, 3, 23, 26, 3, 6, 13762, 41, 183, 75, 3, 61, 3, 6, 423, 4350, 3, 6, 684, 1820, 825, 834, 3350, 3, 10, 825, 23, 26, 3, 6, 13762, 3, 6, 825, 41, 183, 75, 3, 61, 1820, 443, 834, 4350, 7, 3, 10, 143, 23, 26, 3, 6, 825, 41, 183, 75, 3, 61, 3, 6, 143, 41, 183, 75, 3, 6293, 15, 17, 3, 6, 183, 75, 3, 6293, 15, 17, 2600, 7932, 41, 7, 210, 61, 3, 61, 1820, 2948, 834, 6757, 3, 10, 3, 23, 26, 3, 6, 3, 1167, 122, 3, 6, 3, 12980, 7, 3, 6, 3, 15, 10475, 40, 3, 6, 28906, 3, 6, 1293, 3, 6, 16845, 3, 6, 215, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [159]:
col_name_ranges = a_ex['token_ranges_dict']['col_name_ranges']
# col_name_indices = [i for s, e in col_name_ranges.values() for i in range(s, e)]
for ranges in col_name_ranges.values():
    print([mt_uskg.tokenizer.decode(a_ex['enc_tokenized']['input_ids'][s:e]) for s, e in ranges])

['contid']
['continent', 'continent']
['countryid']
['countryname']
['id', 'id']
['maker ( amc )', 'maker']
['fullname']
['country']
['modelid']
['model ( amc )', 'model ( amc )']
['makeid']
['make ( amc hornet, amc hornet sportabout (sw) )']
['mpg']
['cylinders']
['edispl']
['horsepower']
['weight']
['accelerate']
['year']


In [None]:
col_name_ranges

In [160]:
table_name_ranges = a_ex['token_ranges_dict']['table_name_ranges']
# col_name_indices = [i for s, e in col_name_ranges.values() for i in range(s, e)]
for ranges in table_name_ranges.values():
    print([mt_uskg.tokenizer.decode(a_ex['enc_tokenized']['input_ids'][s:e]) for s, e in ranges])

['continents']
['countries']
['car_makers']
['model_list']
['car_names']
['cars_data']


In [161]:
table_name_ranges

defaultdict(list,
            {'continents': [(33, 35)],
             'countries': [(44, 45)],
             'car_makers': [(58, 61)],
             'model_list': [(82, 85)],
             'car_names': [(102, 106)],
             'cars_data': [(146, 149)]})

### utils


In [117]:
_sql = 'SELECT t2.aaa, DISTINCT(t3.ccc), COUNT(*) FROM table_name as t1 JOIN other_table as t2 on table_name.a_a = other_table.b_a JOIN ttt as t3 on other_table.asth = ttt.asth WHERE t2.col like %hey% AND t3.p <= 40'.lower()
_tok_ranges = ctu.separate_punct_by_offset(_sql)
print([_sql[s:e] for s, e in _tok_ranges])

['select', 't2.', 'aaa', ',', 'distinct', '(', 't3.', 'ccc', ')', ',', 'count', '(', '*', ')', 'from', 'table_name', 'as', 't1', 'join', 'other_table', 'as', 't2', 'on', 'table_name', '.', 'a_a', '=', 'other_table', '.', 'b_a', 'join', 'ttt', 'as', 't3', 'on', 'other_table', '.', 'asth', '=', 'ttt', '.', 'asth', 'where', 't2.', 'col', 'like', '%', 'hey', '%', 'and', 't3.', 'p', '<=', '40']


In [315]:
_toks = mt_uskg.tokenizer.tokenize('which school is good? structed_knowledge: school | school : school_name, is_good')
print(len(_toks), _toks)

25 ['▁which', '▁school', '▁is', '▁good', '?', '▁', 'struct', 'e', 'd', '_', 'know', 'ledge', ':', '▁school', '▁|', '▁school', '▁', ':', '▁school', '_', 'name', ',', '▁is', '_', 'good']


In [322]:
_test_a_ex = {
    'enc_sentence': 'which school is good? structed_knowledge: school | school : school_name, is_good',
    'dec_prompt': 'select distinct',
    'expect': 'school_name',  # ['▁school', '_', 'name']
    'answers_t': [1,2,3],
    'answer': 'school_name',
    'text_range': [0, 5],
    'struct_range': [15, 25],
    'self_ranges': [[18, 21]],
    'context_ranges': [[15, 18], [21, 25]],
}

_test_att_masks = ctu.build_dec_cross_attention_mask(
    a_ex=_test_a_ex,
    mt=mt_uskg,
    use_self_node=True
)

In [None]:
_test_att_masks

### trace

In [207]:
a_ex = dict(a_ex_list[a_ex_id])
a_ex = ctu.add_clean_prediction(mt_uskg, a_ex)

In [208]:
result = ctu.make_basic_result_dict(a_ex)
result

{'enc_sentence': 'Which city has the most frequent destination airport?; structed knowledge: | flight_2 | airlines : uid , airline , abbreviation , country | airports : city , airportcode , airportname , country , countryabbrev | flights : airline , flightno , sourceairport , destairport',
 'seq_out': 'select t1.city from airports as t1 join flights as t2 on t1.airportcode = t2.destairport group by t1.city order by count(*) desc limit 1',
 'dec_prompt': 'select t1.city from airports as t1 join flights as t2 on t1.airportcode = t2.destairport group by t1.',
 'expect': 'city',
 'expect_type': 'column',
 'db_id': 'flight_2',
 'expect_input_ranges': [(45, 46)],
 'expect_table': 'airports',
 'answer': 'city',
 'base_score': 0.9983423948287964,
 'answers_t': [6726],
 'correct_prediction': True,
 'category': {'sql_hardness': 'extra',
  'node_role': 'group by',
  'text_match': 'exact'}}

In [209]:
enc_sentence = a_ex['enc_sentence']
dec_prompt = a_ex['dec_prompt']
expect = a_ex['expect']
answer = result['answer']
answers_t = result['answers_t']

inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    [enc_sentence] * 11,
    [dec_prompt] * 11,
    answer=expect)

text_range = a_ex['text_range']
struct_range = a_ex['struct_range']

self_ranges = a_ex['self_ranges']
context_ranges = a_ex['context_ranges']

self_tok_indices = [i for s, e in self_ranges for i in range(s, e)]
context_tok_indices = corrupt_tok_indices = [i for s, e in context_ranges for i in range(s, e)]
text_tok_indices = list(range(*text_range))
struct_tok_indices = list(range(*struct_range))

In [210]:
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', mt_uskg.num_enc_layers - 1))
#                     for tnum in range(*struct_range)],
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
    tokens_to_mix=text_tok_indices,
    tokens_to_mix_individual_indices=True,
    replace=True,
).item()

In [211]:
answers_t, answer, _score

([6726], 'city', 0.8450507521629333)

In [212]:
states_to_corrupt = [(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
                for tnum in text_tok_indices]

_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
    states_to_corrupt=states_to_corrupt,
#     tokens_to_mix=corrupt_tok_indices,
#     tokens_to_mix_individual_indices=True,
    replace=True,
).item()
_score

0.8450507521629333

In [213]:
# Pair of identical input to test correctness 

_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
#                     for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_patch=[],
    states_to_patch_1st_pass=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 12))
                    for tnum in text_tok_indices],
    states_to_unpatch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 23))
                    for tnum in struct_tok_indices],
    answers_t=answers_t,
    tokens_to_mix=text_tok_indices,
    tokens_to_mix_individual_indices=True,
    tokens_to_mix_1st_pass=context_tok_indices,
    replace=True,
).item()

_score

tensor(0.9952, device='cuda:0')

In [214]:
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
#                     for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_patch=[],
    states_to_patch_1st_pass=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 12))
                    for tnum in text_tok_indices],
    states_to_unpatch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 23))
                    for tnum in struct_tok_indices],
    answers_t=answers_t,
    states_to_corrupt=[(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
                    for tnum in text_tok_indices],
    states_to_corrupt_1st_pass=[(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
                    for tnum in context_tok_indices],
    replace=True,
).item()

_score

tensor(0.9952, device='cuda:0')

In [215]:
# Test corrupting attention 
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
#                     for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
    states_to_corrupt=[(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", l, "self_attn"))
                    for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers)],
    replace=True,
).item()

_score

0.7253002524375916

In [None]:
[n for n, w in mt_uskg.model.named_parameters()]

In [216]:
vocab_probs = ctu.run_repatch_uskg_multi_token(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', mt_uskg.num_enc_layers - 1))
#                     for tnum in range(*struct_range)],
    states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
                    for tnum in self_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_unpatch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', mt_uskg.num_enc_layers - 1))
                    for tnum in self_tok_indices],
    answer_len=len(answers_t),
    tokens_to_mix=corrupt_tok_indices,
    tokens_to_mix_individual_indices=True,
    replace=True,
)

In [217]:
vocab_probs.size()

torch.Size([1, 32102])

In [218]:
torch.max(vocab_probs, dim=-1)

torch.return_types.max(
values=tensor([1.], device='cuda:0'),
indices=tensor([7634], device='cuda:0'))

In [219]:
vocab_probs[0, 7634]

tensor(1., device='cuda:0')

In [220]:
vocab_probs

tensor([[2.2642e-25, 1.2223e-15, 7.3942e-18,  ..., 9.1578e-20, 2.6884e-39,
         2.8131e-39]], device='cuda:0')

## Temp

### Debugging exp

In [497]:
ex = processed_spider_dev[97]
text_in = ex['text_in']
struct_in = ex['struct_in']

enc_sentence = f"{text_in}; structed knowledge: {struct_in}"
dec_prompt = "select t1.model from"
expect = "car_names"

In [498]:
inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    enc_sentences=[enc_sentence]*11,
    dec_prompts=[dec_prompt]*11,
    answer=expect
)

### RE

In [70]:
seq = 'aa,bb< cc  \t dd(  )ee <= ff=5 %h% "06-15".'
sep_pattern = r'\s+|\W'

all_matches = re.finditer(sep_pattern, seq)

In [71]:
all_matches = list(all_matches)
all_matches

[<re.Match object; span=(2, 3), match=','>,
 <re.Match object; span=(5, 6), match='<'>,
 <re.Match object; span=(6, 7), match=' '>,
 <re.Match object; span=(9, 13), match='  \t '>,
 <re.Match object; span=(15, 16), match='('>,
 <re.Match object; span=(16, 18), match='  '>,
 <re.Match object; span=(18, 19), match=')'>,
 <re.Match object; span=(21, 22), match=' '>,
 <re.Match object; span=(22, 23), match='<'>,
 <re.Match object; span=(23, 24), match='='>,
 <re.Match object; span=(24, 25), match=' '>,
 <re.Match object; span=(27, 28), match='='>,
 <re.Match object; span=(29, 30), match=' '>,
 <re.Match object; span=(30, 31), match='%'>,
 <re.Match object; span=(32, 33), match='%'>,
 <re.Match object; span=(33, 34), match=' '>,
 <re.Match object; span=(34, 35), match='"'>,
 <re.Match object; span=(37, 38), match='-'>,
 <re.Match object; span=(40, 41), match='"'>,
 <re.Match object; span=(41, 42), match='.'>]

In [72]:
all_matches[0].span()

(2, 3)

In [73]:
splits = [0] + [i for m in all_matches for i in m.span()] + [len(seq)]
splits = sorted(list(set(splits)))
print(splits)

[0, 2, 3, 5, 6, 7, 9, 13, 15, 16, 18, 19, 21, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 40, 41, 42]


In [74]:
st = 0
SP = ["<=", ">=", "<>", "!="]
toks = []

for s, e in zip(splits[:-1], splits[1:]):
    if not seq[s:e].strip():
        # is a whitespace
        st = e
    else:
        # is a punct
        if seq[s:s+2] in SP:
            # wait next
            continue
        toks.append(seq[st:e])
        st = e
        

In [75]:
print(toks)

['aa', ',', 'bb', '<', 'cc', 'dd', '(', ')', 'ee', '<=', 'ff', '=', '5', '%', 'h', '%', '"', '06', '-', '15', '"', '.']


### other temp

In [311]:
mt_uskg.tokenizer.tokenize('school_name')

['▁school', '_', 'name']

In [236]:
mt_uskg.tokenizer.tokenize('cylinder ')

['▁', 'cylinder']

In [233]:
mt_uskg.tokenizer.tokenize('cylinder xa')

['▁', 'cylinder', '▁', 'x', 'a']

In [240]:
mt_uskg.tokenizer.tokenize('structed_input : a | b')

['▁', 'struct', 'e', 'd', '_', 'in', 'put', '▁', ':', '▁', 'a', '▁|', '▁', 'b']

In [242]:
_sql = 'SELECT t2.aaa , COUNT(distinct t1.name) FROM cars_data as t1 JOIN models as t2 on cars_data.a_a = models.b_a'

mt_uskg.tokenizer.tokenize(_sql)

['▁',
 'SEL',
 'ECT',
 '▁',
 't',
 '2.',
 'a',
 'a',
 'a',
 '▁',
 ',',
 '▁CO',
 'UNT',
 '(',
 'distin',
 'c',
 't',
 '▁',
 't',
 '1.',
 'name',
 ')',
 '▁FROM',
 '▁cars',
 '_',
 'data',
 '▁as',
 '▁',
 't',
 '1',
 '▁',
 'JO',
 'IN',
 '▁models',
 '▁as',
 '▁',
 't',
 '2',
 '▁on',
 '▁cars',
 '_',
 'data',
 '.',
 'a',
 '_',
 'a',
 '▁=',
 '▁models',
 '.',
 'b',
 '_',
 'a']

In [342]:
N_layers = mt_uskg.num_dec_layers

layers_range_dict = {
    'q1_layers': range(0, N_layers // 4),
    'q2_layers': range(N_layers // 4, N_layers // 2),
    'q3_layers': range(N_layers // 2, N_layers * 3 // 4),
    'q4_layers': range(N_layers * 3 // 4, N_layers),
    'low_layers': range(0, N_layers // 2),
    'mid_layers': range(N_layers // 4, N_layers * 3 // 4),
    'high_layers': range(N_layers // 2, N_layers),
    'all_layers': range(N_layers),
}

layers_range_dict

{'q1_layers': range(0, 6),
 'q2_layers': range(6, 12),
 'q3_layers': range(12, 18),
 'q4_layers': range(18, 24),
 'low_layers': range(0, 12),
 'mid_layers': range(6, 18),
 'high_layers': range(12, 24),
 'all_layers': range(0, 24)}

In [391]:
d = {
    'a': {
        '1': 1.0,
        '2': 2.0,
    },
    'b': {
        '1': 2.0,
        '2': 1.0,
    },
}
reverse_2D_dict(d)

defaultdict(<function __main__.reverse_2D_dict.<locals>.<lambda>()>,
            {'1': defaultdict(float, {'a': 1.0, 'b': 2.0}),
             '2': defaultdict(float, {'a': 2.0, 'b': 1.0})})

## (placeholder)