<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [1]:
IS_COLAB = False

## Causal Tracing

A demonstration of the double-intervention causal tracing method.

The strategy used by causal tracing is to understand important
states within a transfomer by doing two interventions simultaneously:

1. Corrupt a subset of the input.  In our paper, we corrupt the subject tokens
   to frustrate the ability of the transformer to accurately complete factual
   prompts about the subject.
2. Restore a subset of the internal hidden states.  In our paper, we scan
   hidden states at all layers and all tokens, searching for individual states
   that carry the necessary information for the transformer to recover its
   capability to complete the factual prompt.

The traces of decisive states can be shown on a heatmap.  This notebook
demonstrates the code for conducting causal traces and creating these heatmaps.

In [442]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


The `experiments.causal_trace` module contains a set of functions for running causal traces.

In this notebook, we reproduce, demonstrate and discuss the interesting functions.

We begin by importing several utility functions that deal with tokens and transformer models.

In [3]:
import os, sys, re, json
import string
import torch
import numpy as np
import copy
from collections import defaultdict, Counter
from util import nethook
from util.globals import DATA_DIR
from experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    predict_token,
    predict_from_input,
    collect_embedding_std,
)
from dsets import KnownsDataset

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f83484ca640>

In [5]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## USKG

In [641]:
from transformers import (
    HfArgumentParser,
    set_seed,
    AutoTokenizer
)

# from uskg.models.unified.prefixtuning import Model
from uskg.models.unified import finetune, prefixtuning
from uskg.utils.configue import Configure
from uskg.utils.training_arguments import WrappedSeq2SeqTrainingArguments
from uskg.seq2seq_construction import spider as s2s_spider
from uskg.third_party.spider.preprocess.get_tables import dump_db_json_schema
from uskg.third_party.spider import evaluation as sp_eval
from tqdm.notebook import tqdm

# from nltk.stem.wordnet import WordNetLemmatizer
# import stanza

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

import matplotlib.pyplot as plt
import sqlite3
import ujson
import pickle

from experiments import causal_trace_uskg as ctu

In [7]:
mt_uskg = ctu.ModelAndTokenizer_USKG('t5-large-prefix')

Using tokenizer_uskg: hkunlp/from_all_T5_large_prefix_spider_with_cell_value2
Using tokenizer_fast: t5-large
prefix-tuning sequence length is 10.


In [8]:
list(mt_uskg.task_args.seq2seq)

[('constructor', 'seq2seq_construction.spider'),
 ('schema_serialization_with_db_content', True),
 ('target_with_db_id', False)]

In [9]:
mt_uskg.model.pretrain_model.encoder.embed_tokens is mt_uskg.model.pretrain_model.shared, \
mt_uskg.model.pretrain_model.decoder.embed_tokens is mt_uskg.model.pretrain_model.shared

(True, False)

In [10]:
mt_uskg.model.preseqlen

10

In [None]:
# [k for k,v in mt_uskg.model.named_parameters()]
[k for k,v in mt_uskg.model.named_modules()]

In [87]:
inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    enc_sentences=["Translate to German: My name is Wolfgang and I live in Berlin"],
    dec_prompts=["Mein Name ist Wolfgang"],
    device="cuda:0"
)

In [88]:
out = ctu.run_model_forward_uskg(mt_uskg.model, **inp)

In [89]:
out.keys(), out['logits'].size()

(odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state']),
 torch.Size([1, 5, 32102]))

In [90]:
logits = out["logits"][0, -1].detach().cpu().numpy()
logits.shape

(32102,)

In [91]:
top_5 = sorted(list(enumerate(logits)), key=lambda p: -p[1])[:5]
top_5

[(11, -1.8642352),
 (6, -9.727753),
 (5, -10.966707),
 (27, -11.037394),
 (213, -12.864212)]

In [92]:
[mt_uskg.tokenizer.decode([p[0]]) for p in top_5]

['and', ',', '.', 'I', 'where']

### Load spider dataset

In [11]:
spider_train_path = '/home/yshao/Projects/SDR-analysis/data/spider/train+ratsql_graph.json'
spider_dev_path = '/home/yshao/Projects/SDR-analysis/data/spider/dev+ratsql_graph.json'
spider_db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [12]:
raw_spider_dev = ctu.load_raw_dataset(
    data_filepath = spider_dev_path,
    db_path=spider_db_dir,
#     schema_cache=SCHEMA_CACHE
)
len(raw_spider_dev)

1034

In [13]:
raw_spider_dev[0].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph'])

In [14]:
mt_uskg.task_args.dataset.use_cache

True

In [15]:
processed_spider_dev = s2s_spider.DevDataset(
    args=mt_uskg.task_args,
    raw_datasets=raw_spider_dev,
    cache_root='../cache')

In [876]:
_id = 5
processed_spider_dev[_id]['text_in'], \
processed_spider_dev[_id]['struct_in'], \
processed_spider_dev[_id]['seq_out']

('What is the average, minimum, and maximum age for all French singers?',
 '| concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id',
 "select avg(age), min(age), max(age) from singer where country = 'France'")

In [17]:
_enc_sentence = f"{processed_spider_dev[_id]['text_in']}; structed knowledge: {processed_spider_dev[_id]['struct_in']}"
_toks = mt_uskg.tokenizer.tokenize(_enc_sentence)
len(_toks)

142

In [18]:
# # _occ_punct = set()

# for _id in range(len(processed_spider_dev)):
#     ex = processed_spider_dev[_id]
# #     _occ_punct.update(set(string.punctuation) & set(ex['seq_out']))
#     if '_(' in ex['struct_in']:
#         print(_id, ex['question'])
#         print(ex['struct_in'])
#         print(ex['seq_out'])
#         print()

In [451]:
## Train set

raw_spider_train = ctu.load_raw_dataset(
    data_filepath = spider_train_path,
    db_path=spider_db_dir,
)
processed_spider_train = s2s_spider.TrainDataset(
    args=mt_uskg.task_args,
    raw_datasets=raw_spider_train,
    cache_root='../cache')
len(processed_spider_train)

7000

In [462]:
_ex = processed_spider_train[3153]
_ex['struct_in'], _ex['text_in'], _ex['seq_out'], _ex['db_table_names']

('| assets_maintenance | third_party_companies : company_id , company_type , company_name , company_address , other_company_details | maintenance_contracts : maintenance_contract_id , maintenance_contract_company_id , contract_start_date , contract_end_date , other_contract_details | parts : part_id , part_name , chargeable_yn , chargeable_amount , other_part_details | skills : skill_id , skill_code , skill_description | staff : staff_id , staff_name , gender , other_staff_details | assets : asset_id , maintenance_contract_id , supplier_company_id , asset_details , asset_make , asset_model , asset_acquired_date , asset_disposed_date , other_asset_details | asset_parts : asset_id , part_id | maintenance_engineers : engineer_id , company_id , first_name , last_name , other_details | engineer_skills : engineer_id , skill_id | fault_log : fault_log_entry_id , asset_id , recorded_by_staff_id , fault_log_entry_datetime , fault_description , other_fault_details | engineer_visits : engineer_vi

In [461]:
_ex.keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out'])

In [455]:
len(mt_uskg.tokenizer.tokenize(processed_spider_train[3086]['struct_in']))

497

### Helpers
- Aspects-related helpers are merged into create_analysis_sample_dicts()

#### Utils

In [21]:
def exp6_ob_by_exp_tok(samples):
    # samples: usually `good_samples`
    
    # Key: (expect_tok, sect_k, layer) -> [scores]
    trace_scores_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    trace_scores_avg_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
    trace_scores_cnt_by_exp_tok = defaultdict(int)  # no sect key & layer key 

    trace_sample_ids_by_exp_tok = defaultdict(list)
    
    for i, d in enumerate(samples):
        expect = d['expect']
        trace_sample_ids_by_exp_tok[expect].append(i)
        for sect_k, sect_d in d['trace_scores'].items():
            for layer_k, v in sect_d.items():
                trace_scores_by_exp_tok[expect][sect_k][layer_k].append(v)

    for exp_tok, d1 in trace_scores_by_exp_tok.items():
        if exp_tok.isnumeric(): continue
        for sect_k, d2 in d1.items():
            for layer_k, scores in d2.items():
                if len(scores) <= 2: continue
                trace_scores_avg_by_exp_tok[exp_tok][sect_k][layer_k] = np.mean(scores)
                trace_scores_cnt_by_exp_tok[exp_tok] = len(scores)
    
    return {
        'avg': trace_scores_avg_by_exp_tok,
        'cnt': trace_scores_cnt_by_exp_tok,
        'sample_ids': trace_sample_ids_by_exp_tok,
    }

In [22]:
def reverse_2D_dict(d):
    out_d = defaultdict(lambda: defaultdict(np.nan))
    for k1, d1 in d.items():
        for k2, v in d1.items():
            out_d[k2][k1] = v
    return out_d

def format_print_1D_dict(d, sort_by=None, reverse=False, head_col_w=10, col_w=6):
    # sort: None, 'key' or 'value'
    
    item_l = list(d.items())
    if sort_by == 'key':
        item_l.sort(reverse=reverse)
    elif sort_by == 'value':
        item_l.sort(key=lambda x: (x[1], x[0]), reverse=reverse)
    
    decm_w = col_w - 2
    
    for k, v in item_l:
        print(f'{k:<{head_col_w}s}{v:.{decm_w}f}')

def format_print_2D_dict(d, 
                         all_k1=None, 
                         all_k2=None, 
                         sort_k1_kwargs=None, 
                         sort_k2_kwargs=None, 
                         head_col_w=12, 
                         col_w=6,
                         decm_w=4):
    if all_k1 is None:
        all_k1 = list(d.keys())
        if sort_k1_kwargs is not None:
            all_k1.sort(**sort_k1_kwargs)
    
    if all_k2 is None:
        for k1, d1 in d.items():
            d1_keys = list(d1.keys())
            if all_k2 is None:
                all_k2 = d1_keys
            else:
                if set(d1_keys) != set(all_k2):
                    print('Warning:\n', d1_keys, '\n', all_k2)
            # all_k2.update(list(d1.keys()))
        if sort_k2_kwargs is not None:
            all_k2.sort(**sort_k2_kwargs)
    
    print_str = '\t'.join(['X' * head_col_w] + [f'{k2:<{col_w}s}' for k2 in all_k2]) + '\n'
    
    for k1 in all_k1:
        d1 = d[k1]
        print_str += f'{k1:<{head_col_w}s}'
        for k2 in all_k2:
            v = d1[k2]
            print_str += f'\t{v:<{col_w}.{decm_w}f}'
        print_str += '\n'
    
    print(print_str)

#### Evaluator

In [511]:
table_path = '/home/yshao/Projects/language/language/xsp/data/spider/tables.json'
db_dir = '/home/yshao/Projects/language/language/xsp/data/spider/database'

In [512]:
kmaps = sp_eval.build_foreign_key_map_from_json(table_path)
evaluator = sp_eval.Evaluator(db_dir=db_dir, kmaps=kmaps, etype='all')

In [513]:
ctu.evaluate_hardness.evaluator = evaluator

In [514]:
# test
_sql_str = 'select t1.birth_date from people as t1 join poker_player as t2 on t1.people_id = t2.people_id order by t2.earnings asc limit 1'
db_name = 'poker_player'
schema = evaluator.schemas[db_name]
_sql = sp_eval.get_sql(schema, _sql_str)
sp_eval.count_component1(_sql), sp_eval.count_component2(_sql), sp_eval.count_others(_sql), \
evaluator.eval_hardness(_sql)

(3, 0, 0, 'hard')

#### Hardness

In [27]:
ctu.evaluate_hardness(_sql_str, db_name, evaluator=evaluator)

'hard'

In [28]:
ctu.evaluate_hardness.evaluator

<uskg.third_party.spider.evaluation.Evaluator at 0x7f81c83e1f40>

#### Node role

In [None]:
dec_prompt = 'select avg(age), min(age), max(age) from'
ctu.detect_node_role(dec_prompt)

#### Text match

In [None]:
a_dicts = ctu.create_analysis_sample_dicts(
    mt=mt_uskg,
    ex=processed_spider_dev[100],
    subject_type='table'
)
len(a_dicts), [d['expect'] for d in a_dicts]

In [None]:
a_ex = a_dicts[2]
ctu.check_table_text_match(a_ex, 'car_names')

In [None]:
a_ex['text_in']

### Exp-2.3: section corruption effect

#### Load & Check

In [720]:
expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp2.3_section_corruption_effect/exp=2.3.1_dev_{expect_type}-replace=True-noise=0.0.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [721]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [722]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d.get('is_good_sample', True):
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(2039, (1438, 1438), 339, 262, 601, 'good / correct = 1438 / 1700')

#### Overall avg

In [723]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            trace_scores_avg[sect_k][k] += v

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [724]:
trace_scores_avg

{'text': defaultdict(int,
             {'embed': 0.7823026149835357, 'final_enc': 0.8686982057318997}),
 'struct': defaultdict(int,
             {'embed': 0.8156400979291917, 'final_enc': 0.8783989991454734}),
 'self': defaultdict(int,
             {'embed': 0.9009977694993451, 'final_enc': 0.9186035738517472}),
 'struct_context': defaultdict(int,
             {'embed': 0.8188476237862654, 'final_enc': 0.9366639298185229}),
 'other': defaultdict(int,
             {'embed': 0.9668801645956868, 'final_enc': 0.9874240734377392}),
 'text+other': defaultdict(int,
             {'embed': 0.7898215747748188, 'final_enc': 0.8762273743275982}),
 'all': defaultdict(int,
             {'embed': 0.05134222172979076, 'final_enc': 0.7129756825101383})}

In [725]:
format_print_2D_dict(trace_scores_avg)

XXXXXXXXXXXX	embed 	final_enc
text        	0.7823	0.8687
struct      	0.8156	0.8784
self        	0.9010	0.9186
struct_context	0.8188	0.9367
other       	0.9669	0.9874
text+other  	0.7898	0.8762
all         	0.0513	0.7130



### Exp-4.1: attention weights distribution for all nodes

#### Load & Check results

In [550]:
## Taking too much memory...

# res_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_train.jsonl'

# with open(res_path, 'r') as f:
#     all_train_samples = [ujson.loads(l) for l in tqdm(f)]
# len(all_train_samples)

In [551]:
# res_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_dev.jsonl'

# with open(res_path, 'r') as f:
#     all_dev_samples = [ujson.loads(l) for l in tqdm(f)]
# len(all_dev_samples)

In [552]:
# _toks = mt_uskg.tokenizer.tokenize(all_samples[100]['trace_results']['enc_sentence'], add_special_tokens=True)
# len(_toks), _toks[136:140]

In [553]:
res_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_dev.jsonl'
samples_N = len(processed_spider_dev)

# check one sample (ex)
with open(res_path, 'r') as f:
    for l in f:
        ex = ujson.loads(l)
        break

In [554]:
ex.keys(), ex['trace_results'].keys()

(dict_keys(['ex_id', 'trace_results']),
 dict_keys(['enc_sentence', 'seq_out', 'dec_prompt', 'db_id', 'col_self_ranges', 'col_context_ranges', 'tab_self_ranges', 'tab_context_ranges', 'category', 'occ_cols', 'non_occ_cols', 'occ_tabs', 'non_occ_tabs', 'attentions']))

In [555]:
for k, v in ex['trace_results'].items():
    if k != 'attentions':
        print(k, ':', v)

enc_sentence : How many singers do we have?; structed knowledge: | concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id
seq_out : select count(*) from singer
dec_prompt : select
db_id : concert_singer
col_self_ranges : {'location': [[28, 33]], 'capacity': [[34, 39]], 'highest': [[37, 42]], 'lowest': [[40, 45]], 'average': [[43, 47]], 'country': [[57, 62]], 'song_name': [[60, 67]], 'song_release_year': [[65, 74]], 'age': [[72, 77]], 'is_male': [[75, 81]], 'concert_name': [[88, 95]], 'theme': [[93, 98]], 'year': [[102, 106]]}
col_context_ranges : {'location': [[15, 28], [33, 125]], 'capacity': [[15, 34], [39, 125]], 'highest': [[15, 37], [42, 125]], 'lowest': [[15, 40], [45, 125]], 'average': [[15, 43], [47, 125]], 'country': [[15, 57], [62, 125]], 'so

In [None]:
# Dict: {layer -> {head_id -> {occ_type -> {section -> List[att_w]}}}}; list for all samples, all nodes in each occ_type 
# Perhaps too large...
# att_weights_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
att_weights_sum_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))

# Dict: {occ_type -> int}
att_weights_cnt_dict = defaultdict(int)

# Dict: {layer -> {head_id -> {occ_type -> {section -> avg_att_w}}}}; averaged by all samples, all nodes in each occ_type 
att_weights_avg_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))

# Dict: {occ_type -> List[Tuple(ex_id, node_name)]}
sample_backtrace_dict = defaultdict(list)


with open(res_path, 'r') as f:
    for l in tqdm(f, total=samples_N):
        ex = ujson.loads(l)
        ex_id = ex['ex_id']

        if 'err_msg' in ex:
            continue

        result_d = ex['trace_results']

        all_col_atts = result_d['attentions']['col']
        all_tab_atts = result_d['attentions']['tab']

        for col_occ_type in ['occ_cols', 'non_occ_cols']:
            for col in result_d[col_occ_type]:
                sect_att_dict = all_col_atts[col]
                sample_backtrace_dict[col_occ_type].append((ex_id, col))
                att_weights_cnt_dict[col_occ_type] += 1

                for sect_k, att_mat in sect_att_dict.items():
                    att_mat = np.array(ctu.nested_list_processing(att_mat, func=float))
                    n_layers, n_heads = att_mat.shape
                    for l in range(n_layers):
                        for h in range(n_heads):
                            att_w = att_mat[l, h]
                            # att_weights_dict[l][h][col_occ_type][sect_k].append(att_w)
                            att_weights_sum_dict[l][h][col_occ_type][sect_k] += att_w

        for tab_occ_type in ['occ_tabs', 'non_occ_tabs']:
            for tab in result_d[tab_occ_type]:
                sect_att_dict = all_tab_atts[tab]
                sample_backtrace_dict[tab_occ_type].append((ex_id, tab))
                att_weights_cnt_dict[tab_occ_type] += 1

                for sect_k, att_mat in sect_att_dict.items():
                    att_mat = np.array(ctu.nested_list_processing(att_mat, func=float))
                    n_layers, n_heads = att_mat.shape
                    for l in range(n_layers):
                        for h in range(n_heads):
                            att_w = att_mat[l, h]
                            # att_weights_dict[l][h][tab_occ_type][sect_k].append(att_w)
                            att_weights_sum_dict[l][h][tab_occ_type][sect_k] += att_w

In [557]:
for l_id, layer_d in att_weights_sum_dict.items():
    for h_id, head_d in layer_d.items():
        for occ_type, occ_type_d in head_d.items():
            att_w_cnt = att_weights_cnt_dict[occ_type]
            for sect_k, att_w_sum in occ_type_d.items():
                att_weights_avg_dict[l_id][h_id][occ_type][sect_k] = att_w_sum / att_w_cnt

In [558]:
att_weights_avg_dict[0][0]

defaultdict(<function __main__.<lambda>.<locals>.<lambda>.<locals>.<lambda>()>,
            {'non_occ_cols': defaultdict(float,
                         {'prefix#0': 0.00506160506160524,
                          'prefix#1': 0.018118548118547986,
                          'prefix#2': 0.006291042291042455,
                          'prefix#3': 0.00046553446553446523,
                          'prefix#4': 0.049131535131533094,
                          'prefix#5': 0.03181684981684876,
                          'prefix#6': 0.01746253746253744,
                          'prefix#7': 0.0016443556443556857,
                          'prefix#8': 8.52480852480853e-05,
                          'prefix#9': 0.016844488844488335,
                          'text': 0.016877122877122598,
                          'self': 0.30191874791874695,
                          'context': 0.5137555777555907,
                          'others': 0.014923076923077245}),
             'occ_tabs': defaultdict(float,


#### Dump results 
- do not rerun unless updated

In [559]:
# res_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_dev_dump.pkl'

# dump_d = {
#     'att_weights_sum_dict': ctu.nested_json_processing(att_weights_sum_dict, func=lambda x: x),   # defaultdict -> dict 
#     'att_weights_cnt_dict': ctu.nested_json_processing(att_weights_cnt_dict, func=lambda x: x),
#     'att_weights_avg_dict': ctu.nested_json_processing(att_weights_avg_dict, func=lambda x: x),
#     'sample_backtrace_dict': ctu.nested_json_processing(sample_backtrace_dict, func=lambda x: x),
# }

# with open(res_path, 'wb') as f:
#     pickle.dump(dump_d, f)

#### Load results 

In [564]:
res_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_train_dump.pkl'

with open(res_path, 'rb') as f:
    dump_d = pickle.load(f)

att_weights_sum_dict = dump_d['att_weights_sum_dict']
att_weights_cnt_dict = dump_d['att_weights_cnt_dict']
att_weights_avg_dict = dump_d['att_weights_avg_dict']
sample_backtrace_dict = dump_d['sample_backtrace_dict']

In [565]:
att_weights_avg_dict[0][0]

{'occ_cols': {'prefix#0': 0.005590211530485306,
  'prefix#1': 0.038914143508916546,
  'prefix#2': 0.011377021982580325,
  'prefix#3': 0.00034840315221899436,
  'prefix#4': 0.021940273745333354,
  'prefix#5': 0.04742015761094831,
  'prefix#6': 0.01801990875155525,
  'prefix#7': 0.0029879717959352893,
  'prefix#8': 6.470344255495649e-05,
  'prefix#9': 0.028151804230609197,
  'text': 0.028184985483199768,
  'self': 0.29681045209456924,
  'context': 0.4812252177519773,
  'others': 0.013333886354210116},
 'non_occ_cols': {'prefix#0': 0.007513540287063308,
  'prefix#1': 0.014389637238746093,
  'prefix#2': 0.007340073865897553,
  'prefix#3': 0.0003575927899990462,
  'prefix#4': 0.017249625825066686,
  'prefix#5': 0.049186182985021795,
  'prefix#6': 0.02434915097285215,
  'prefix#7': 0.002757480806710441,
  'prefix#8': 0.0001255504947723164,
  'prefix#9': 0.012217161438983645,
  'text': 0.016699938624532496,
  'self': 0.29822700304725547,
  'context': 0.5328514821633931,
  'others': 0.01090135

In [None]:
occ_types = ['occ_cols', 'non_occ_cols']

for l_id in [1, 6, 12, 18, 23]:
    # Dict: occ_type -> sect -> List[att_w]; list for all heads 
    layer_ob_dict = defaultdict(lambda: defaultdict(list))

    for h_id in range(len(att_weights_avg_dict[l_id])):
        for occ in occ_types:
            for sect_k, att_w in att_weights_avg_dict[l_id][h_id][occ].items():
                att_w_str = np.format_float_positional(att_w, precision=2, min_digits=2)
                layer_ob_dict[occ][sect_k].append(att_w_str)
    
    print(f'===== Layer {l_id} =====')
    for occ in occ_types:
        print(occ)
        for sect_k, att_w_list in layer_ob_dict[occ].items():
            att_w_list_str = "  ".join(att_w_list)
            print(f'{sect_k:<10s}{att_w_list_str}')
        print()

In [563]:
len(att_weights_avg_dict[0])

16

#### Probing attn - data

In [567]:
train_res_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_train_dump.pkl'
dev_res_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_dev_dump.pkl'

with open(train_res_path, 'rb') as f:
    train_dump_d = pickle.load(f)

with open(dev_res_path, 'rb') as f:
    dev_dump_d = pickle.load(f)

# att_weights_sum_dict = train_dump_d['att_weights_sum_dict']
# att_weights_cnt_dict = train_dump_d['att_weights_cnt_dict']
# att_weights_avg_dict = train_dump_d['att_weights_avg_dict']
# sample_backtrace_dict = train_dump_d['sample_backtrace_dict']

In [585]:
# Dict: layer -> head -> sect_k -> heuristic diff val 
heu_diff_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
heu_diff_flat_dict = dict()

occ_cls_A = 'occ_cols'
occ_cls_B = 'non_occ_cols'

def heu_diff_func(a, b):
    return (a + 0.1) / (b + 0.1) - (b + 0.1) / (a + 0.1) + a - b

# Dict: layer -> head -> occ -> sect_k -> att_w_avg 
train_att_weights_avg_dict = train_dump_d['att_weights_avg_dict']

for l_id, layer_d in train_att_weights_avg_dict.items():
    for h_id, head_d in layer_d.items():
        for sect_k, _ in head_d[occ_cls_A].items():
            vA = train_att_weights_avg_dict[l_id][h_id][occ_cls_A][sect_k]
            vB = train_att_weights_avg_dict[l_id][h_id][occ_cls_B][sect_k]
            heu_diff = heu_diff_func(vA, vB)
            heu_diff_dict[l_id][h_id][sect_k] = heu_diff
            heu_diff_flat_dict[f'L{l_id}-H{h_id}-{sect_k}'] = heu_diff

In [586]:
len(heu_diff_flat_dict)

5376

In [598]:
use_feats_info = []
feat_name2tuple = dict()

for l_id, layer_d in heu_diff_dict.items():
    for h_id, head_d in layer_d.items():
        for sect_k, v in head_d.items():
            if abs(v) > 2.0:
                feat_name = f'L{l_id}-H{h_id}-{sect_k}'
                use_feats_info.append((feat_name, v))
                feat_name2tuple[feat_name] = (l_id, h_id, sect_k)

len(use_feats_dict)

99

In [601]:
use_feats_info_dict = {
    'feats': [feat for feat, v in use_feats_info],
    'feat_vals': use_feats_info,
    'feat_name2tuple': feat_name2tuple,
}

In [602]:
probe_out_dir = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/probing/{occ_cls_A}-vs-{occ_cls_B}'
os.makedirs(probe_out_dir, exist_ok=True)

In [603]:
# use_feats_info_path = os.path.join(probe_out_dir, 'use_feats.json')

# with open(use_feats_info_path, 'w') as f:
#     json.dump(use_feats_info_dict, f, indent=2)

In [604]:
# Load feats meta info
probe_out_dir = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/probing/{occ_cls_A}-vs-{occ_cls_B}'
use_feats_info_path = os.path.join(probe_out_dir, 'use_feats.json')

with open(use_feats_info_path, 'r') as f:
    use_feats_info_dict = json.load(f)

feats = use_feats_info_dict['feats']
feat_name2tuple = use_feats_info_dict['feat_name2tuple']

len(feats), len(feat_name2tuple)

(99, 99)

In [611]:
cls2id = {
    'occ_cols': 0,
    'non_occ_cols': 1,
}

In [732]:
## Collect X and y from raw output (each node is a sample) 

train_raw_res_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_train.jsonl'
samples_N = 7000

train_X = []
train_y = []

with open(train_raw_res_path, 'r') as f:
    for l in tqdm(f, total=samples_N):
        ex = ujson.loads(l)
        ex_id = ex['ex_id']

        if 'err_msg' in ex:
            continue

        result_d = ex['trace_results']

        all_col_atts = result_d['attentions']['col']
#         all_tab_atts = result_d['attentions']['tab']

        for occ_cls, cls_id in cls2id.items():
            for col in list(set(result_d[occ_cls])):   # remove duplicates here 
                sect_att_dict = all_col_atts[col]
                feat_vec = []
                for feat_name in feats:
                    l_id, h_id, sect_k = feat_name2tuple[feat_name]
                    feat_val = float(sect_att_dict[sect_k][l_id][h_id])
                    feat_vec.append(feat_val)
                
                train_X.append(feat_vec)
                train_y.append(cls_id)

len(train_X), len(train_y), len(train_X[0])

  0%|          | 0/7000 [00:00<?, ?it/s]

(102599, 102599, 99)

In [733]:
## save X, y
with open(os.path.join(probe_out_dir, 'train_X.pkl'), 'wb') as f:
    pickle.dump(train_X, f)
with open(os.path.join(probe_out_dir, 'train_y.pkl'), 'wb') as f:
    pickle.dump(train_y, f)

In [727]:
dev_raw_res_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_dev.jsonl'
samples_N = 1034

test_X = []
test_y = []

with open(dev_raw_res_path, 'r') as f:
    for l in tqdm(f, total=samples_N):
        ex = ujson.loads(l)
        ex_id = ex['ex_id']

        if 'err_msg' in ex:
            continue

        result_d = ex['trace_results']

        all_col_atts = result_d['attentions']['col']
#         all_tab_atts = result_d['attentions']['tab']

        for occ_cls, cls_id in cls2id.items():
            for col in list(set(result_d[occ_cls])):   # remove duplicates here 
                sect_att_dict = all_col_atts[col]
                feat_vec = []
                for feat_name in feats:
                    l_id, h_id, sect_k = feat_name2tuple[feat_name]
                    feat_val = float(sect_att_dict[sect_k][l_id][h_id])
                    feat_vec.append(feat_val)
                
                test_X.append(feat_vec)
                test_y.append(cls_id)

len(test_X), len(test_y), len(test_X[0])

  0%|          | 0/1034 [00:00<?, ?it/s]

(16640, 16640, 99)

In [731]:
## save X, y
with open(os.path.join(probe_out_dir, 'test_X.pkl'), 'wb') as f:
    pickle.dump(test_X, f)
with open(os.path.join(probe_out_dir, 'test_y.pkl'), 'wb') as f:
    pickle.dump(test_y, f)

In [730]:
## Have duplicates? - no
_check_ids = [i for i in range(len(test_y)) if test_y[i] == 0][:10]
np.array(test_X)[_check_ids, :10]

array([[0.78, 0.58, 0.96, 0.89, 0.98, 0.88, 0.52, 0.95, 0.97, 0.83],
       [0.92, 0.69, 0.96, 0.83, 0.45, 0.73, 0.65, 0.96, 1.  , 0.36],
       [0.63, 0.44, 0.91, 0.82, 0.97, 0.68, 0.52, 0.98, 0.98, 0.74],
       [0.04, 0.27, 0.  , 0.87, 0.32, 0.43, 0.47, 0.96, 1.  , 0.17],
       [0.64, 0.42, 0.94, 0.91, 0.97, 0.69, 0.34, 0.98, 0.88, 0.48],
       [0.  , 0.12, 0.  , 0.17, 0.06, 0.13, 0.19, 0.32, 0.01, 0.04],
       [0.63, 0.43, 0.95, 0.89, 0.98, 0.66, 0.4 , 0.98, 0.93, 0.55],
       [0.  , 0.12, 0.  , 0.42, 0.07, 0.13, 0.16, 0.19, 0.01, 0.01],
       [0.06, 0.15, 0.  , 0.24, 0.17, 0.49, 0.02, 0.12, 0.12, 0.23],
       [0.97, 0.3 , 0.48, 0.43, 0.14, 0.44, 0.19, 0.06, 0.84, 0.37]])

#### Probing attn - training

In [794]:
with open(os.path.join(probe_out_dir, 'train_X.pkl'), 'rb') as f:
    train_X = pickle.load(f)
with open(os.path.join(probe_out_dir, 'train_y.pkl'), 'rb') as f:
    train_y = pickle.load(f)
with open(os.path.join(probe_out_dir, 'test_X.pkl'), 'rb') as f:
    test_X = pickle.load(f)
with open(os.path.join(probe_out_dir, 'test_y.pkl'), 'rb') as f:
    test_y = pickle.load(f)

In [795]:
len(train_X), len(train_y), len(train_X[0]), \
len(test_X), len(test_y), len(test_X[0])

(102599, 102599, 99, 16640, 16640, 99)

In [796]:
# # only use a subset of feats? - no improve, as expected... 

# sub_f_ids = []
# for f_i, (feat_name, feat_val) in enumerate(use_feats_info):
#     if abs(feat_val) > 4:
#         sub_f_ids.append(f_i)

# [use_feats_info[i] for i in sub_f_ids]

In [797]:
# train_X = [[x[i] for i in sub_f_ids] for x in train_X]
# test_X = [[x[i] for i in sub_f_ids] for x in test_X]
# np.array(train_X).shape, np.array(test_X).shape

In [798]:
C_val = 1.0

# clf = LogisticRegression(C=1.0)

logreg = LogisticRegression(C=C_val, max_iter=1000)
scaler = StandardScaler()
clf = Pipeline([
    ('scaler', scaler),
    ('logreg', logreg)
])

In [799]:
clf

Pipeline(steps=[('scaler', StandardScaler()),
                ('logreg', LogisticRegression(max_iter=1000))])

In [800]:
clf.fit(train_X, train_y)

Pipeline(steps=[('scaler', StandardScaler()),
                ('logreg', LogisticRegression(max_iter=1000))])

In [801]:
# with open(os.path.join(probe_out_dir, f'trained_scale_clf_C={C_val}.pkl'), 'wb') as f:
#     pickle.dump(clf, f)

In [802]:
def compute_results(preds, labels, f1_cls):
    """f1_cls: the class to compute F1 score on"""
    assert len(preds) == len(labels), (len(preds), len(labels))
    
    N = len(preds)
    N_pred = sum([p == f1_cls for p in preds])
    N_true = sum([y == f1_cls for y in labels])
    N_pred_true = sum([p == y == f1_cls for p, y in zip(preds, labels)])
    N_corr = sum([p == y for p, y in zip(preds, labels)])
    
    Acc = N_corr / N
    P = N_pred_true / N_pred
    R = N_pred_true / N_true
    F1 = (2 * P * R) / (P + R + 1e-9)
    
    res_dict = {
        'Acc': Acc,
        'P': P,
        'R': R,
        'F1': F1,
    }
    
    for k, v in res_dict.items():
        print(f'{k}\t{v:.4f}')
    
    return res_dict
    

In [803]:
logreg.coef_

array([[-0.09639597,  0.01532772,  0.07832486, -0.15533311, -0.01392892,
        -0.22638873, -0.14406257,  0.09896343,  0.13709107, -0.0838726 ,
         0.13207204, -0.07925985, -0.03770774, -0.00595991,  0.0758152 ,
         0.01066828,  0.07825765,  0.16469661,  0.05137414,  0.31637049,
         0.10691322, -0.69427748, -0.03884708,  0.03260927,  0.11772163,
         0.34440898, -0.02701317,  0.36173607, -0.095233  , -0.04127651,
        -0.02898593, -0.61401068,  0.04967187, -0.02934451,  0.25257069,
        -0.69791418, -0.29124649, -0.10762437,  0.33133284,  0.08670616,
        -0.35429363,  0.18354341, -0.07732019,  0.43660091,  0.12727427,
         0.77473455, -0.00536978, -0.18288719,  0.30296189, -0.02659497,
        -0.12644743, -0.28158617,  0.26223986,  0.12994199,  0.2896086 ,
         0.51471449, -0.3773104 , -0.1979374 , -0.28664154,  1.07788551,
        -0.14624692,  0.58240512,  0.9270052 , -0.55910419,  0.13400262,
         0.1038081 ,  1.13520425, -0.05669287,  0.1

In [804]:
train_preds = clf.predict(train_X)
len(train_preds), sum(train_preds)

(102599, 93170)

In [805]:
compute_results(train_preds, train_y, f1_cls=0)

Acc	0.9829
P	0.9230
R	0.8946
F1	0.9086


{'Acc': 0.9829335568572793,
 'P': 0.9230034998409163,
 'R': 0.8946340460526315,
 'F1': 0.9085973790480677}

In [806]:
test_preds = clf.predict(test_X)
len(test_preds), sum(test_preds)

(16640, 15040)

In [807]:
compute_results(test_preds, test_y, f1_cls=0)

Acc	0.9729
P	0.8669
R	0.8535
F1	0.8602


{'Acc': 0.9728966346153847,
 'P': 0.866875,
 'R': 0.8535384615384616,
 'F1': 0.86015503825972}

In [808]:
len(train_y) - sum(train_y), len(test_y) - sum(test_y)

(9728, 1625)

### Exp-5.0: dirty attention vector effect 

#### Load & Check results

In [505]:
expect_type = 'column'
res_path = f'/home/yshao/Projects/rome/results/exp5_0_dirty_attention_vector_effect/exp=5_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [506]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [507]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1
#         # TEMP adjustment for column results 
#         d['low_score'] = d['trace_scores']['high_layers_corrupt'].get("0", 0.0)  # "0" is key (for layer 0), 0.0 is default 
#         if d['base_score'] - d['low_score'] < 0.5:
#             d['is_good_sample'] = False
#         # END_TEMP
        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5, (i, d)
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
n_good_samples + n_too_easy

(2002, (867, 867), 566, 569, 1135, 1436)

In [None]:
[s for s in bad_samples if s['correct_prediction']][0]

In [None]:
good_samples[0]

#### Overall avg

In [508]:
trace_scores_avg = {k: {str(l): 0 for l in range(24)} for k in good_samples[0]['trace_scores'].keys()}

In [509]:
for d in good_samples:
    for k, layer_d in d['trace_scores'].items():
        for l, s in layer_d.items():
            trace_scores_avg[k][l] += s

for k, layer_d in trace_scores_avg.items():
    for l, s in layer_d.items():
        layer_d[l] = s / len(good_samples)

In [None]:
trace_scores_avg

#### Avg by aspects (category)
- Still kind of linear, as in exp-2.2

In [243]:
d['category']

{'sql_hardness': 'hard', 'node_role': 'where', 'text_match': 'partial'}

In [265]:
# Key: (trace_key, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no trace key & layer key 

In [266]:
for d in good_samples:
    for trace_k, trace_layer_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for l, s in trace_layer_d.items():
                trace_scores_by_aspect[trace_k][aspect][asp_val][l].append(s)

for trace_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for l, s in d3.items():
                trace_scores_avg_by_aspect[trace_k][asp_k][asp_v][l] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [267]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 400,
                          'hard': 155,
                          'easy': 142,
                          'extra': 170}),
             'node_role': defaultdict(int,
                         {'where': 268,
                          'select': 412,
                          'order by': 66,
                          'join': 92,
                          'group by': 25,
                          'having': 4}),
             'text_match': defaultdict(int,
                         {'no-match': 360, 'partial': 148, 'exact': 359})})

In [None]:
trace_scores_avg_by_aspect['high_layers_corrupt']

### Exp-5.2: attention section removal effect

#### Load & Check

In [384]:
expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [385]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [386]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples),\
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(2039, (168, 168), 339, 1532, 1871, 'good / correct = 168 / 1700')

In [387]:
[s for s in bad_samples if not s['correct_prediction']][0]

{'enc_sentence': 'Show the stadium name and capacity with most number of concerts in year 2014 or after.; structed knowledge: | concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id',
 'seq_out': 'select t2.name, t2.capacity from concert as t1 join stadium as t2 on t1.stadium_id = t2.stadium_id where t1.year >= 2014 group by t2.stadium_id order by count(*) desc limit 1',
 'dec_prompt': 'select t2.name, t2.capacity from concert as t1 join stadium as t2 on t1.stadium_id = t2.stadium_id where t1.year >= 2014 group by',
 'expect': 't2.',
 'expect_type': 'table_alias',
 'db_id': 'concert_singer',
 'expect_input_ranges': [[30, 31]],
 'answer': 't1.',
 'base_score': 0.9849615097045898,
 'answers_t': [3, 17, 5411],
 'correct_prediction': False,
 'category': {

#### Overall avg

In [388]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            if k == 'window':
                for l, s in v.items():
                    if int(l) % 4 != 3: continue
                    trace_scores_avg[sect_k][f'{k}-{l}'] += s
            else:
                s = v
                trace_scores_avg[sect_k][k] += s

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [389]:
trace_scores_avg

{'prefix': defaultdict(int,
             {'window-3': 0.9472156441992238,
              'window-7': 0.9222004726706516,
              'window-11': 0.9004831596948428,
              'window-15': 0.6439282640543488,
              'window-19': 0.5357348926681949,
              'window-23': 0.6425308727503227,
              'first_layer': 0.9484615538801465,
              'last_layer': 0.873940426016426,
              'all_layers': 0.2498462739990746}),
 'text': defaultdict(int,
             {'window-3': 0.947731566571054,
              'window-7': 0.9410505436715626,
              'window-11': 0.9183367850879828,
              'window-15': 0.845033811380225,
              'window-19': 0.8540408059165029,
              'window-23': 0.884705040109111,
              'first_layer': 0.9480909219100362,
              'last_layer': 0.9342825717869259,
              'all_layers': 0.7430854112239325}),
 'struct': defaultdict(int,
             {'window-3': 0.9109604747722014,
              'window-

In [390]:
format_print_2D_dict(trace_scores_avg, col_w=12)

XXXXXXXXXXXX	window-3    	window-7    	window-11   	window-15   	window-19   	window-23   	first_layer 	last_layer  	all_layers  
prefix      	0.9472      	0.9222      	0.9005      	0.6439      	0.5357      	0.6425      	0.9485      	0.8739      	0.2498      
text        	0.9477      	0.9411      	0.9183      	0.8450      	0.8540      	0.8847      	0.9481      	0.9343      	0.7431      
struct      	0.9110      	0.8727      	0.8174      	0.7166      	0.4618      	0.6287      	0.9485      	0.9440      	0.1263      
text+struct 	0.9083      	0.8677      	0.7925      	0.6649      	0.3975      	0.5780      	0.9479      	0.9306      	0.1069      
all         	0.8622      	0.7481      	0.6102      	0.3438      	0.2692      	0.4050      	0.9510      	0.8428      	0.0669      
self        	0.9370      	0.9156      	0.8692      	0.8039      	0.8306      	0.9212      	0.9488      	0.9441      	0.5128      
struct_context	0.9341      	0.9235      	0.9001      	0.7936      	0.4954      	0.6496    

#### Avg by aspects (category)

In [696]:
# TEMP patch for node_len category 
for d in good_samples + bad_samples:
    node_len = len(d['answers_t'])
    assert len(mt_uskg.tokenizer.tokenize(d['expect'])) == node_len, (d['expect'], node_len)
    d['category']['node_len'] = str(node_len) if node_len <= 3 else '4+'

In [701]:
d['category']

{'sql_hardness': 'medium',
 'node_role': 'from',
 'text_match': 'exact',
 'node_len': '1'}

In [651]:
# Key: (sect_k, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no sect key & layer key 

In [652]:
for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for k, v in sect_d.items():
                if k == 'window':
                    for l, s in v.items():
                        if not (int(l) % 4 == 3): continue
                        layer_k = f'{k}-{l}'
                        trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                else:
                    layer_k = k
                    s = v
                    trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                    
for sect_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for layer_k, s in d3.items():
                trace_scores_avg_by_aspect[sect_k][asp_k][asp_v][layer_k] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [653]:
for sect_k, sect_d in trace_scores_avg_by_aspect.items():
    sect_d['overall'] = dict()
    for layer_k, s in trace_scores_avg[sect_k].items():
        if layer_k.startswith('window'):
            # only keep a subset of layers 
            _, l = layer_k.split('-')
            if not (int(l) % 4 == 3): continue
        sect_d['overall'][layer_k] = s

In [654]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 378,
                          'hard': 148,
                          'easy': 134,
                          'extra': 160}),
             'node_role': defaultdict(int,
                         {'where': 248,
                          'select': 393,
                          'order by': 63,
                          'join': 91,
                          'group by': 21,
                          'having': 4}),
             'text_match': defaultdict(int,
                         {'no-match': 345, 'partial': 142, 'exact': 333}),
             'node_len': defaultdict(int,
                         {'1': 329, '3': 238, '4+': 160, '2': 93})})

In [None]:
trace_scores_avg_by_aspect['self']

In [641]:
dump_d = ctu.nested_json_processing(trace_scores_avg_by_aspect, func=lambda x: np.format_float_positional(x, precision=4, min_digits=4))
# dump_d

In [642]:
dump_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/summ-exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(dump_path, 'w') as f:
    json.dump(dump_d, f, indent=1)

#### (one-time temp patch)

In [531]:
# expect_type = 'table_alias'
# orig_res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/no_structcontext-exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'
# add_res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/exp=5.2.1+structcontext_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

# merge_res_path = f'/home/yshao/Projects/rome/results/exp5_2_attention_section_removal_effect/exp=5.2.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

In [532]:
# with open(orig_res_path, 'r') as f:
#     orig_all_samples = [json.loads(l) for l in f]
# with open(add_res_path, 'r') as f:
#     add_all_samples = [json.loads(l) for l in f]

# f = open(merge_res_path, 'w')
    
# for i, (orig_ex, add_ex) in enumerate(zip(orig_all_samples, add_all_samples)):
#     assert len(orig_ex['trace_results']) == len(add_ex['trace_results']), i
#     # There is randomness in the order of expected node (from set()), thus sorting here 
#     orig_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     add_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     for j, (orig_d, add_d) in enumerate(zip(orig_ex['trace_results'], add_ex['trace_results'])):
#         assert orig_d['is_good_sample'] == add_d['is_good_sample'], (i, j)
#         if not orig_d['is_good_sample']:
#             continue
            
#         # is good sample: add the new sections 
#         orig_d['trace_scores']['struct_context'] = add_d['trace_scores']['struct_context']
#         orig_d['trace_scores']['text+struct_context'] = add_d['trace_scores']['text+struct_context']
        
#     f.write(json.dumps(orig_ex, indent=None) + '\n')
    
# f.close()

#### Single samples observations

In [702]:
good_samples[0]['trace_scores'].keys()

dict_keys(['prefix', 'text', 'struct', 'text+struct', 'all', 'self', 'struct_context', 'text+struct_context'])

In [703]:
_id = 0

d = good_samples[_id]

check_info_d = defaultdict(dict)

for sect_k, sect_d in d['trace_scores'].items():
    for layer_k, s in sect_d.items():
        if layer_k == 'window':
            layer_k = 'window-19'
            s = s['19']
        if s < 0.5:
            check_info_d[sect_k][layer_k] = s

In [704]:
print(json.dumps(check_info_d, indent=2))

{
  "text+struct": {
    "all_layers": 0.3990614414215088
  },
  "all": {
    "all_layers": 0.4772564172744751
  }
}


##### Layer

In [705]:
### Check "breaking" window layer, i.e. those with sudden changes 
### For now: single layer drop > _th

_th = 0.4
check_info_l = []
for i, d in enumerate(good_samples):
#     for sect_k, sect_d in d['trace_scores'].items():
    sect_k = 'all'
    sect_d = d['trace_scores'][sect_k]
    window_d = sect_d['window']
    for l in range(1, 24):
        if window_d[str(l-1)] - window_d[str(l)] > _th:
            _info_d = {
                'id': i,
                'sect_k': sect_k,
                'layer': l,
                'last_layer_score': window_d[str(l-1)],
                'this_layer_score': window_d[str(l)],
            }
            check_info_l.append(_info_d)
            break
len(check_info_l)

844

In [706]:
len(check_info_l), len(good_samples)

(844, 1207)

In [707]:
break_layer_counter = Counter([_d['layer'] for _d in check_info_l])
sorted(break_layer_counter.items())

[(1, 10),
 (2, 25),
 (3, 31),
 (4, 81),
 (5, 24),
 (6, 3),
 (7, 15),
 (8, 87),
 (9, 69),
 (10, 19),
 (11, 38),
 (12, 36),
 (13, 70),
 (14, 101),
 (15, 27),
 (16, 31),
 (17, 31),
 (18, 79),
 (19, 67)]

In [None]:
for info_d in check_info_l:
    if info_d['layer'] < 6:
        print(info_d)
        sample_id = info_d['id']
        d = good_samples[sample_id]
        print(d['enc_sentence'])
        print(d['dec_prompt'], '---->', d['expect'])
        print('Categories:', d['category'])
        print('--' * 20)

In [713]:
sample_counter_by_aspect = defaultdict(Counter)  # [asp_k, asp_v] -> count 
sample_counter = Counter()

for info_d in check_info_l:
    if info_d['layer'] < 6:
        sample_id = info_d['id']
        d = good_samples[sample_id]
        text_match = d['category']['text_match']
        node_len = d['category']['node_len']
        sample_counter[(text_match, node_len)] += 1
        
        for asp_k, asp_v in d['category'].items():
            sample_counter_by_aspect[asp_k][asp_v] += 1

In [715]:
sample_counter_by_aspect

defaultdict(collections.Counter,
            {'sql_hardness': Counter({'medium': 58,
                      'hard': 42,
                      'extra': 51,
                      'easy': 20}),
             'node_role': Counter({'from': 105, 'join': 66}),
             'text_match': Counter({'partial': 59,
                      'no-match': 74,
                      'exact': 38}),
             'node_len': Counter({'4+': 60, '3': 51, '2': 24, '1': 36})})

In [716]:
sample_counter.most_common()

[(('partial', '3'), 30),
 (('partial', '4+'), 29),
 (('exact', '1'), 22),
 (('no-match', '3'), 21),
 (('no-match', '4+'), 21),
 (('no-match', '2'), 18),
 (('no-match', '1'), 14),
 (('exact', '4+'), 10),
 (('exact', '2'), 6)]

In [717]:
sample_counter_by_aspect = defaultdict(Counter)  # [asp_k, asp_v] -> count 
sample_counter = Counter()

for info_d in check_info_l:
    if info_d['layer'] > 18:
        sample_id = info_d['id']
        d = good_samples[sample_id]
        text_match = d['category']['text_match']
        node_len = d['category']['node_len']
        sample_counter[(text_match, node_len)] += 1
        
        for asp_k, asp_v in d['category'].items():
            sample_counter_by_aspect[asp_k][asp_v] += 1

In [718]:
sample_counter_by_aspect

defaultdict(collections.Counter,
            {'sql_hardness': Counter({'hard': 9,
                      'medium': 27,
                      'easy': 11,
                      'extra': 20}),
             'node_role': Counter({'join': 21, 'from': 46}),
             'text_match': Counter({'exact': 46,
                      'no-match': 20,
                      'partial': 1}),
             'node_len': Counter({'3': 12, '1': 46, '2': 6, '4+': 3})})

In [719]:
sample_counter.most_common()

[(('exact', '1'), 38),
 (('no-match', '3'), 9),
 (('no-match', '1'), 8),
 (('exact', '3'), 3),
 (('no-match', '2'), 3),
 (('exact', '2'), 3),
 (('exact', '4+'), 2),
 (('partial', '4+'), 1)]

In [743]:
# _min_p = 1.0

# for i, d in enumerate(good_samples):
#     sect_k = 'text'
#     sect_d = d['trace_scores'][sect_k]
#     _min_p = min(_min_p, sect_d['all_layers'])

# _min_p

7.385318918917694e-12

In [751]:
### Systematic 

ob_sect_k = 'all'

_th = 0.4

check_info_l = []
all_layers_eff_cnt = 0  # for this section to observe, how many samples are effective with all_layers
window_eff_cnt = 0      # for this section to observe, how many samples are effective with any window 

for i, d in enumerate(good_samples):
#     for sect_k, sect_d in d['trace_scores'].items():
    sect_k = ob_sect_k
    sect_d = d['trace_scores'][sect_k]
    if sect_d['all_layers'] > 0.5:
        # not effective
        continue
    else:
        all_layers_eff_cnt += 1
        
    if min(sect_d['window'].values()) > 0.5:
        # not effective
        continue
    else:
        window_eff_cnt += 1
        
    window_d = sect_d['window']
    for l in range(1, 24):
        if window_d[str(l-1)] - window_d[str(l)] > _th:
            _info_d = {
                'id': i,
                'sect_k': sect_k,
                'layer': l,
                'last_layer_score': window_d[str(l-1)],
                'this_layer_score': window_d[str(l)],
            }
            check_info_l.append(_info_d)
            break
len(check_info_l), window_eff_cnt, all_layers_eff_cnt, len(good_samples)

(840, 916, 1207, 1207)

In [745]:
ctg_list = [(tm, nl) for tm in ['exact', 'partial', 'no-match'] for nl in ['1', '2', '3', '4+']]
layer_list = [str(l) for l in range(1, 24)]

ctg_elem2id = {elem : i for i, elem in enumerate(ctg_list)}
layer_elem2id = {elem : i for i, elem in enumerate(layer_list)}

cnt_matrix = np.zeros((len(ctg_list), len(layer_list)), int)

for info_d in check_info_l:
    sample_id = info_d['id']
    d = good_samples[sample_id]
    text_match = d['category']['text_match']
    node_len = d['category']['node_len']
    _ctg = (text_match, node_len)
    _layer = str(info_d['layer'])
    
    _ctg_idx = ctg_elem2id[_ctg]
    _layer_idx = layer_elem2id[_layer]
    cnt_matrix[_ctg_idx, _layer_idx] += 1

In [None]:
# Create a figure and axis
fig, ax = plt.subplots(figsize=(8, 6))

# Display the matrix using imshow
im = ax.imshow(cnt_matrix, cmap='Blues')

# Set the tick labels for the first and second dimensions
ax.set_xticks(np.arange(len(layer_list)))
ax.set_yticks(np.arange(len(ctg_list)))

# Set the tick labels using the ctg_list and layer_list
ax.set_xticklabels(layer_list)
ax.set_yticklabels(ctg_list)

ax.set_title(f'Section: {ob_sect_k}\n')

# Rotate the x-axis tick labels if needed
# plt.xticks(rotation=90)

# Add a colorbar
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.5)

# Show the plot
plt.show()


In [None]:
# Create a figure and axis
fig, ax = plt.subplots(figsize=(8, 6))

# Display the matrix using imshow
im = ax.imshow(cnt_matrix, cmap='Blues')

# Set the tick labels for the first and second dimensions
ax.set_xticks(np.arange(len(layer_list)))
ax.set_yticks(np.arange(len(ctg_list)))

# Set the tick labels using the ctg_list and layer_list
ax.set_xticklabels(layer_list)
ax.set_yticklabels(ctg_list)

ax.set_title(f'Section: {ob_sect_k}\n')

# Rotate the x-axis tick labels if needed
# plt.xticks(rotation=90)

# Add a colorbar
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.5)

# Show the plot
plt.show()


In [747]:
plt.close()

### Exp-5.3: attention section mutual removal

#### Load & Check

In [848]:
expect_type = 'column'

# res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.2_dev_{expect_type}-attn_crpt=logits.jsonl'
res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [849]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [850]:
good_samples = []
# bad_samples = []
too_hard_samples = []
too_easy_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            too_hard_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            too_easy_samples.append(d)

bad_samples = too_hard_samples + too_easy_samples
            
total_samples, (n_good_samples, len(good_samples)), \
(n_too_hard, len(too_hard_samples)), (n_too_easy, len(too_easy_samples)), \
len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(2002,
 (1165, 1165),
 (566, 566),
 (271, 271),
 837,
 'good / correct = 1165 / 1436')

#### Overall avg

In [364]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            if k == 'window':
                for l, s in v.items():
                    trace_scores_avg[sect_k][f'{k}-{l}'] += s
            else:
                s = v
                trace_scores_avg[sect_k][k] += s

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [None]:
trace_scores_avg

In [None]:
format_print_2D_dict(trace_scores_avg, col_w=12)

#### Avg by aspects (category)

In [842]:
d['category']

{'sql_hardness': 'medium', 'node_role': 'where', 'text_match': 'no-match'}

In [843]:
# # TEMP patch for node_len category 
# for d in good_samples + bad_samples:
#     node_len = len(d['answers_t'])
#     assert len(mt_uskg.tokenizer.tokenize(d['expect'])) == node_len, (d['expect'], node_len)
#     d['category']['node_len'] = str(node_len) if node_len <= 3 else '4+'

In [844]:
d['category']

{'sql_hardness': 'medium',
 'node_role': 'group by',
 'text_match': 'exact',
 'node_len': '3'}

In [845]:
# Key: (sect_k, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no sect key & layer key 

In [846]:
for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for k, v in sect_d.items():
                if k == 'window':
                    for l, s in v.items():
                        if not (int(l) % 4 == 3): continue
                        layer_k = f'{k}-{l}'
                        trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                else:
                    layer_k = k
                    s = v
                    trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                    
for sect_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for layer_k, s in d3.items():
                trace_scores_avg_by_aspect[sect_k][asp_k][asp_v][layer_k] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [847]:
for sect_k, sect_d in trace_scores_avg_by_aspect.items():
    sect_d['overall'] = dict()
    for layer_k, s in trace_scores_avg[sect_k].items():
        if layer_k.startswith('window'):
            # only keep a subset of layers 
            _, l = layer_k.split('-')
            if not (int(l) % 4 == 3): continue
        sect_d['overall'][layer_k] = s

In [848]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 123, 'extra': 175, 'hard': 64, 'easy': 2}),
             'node_role': defaultdict(int,
                         {'select': 191,
                          'group by': 33,
                          'join': 12,
                          'where': 111,
                          'order by': 17}),
             'text_match': defaultdict(int,
                         {'exact': 230, 'partial': 28, 'no-match': 106}),
             'node_len': defaultdict(int, {'3': 364})})

In [None]:
trace_scores_avg_by_aspect['c->p']

In [850]:
dump_d = ctu.nested_json_processing(trace_scores_avg_by_aspect, func=lambda x: np.format_float_positional(x, precision=4, min_digits=4))
# dump_d

In [851]:
# dump_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/summ-exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

# with open(dump_path, 'w') as f:
#     json.dump(dump_d, f, indent=1)

#### Good / correct ratio per class 

In [853]:
# asp_k -> asp_v -> {'good', 'too_easy', 'all_correct', 'ratio'}
good_correct_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))

for d in good_samples:
    for asp_k, asp_v in d['category'].items():
        good_correct_stats[asp_k][asp_v]['good'] += 1
        good_correct_stats[asp_k][asp_v]['all_correct'] += 1

for d in too_easy_samples:
    for asp_k, asp_v in d['category'].items():
        good_correct_stats[asp_k][asp_v]['too_easy'] += 1
        good_correct_stats[asp_k][asp_v]['all_correct'] += 1
        
for asp_k, d1 in good_correct_stats.items():
    for asp_v, asp_d in d1.items():
        asp_d['ratio'] = asp_d['good'] / asp_d['all_correct']
    

In [855]:
good_correct_stats['text_match']

defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
            {'exact': defaultdict(int,
                         {'good': 609,
                          'all_correct': 791,
                          'too_easy': 182,
                          'ratio': 0.7699115044247787}),
             'no-match': defaultdict(int,
                         {'good': 360,
                          'all_correct': 443,
                          'too_easy': 83,
                          'ratio': 0.8126410835214447}),
             'partial': defaultdict(int,
                         {'good': 196,
                          'all_correct': 202,
                          'too_easy': 6,
                          'ratio': 0.9702970297029703})})

In [856]:
format_print_2D_dict(good_correct_stats['text_match'])

XXXXXXXXXXXX	good  	all_correct	too_easy	ratio 
exact       	609.0000	791.0000	182.0000	0.7699
no-match    	360.0000	443.0000	83.0000	0.8126
partial     	196.0000	202.0000	6.0000	0.9703



#### Results reload

In [815]:
dump_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/summ-exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

with open(dump_path, 'r') as f:
    dump_d = json.load(f)

dump_d.keys()

dict_keys(['t->s', 's->t', 't<->s', 't->p', 's->p', 'ts->p', 't->t', 's->s', 's->c', 'c->s', 'c->c', 'c->p', 'all'])

In [823]:
# asp_k -> asp_v -> sect_k -> layer_k -> val
results_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))

for sect_k, sect_d in dump_d.items():
    for asp_k, d1 in sect_d.items():
        if asp_k == 'overall':
            continue
        for asp_v, asp_d in d1.items():
            for layer_k in ['window-7', 'window-19', 'all_layers']:
                v = float(asp_d[layer_k])
                results_by_aspect[asp_k][asp_v][sect_k][layer_k] = v

In [None]:
asp_k = 'text_match'

for asp_v, asp_d in results_by_aspect[asp_k].items():
    print(f'{asp_k} = {asp_v}')
    format_print_2D_dict(asp_d, col_w=10)
    print()

#### (one-time temp patch)

In [804]:
# expect_type = 'table_alias'
# orig_res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/no_c2p_exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'
# add_res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.1+c2p_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

# merge_res_path = f'/home/yshao/Projects/rome/results/exp5_3_attention_section_mutual_removal/exp=5.3.1_dev_{expect_type}_encoder-attn=self_attn-corrupt=zero.jsonl'

In [805]:
# with open(orig_res_path, 'r') as f:
#     orig_all_samples = [json.loads(l) for l in f]
# with open(add_res_path, 'r') as f:
#     add_all_samples = [json.loads(l) for l in f]

# f = open(merge_res_path, 'w')
    
# for i, (orig_ex, add_ex) in enumerate(zip(orig_all_samples, add_all_samples)):
#     assert len(orig_ex['trace_results']) == len(add_ex['trace_results']), i
#     # There is randomness in the order of expected node (from set()), thus sorting here 
#     orig_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     add_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     for j, (orig_d, add_d) in enumerate(zip(orig_ex['trace_results'], add_ex['trace_results'])):
#         assert orig_d['is_good_sample'] == add_d['is_good_sample'], (i, j)
#         if not orig_d['is_good_sample']:
#             continue
            
#         # is good sample: add the new sections 
#         orig_d['trace_scores']['c->p'] = add_d['trace_scores']['c->p']
        
#         # put all at end in the dict 
#         _t = orig_d['trace_scores']['all']
#         del orig_d['trace_scores']['all']
#         orig_d['trace_scores']['all'] = _t
        
#     f.write(json.dumps(orig_ex, indent=None) + '\n')
    
# f.close()

### Exp-5.4: attention section mutual removal

#### Load & Check

In [857]:
expect_type = 'column'

# res_path = f'/home/yshao/Projects/rome/results/exp5_4_decoder_cross_attention_removal/exp=5.4.1_dev_{expect_type}-attn_crpt=logits.jsonl'
res_path = f'/home/yshao/Projects/rome/results/exp5_4_decoder_cross_attention_removal/exp=5.4_dev_{expect_type}-attn_crpt=weights.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [861]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [862]:
good_samples = []
# bad_samples = []
too_hard_samples = []
too_easy_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            too_hard_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            too_easy_samples.append(d)

bad_samples = too_hard_samples + too_easy_samples
            
total_samples, (n_good_samples, len(good_samples)), \
(n_too_hard, len(too_hard_samples)), (n_too_easy, len(too_easy_samples)), \
len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(2002,
 (1266, 1266),
 (566, 566),
 (170, 170),
 736,
 'good / correct = 1266 / 1436')

#### Overall avg

In [831]:
# Dict[sect_k, Dict[layer_k, s]]
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, s in sect_d.items():
            trace_scores_avg[sect_k][k] += s

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [832]:
trace_scores_avg

{'all': defaultdict(int,
             {'low_layers': 0.5242457950745857,
              'mid_layers': 0.6081488858922526,
              'high_layers': 0.04873052586141602,
              'all_layers': 0.014953504248225972}),
 'ans->t': defaultdict(int,
             {'low_layers': 0.9694846544320876,
              'mid_layers': 0.9548988270596958,
              'high_layers': 0.9813394950611222,
              'all_layers': 0.9361464705256212}),
 'all->t': defaultdict(int,
             {'low_layers': 0.9533332388484522,
              'mid_layers': 0.9325720264446838,
              'high_layers': 0.9790280977666627,
              'all_layers': 0.8936582017384058}),
 'ans->s': defaultdict(int,
             {'low_layers': 0.8826992521290236,
              'mid_layers': 0.8524766282742821,
              'high_layers': 0.26891499033840577,
              'all_layers': 0.19515848414907608}),
 'all->s': defaultdict(int,
             {'low_layers': 0.8577877470518337,
              'mid_layers': 0.

In [833]:
layers_keys = ['low_layers', 'mid_layers', 'high_layers', 'all_layers']

for sect_k, sect_d in trace_scores_avg.items():
    # Now unify and use `all` for analysis 
    if sect_k.startswith('ans->'):
        continue
    print_l = f'{sect_k:<8s}'
    for k in layers_keys:
        s = sect_d[k]
        print_l += f'\t{s:.4f}'
    print(print_l)

all     	0.5242	0.6081	0.0487	0.0150
all->t  	0.9533	0.9326	0.9790	0.8937
all->s  	0.8578	0.8522	0.2715	0.2331
all->p  	0.9161	0.9358	0.9780	0.8995
all->o  	0.9872	0.9875	0.9873	0.9849
all->t+o	0.9524	0.9330	0.9773	0.8889
all->c  	0.9342	0.9845	0.9849	0.9352
all->self	0.9439	0.8528	0.3093	0.3013


#### Avg by aspects (category)

In [834]:
# Key: (sect_k, aspect, asp_val, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no sect key & layer key 

In [836]:
for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            for k, v in sect_d.items():
#                 if k == 'window':
#                     for l, s in v.items():
#                         if not (int(l) % 4 == 3): continue
#                         layer_k = f'{k}-{l}'
#                         trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
#                 else:
                layer_k = k
                s = v
                trace_scores_by_aspect[sect_k][aspect][asp_val][layer_k].append(s)
                    
for sect_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for asp_v, d3 in d2.items():
            for layer_k, s in d3.items():
                trace_scores_avg_by_aspect[sect_k][asp_k][asp_v][layer_k] = np.mean(s)
                trace_scores_cnt_by_aspect[asp_k][asp_v] = len(s)

In [837]:
# for sect_k, sect_d in trace_scores_avg_by_aspect.items():
#     sect_d['overall'] = dict()
#     for layer_k, s in trace_scores_avg[sect_k].items():
#         if layer_k.startswith('window'):
#             # only keep a subset of layers 
#             _, l = layer_k.split('-')
#             if not (int(l) % 4 == 3): continue
#         sect_d['overall'][layer_k] = s

In [838]:
trace_scores_cnt_by_aspect

defaultdict(<function __main__.<lambda>()>,
            {'sql_hardness': defaultdict(int,
                         {'medium': 586,
                          'easy': 229,
                          'hard': 240,
                          'extra': 211}),
             'node_role': defaultdict(int,
                         {'select': 630,
                          'order by': 82,
                          'where': 381,
                          'group by': 82,
                          'join': 85,
                          'having': 6}),
             'text_match': defaultdict(int,
                         {'exact': 653, 'no-match': 419, 'partial': 194}),
             'node_len': defaultdict(int,
                         {'1': 578, '3': 369, '4+': 203, '2': 116})})

In [839]:
trace_scores_avg_by_aspect['all->t']

defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
            {'sql_hardness': defaultdict(<function __main__.<lambda>.<locals>.<lambda>.<locals>.<lambda>()>,
                         {'medium': defaultdict(float,
                                      {'low_layers': 0.9511169016978063,
                                       'mid_layers': 0.9341550427679274,
                                       'high_layers': 0.9850538173801341,
                                       'all_layers': 0.8932898050521948}),
                          'easy': defaultdict(float,
                                      {'low_layers': 0.9820877028289339,
                                       'mid_layers': 0.964199215542708,
                                       'high_layers': 0.9875165987896452,
                                       'all_layers': 0.9503825834782225}),
                          'hard': defaultdict(float,
                                      {'low_layers': 0.9362867399049416,
      

In [840]:
dump_d = ctu.nested_json_processing(trace_scores_avg_by_aspect, func=lambda x: np.format_float_positional(x, precision=4, min_digits=4))
# dump_d

In [841]:
dump_path = f'/home/yshao/Projects/rome/results/exp5_4_decoder_cross_attention_removal/summ-exp=5.4_dev_{expect_type}-attn_crpt=weights.jsonl'

with open(dump_path, 'w') as f:
    json.dump(dump_d, f, indent=1)

#### Good / correct ratio per class 

In [863]:
# asp_k -> asp_v -> {'good', 'too_easy', 'all_correct', 'ratio'}
good_correct_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(int)))

for d in good_samples:
    for asp_k, asp_v in d['category'].items():
        good_correct_stats[asp_k][asp_v]['good'] += 1
        good_correct_stats[asp_k][asp_v]['all_correct'] += 1

for d in too_easy_samples:
    for asp_k, asp_v in d['category'].items():
        good_correct_stats[asp_k][asp_v]['too_easy'] += 1
        good_correct_stats[asp_k][asp_v]['all_correct'] += 1
        
for asp_k, d1 in good_correct_stats.items():
    for asp_v, asp_d in d1.items():
        asp_d['ratio'] = asp_d['good'] / asp_d['all_correct']
    

In [864]:
good_correct_stats['text_match']

defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
            {'exact': defaultdict(int,
                         {'good': 653,
                          'all_correct': 791,
                          'too_easy': 138,
                          'ratio': 0.8255372945638433}),
             'no-match': defaultdict(int,
                         {'good': 419,
                          'all_correct': 443,
                          'too_easy': 24,
                          'ratio': 0.945823927765237}),
             'partial': defaultdict(int,
                         {'good': 194,
                          'all_correct': 202,
                          'too_easy': 8,
                          'ratio': 0.9603960396039604})})

In [865]:
format_print_2D_dict(good_correct_stats['text_match'])

XXXXXXXXXXXX	good  	all_correct	too_easy	ratio 
exact       	653.0000	791.0000	138.0000	0.8255
no-match    	419.0000	443.0000	24.0000	0.9458
partial     	194.0000	202.0000	8.0000	0.9604



#### Results reload

In [842]:
dump_path = f'/home/yshao/Projects/rome/results/exp5_4_decoder_cross_attention_removal/summ-exp=5.4_dev_{expect_type}-attn_crpt=weights.jsonl'

with open(dump_path, 'r') as f:
    dump_d = json.load(f)

dump_d.keys()

dict_keys(['all', 'ans->t', 'all->t', 'ans->s', 'all->s', 'ans->p', 'all->p', 'ans->o', 'all->o', 'ans->t+o', 'all->t+o', 'ans->c', 'all->c', 'ans->self', 'all->self'])

In [844]:
# asp_k -> asp_v -> sect_k -> layer_k -> val
results_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(float))))

for sect_k, sect_d in dump_d.items():
    if sect_k.startswith('ans->'):
        continue
        
    for asp_k, d1 in sect_d.items():
        if asp_k == 'overall':
            continue
        for asp_v, asp_d in d1.items():
            for layer_k in asp_d.keys():
                v = float(asp_d[layer_k])
                results_by_aspect[asp_k][asp_v][sect_k][layer_k] = v

In [845]:
asp_k = 'text_match'

for asp_v, asp_d in results_by_aspect[asp_k].items():
    print(f'{asp_k} = {asp_v}')
    format_print_2D_dict(asp_d, col_w=10)
    print()

text_match = exact
XXXXXXXXXXXX	low_layers	mid_layers	high_layers	all_layers
all         	0.6579    	0.6245    	0.0514    	0.0160    
all->t      	0.9465    	0.9244    	0.9812    	0.8796    
all->s      	0.9685    	0.9386    	0.4133    	0.3229    
all->p      	0.9769    	0.9584    	0.9874    	0.9565    
all->o      	0.9933    	0.9929    	0.9930    	0.9913    
all->t+o    	0.9449    	0.9254    	0.9798    	0.8760    
all->c      	0.9684    	0.9889    	0.9908    	0.9527    
all->self   	0.9910    	0.9589    	0.4292    	0.4146    


text_match = no-match
XXXXXXXXXXXX	low_layers	mid_layers	high_layers	all_layers
all         	0.3726    	0.6184    	0.0540    	0.0139    
all->t      	0.9630    	0.9462    	0.9749    	0.9221    
all->s      	0.7303    	0.7756    	0.1491    	0.1614    
all->p      	0.8113    	0.9099    	0.9656    	0.7972    
all->o      	0.9759    	0.9774    	0.9764    	0.9729    
all->t+o    	0.9620    	0.9455    	0.9718    	0.9137    
all->c      	0.8764    	0.9808    	0.9897  

### Exp-5.5: both part attention removal

#### Load & Check

In [131]:
expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp5_5_both_part_attention_removal/exp=5.5.1_dev_{expect_type}-attn_crpt=logits.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [132]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [133]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d['is_good_sample']:
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(2039, (1631, 1631), 339, 69, 408, 'good / correct = 1631 / 1700')

#### Overall avg

In [134]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            if k == 'window':
                for l, s in v.items():
                    trace_scores_avg[sect_k][f'{k}-{l}'] += s
            else:
                s = v
                trace_scores_avg[sect_k][k] += s

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [135]:
trace_scores_avg

{'s->t&all->t': defaultdict(int,
             {'E-all&D-all': 0.810587916234049,
              'E-all&D-low': 0.8383651262275511,
              'E-low&D-all': 0.8902601872852249})}

In [136]:
format_print_1D_dict(trace_scores_avg['s->t&all->t'], head_col_w=15)

E-all&D-all    0.8106
E-all&D-low    0.8384
E-low&D-all    0.8903


#### Avg by aspects

In [98]:
# Key: (sect_k, asp_k, layer) -> [scores]
trace_scores_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
trace_scores_avg_by_aspect = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
trace_scores_cnt_by_aspect = defaultdict(lambda: defaultdict(int))  # no sect key & layer key 

In [99]:
for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for aspect, asp_val in d['category'].items():
            asp_k = f'{aspect}={asp_val}'
            for k, v in sect_d.items():
                layer_k = k
                s = v
                trace_scores_by_aspect[sect_k][asp_k][layer_k].append(s)
                    
for sect_k, d1 in trace_scores_by_aspect.items():
    for asp_k, d2 in d1.items():
        for layer_k, s in d2.items():
            trace_scores_avg_by_aspect[sect_k][asp_k][layer_k] = np.mean(s)
            trace_scores_cnt_by_aspect[asp_k] = len(s)

In [100]:
# for sect_k, sect_d in trace_scores_avg_by_aspect.items():
#     sect_d['overall'] = dict()
#     for layer_k, s in trace_scores_avg[sect_k].items():
#         if layer_k.startswith('window'):
#             # only keep a subset of layers 
#             _, l = layer_k.split('-')
#             if not (int(l) % 4 == 3): continue
#         sect_d['overall'][layer_k] = s

In [101]:
sorted(trace_scores_cnt_by_aspect.items())

[('node_len=1', 578),
 ('node_len=2', 116),
 ('node_len=3', 369),
 ('node_len=4+', 203),
 ('node_role=group by', 82),
 ('node_role=having', 6),
 ('node_role=join', 85),
 ('node_role=order by', 82),
 ('node_role=select', 630),
 ('node_role=where', 381),
 ('sql_hardness=easy', 229),
 ('sql_hardness=extra', 211),
 ('sql_hardness=hard', 240),
 ('sql_hardness=medium', 586),
 ('text_match=exact', 653),
 ('text_match=no-match', 419),
 ('text_match=partial', 194)]

In [102]:
{k: trace_scores_avg_by_aspect['s->t&all->t'][f'sql_hardness={k}'] for k in ['easy', 'medium', 'hard', 'extra']}

{'easy': defaultdict(float,
             {'E-all&D-all': 0.6952432668401013,
              'E-all&D-low': 0.9067031433060689,
              'E-low&D-all': 0.9375170493903875}),
 'medium': defaultdict(float,
             {'E-all&D-all': 0.556048543710707,
              'E-all&D-low': 0.8054229393484426,
              'E-low&D-all': 0.8955880533978087}),
 'hard': defaultdict(float,
             {'E-all&D-all': 0.5484870432557414,
              'E-all&D-low': 0.8360452807140587,
              'E-low&D-all': 0.8290000068771377}),
 'extra': defaultdict(float,
             {'E-all&D-all': 0.6368124696327836,
              'E-all&D-low': 0.8165238064917822,
              'E-low&D-all': 0.8967893739510424})}

In [103]:
{k: trace_scores_avg_by_aspect['s->t&all->t'][f'text_match={k}'] for k in ['exact', 'partial', 'no-match']}

{'exact': defaultdict(float,
             {'E-all&D-all': 0.5947468875597053,
              'E-all&D-low': 0.8776616909043937,
              'E-low&D-all': 0.8787339264019262}),
 'partial': defaultdict(float,
             {'E-all&D-all': 0.49977204983319107,
              'E-all&D-low': 0.8079702937117633,
              'E-low&D-all': 0.8836331785910234}),
 'no-match': defaultdict(float,
             {'E-all&D-all': 0.6342099784024492,
              'E-all&D-low': 0.7701454216605397,
              'E-low&D-all': 0.9127696242686905})}

In [105]:
ob_ids = []
for i, d in enumerate(good_samples):
    if d['category']['text_match'] == 'exact':
        continue
    # here: no exact text match 
    if d['trace_scores']['s->t&all->t']['E-all&D-all'] < 0.5:
        continue
    # here: corrupted pred is correct 
    ob_ids.append(i)

In [109]:
print(len(ob_ids), ob_ids[::10])

360 [9, 77, 96, 116, 138, 150, 177, 200, 228, 255, 282, 297, 375, 386, 439, 487, 530, 584, 604, 635, 673, 755, 849, 874, 892, 911, 926, 962, 986, 1003, 1019, 1039, 1082, 1108, 1156, 1211]


In [110]:
good_samples[9]

{'enc_sentence': 'What is the average, minimum, and maximum age of all singers from France?; structed knowledge: | concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country ( France ) , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id',
 'seq_out': "select avg(age), min(age), max(age) from singer where country = 'France'",
 'dec_prompt': 'select avg(age), min(age), max(age) from singer where',
 'expect': 'country',
 'expect_type': 'column',
 'db_id': 'concert_singer',
 'expect_input_ranges': [[68, 73]],
 'self_ranges': [[66, 75]],
 'expect_table': 'singer',
 'answer': 'country',
 'base_score': 0.9999995231628418,
 'answers_t': [684],
 'correct_prediction': True,
 'category': {'sql_hardness': 'medium',
  'node_role': 'where',
  'text_match': 'no-match',
  'node_len': '1'},
 'corrupted_answers_t': [2306],
 

### Exp-6.0: corruption effect - syntax

#### Load & Check

In [287]:
# expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp6_0_encoding_corruption_effect_syntax/exp=6.0_dev.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [288]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [289]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d.get('is_good_sample', True):
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(10233, (2261, 2261), 1623, 6349, 7972, 'good / correct = 2261 / 8610')

#### Overall avg

In [290]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            trace_scores_avg[sect_k][k] += v

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [291]:
trace_scores_avg

{'text': defaultdict(int,
             {'embed': 0.2704069729491695, 'final_enc': 0.43288231847139375}),
 'struct': defaultdict(int,
             {'embed': 0.8434888537914244, 'final_enc': 0.7056294551667914}),
 'columns': defaultdict(int,
             {'embed': 0.8977011346416999, 'final_enc': 0.9094602057515592}),
 'tables': defaultdict(int,
             {'embed': 0.9401333304200568, 'final_enc': 0.9652279544981762}),
 'all': defaultdict(int,
             {'embed': 0.04223158108438767, 'final_enc': 0.14583704401879738})}

In [292]:
format_print_2D_dict(trace_scores_avg)

XXXXXXXXXXXX	embed 	final_enc
text        	0.2704	0.4329
struct      	0.8435	0.7056
columns     	0.8977	0.9095
tables      	0.9401	0.9652
all         	0.0422	0.1458



#### Corruption overall effect

In [293]:
# Dict[str, int]: expect_tok -> num of effective / not effective corruptions (all)
eff_counter = Counter()
neff_counter = Counter()

for d in good_samples:
    eff_counter[d['expect']] += 1
for d in bad_samples:
    if d['correct_prediction']:
        # "too easy", corruption not effective 
        neff_counter[d['expect']] += 1

In [294]:
eff_rate_d = dict()

for k in list(set(eff_counter.keys()) | set(neff_counter.keys())):
    eff_c = eff_counter[k]
    neff_c = neff_counter[k]
    eff_r = 1.0 * eff_c / (eff_c + neff_c)
    eff_rate_d[k] = eff_r
    # print(f'{k:<10s}{eff_c:5d} /{eff_c + neff_c:5d} = {eff_r:.4f}')

In [295]:
for k, eff_r in sorted(eff_rate_d.items(), key=lambda x: x[1], reverse=True):
    eff_c = eff_counter[k]
    neff_c = neff_counter[k]
    all_c = eff_c + neff_c
    if all_c <= 2: continue
    if k.isnumeric(): continue
#     print(f'{k:<10s}{eff_c:5d} /{all_c:5d} = {eff_r:.4f}')
    print(f'{k}\t{eff_c}\t{all_c}\t{eff_r:.4f}')

!=	20	20	1.0000
or	34	34	1.0000
between	6	6	1.0000
intersect	34	34	1.0000
union	6	6	1.0000
asc	19	19	1.0000
min	18	18	1.0000
avg	65	65	1.0000
max	30	30	1.0000
except	21	21	1.0000
like	12	12	1.0000
having	80	81	0.9877
distinct	25	26	0.9615
sum	21	22	0.9545
where	484	516	0.9380
not	42	46	0.9130
group	225	265	0.8491
and	31	39	0.7949
count	267	406	0.6576
order	142	221	0.6425
>	61	101	0.6040
)	11	23	0.4783
=	191	968	0.1973
as	93	952	0.0977
desc	16	164	0.0976
in	4	50	0.0800
join	39	496	0.0786
from	49	1196	0.0410
limit	6	177	0.0339
>=	1	30	0.0333
(	22	675	0.0326
select	0	88	0.0000
on	0	516	0.0000
by	0	516	0.0000
*	0	381	0.0000


#### Avg by expect syntax token

In [296]:
# Key: (expect_tok, sect_k, layer) -> [scores]
trace_scores_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
trace_scores_avg_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
trace_scores_cnt_by_exp_tok = defaultdict(int)  # no sect key & layer key 

trace_sample_ids_by_exp_tok = defaultdict(list)

In [297]:
for i, d in enumerate(good_samples):
    expect = d['expect']
    trace_sample_ids_by_exp_tok[expect].append(i)
    for sect_k, sect_d in d['trace_scores'].items():
        for layer_k, v in sect_d.items():
            trace_scores_by_exp_tok[expect][sect_k][layer_k].append(v)

for exp_tok, d1 in trace_scores_by_exp_tok.items():
    if exp_tok.isnumeric(): continue
    for sect_k, d2 in d1.items():
        for layer_k, scores in d2.items():
            if len(scores) <= 2: continue
            trace_scores_avg_by_exp_tok[exp_tok][sect_k][layer_k] = np.mean(scores)
            trace_scores_cnt_by_exp_tok[exp_tok] = len(scores)

In [298]:
trace_scores_cnt_by_exp_tok

defaultdict(int,
            {'count': 267,
             'order': 142,
             'avg': 65,
             'min': 18,
             'max': 30,
             'where': 484,
             'distinct': 25,
             '>': 61,
             'group': 225,
             '(': 22,
             'between': 6,
             'from': 49,
             'desc': 16,
             'or': 34,
             'not': 42,
             'intersect': 34,
             'except': 21,
             'as': 93,
             'join': 39,
             'like': 12,
             'and': 31,
             '=': 191,
             'having': 80,
             '!=': 20,
             'union': 6,
             'limit': 6,
             'sum': 21,
             'asc': 19,
             ')': 11,
             'in': 4})

In [299]:
trace_scores_avg_by_exp_tok['count']

defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
            {'text': defaultdict(float,
                         {'embed': 0.07181598544092065,
                          'final_enc': 0.08561127370659098}),
             'struct': defaultdict(float,
                         {'embed': 0.9997616432579269,
                          'final_enc': 0.9906901482785686}),
             'columns': defaultdict(float,
                         {'embed': 0.9989915059300397,
                          'final_enc': 0.9982228384035804}),
             'tables': defaultdict(float,
                         {'embed': 0.9940267473124387,
                          'final_enc': 0.9992919242783879}),
             'all': defaultdict(float,
                         {'embed': 0.022402080392785025,
                          'final_enc': 0.07584266595464821})})

In [300]:
sect_k = 'struct'
layer_k = 'final_enc'
scores_d = dict()

for exp_tok, d1 in trace_scores_avg_by_exp_tok.items():
    s = d1[sect_k][layer_k]
    scores_d[exp_tok] = s
    # print(f'{exp_tok:<10s}{s:.4f}')

In [301]:
for k, s in sorted(scores_d.items(), key=lambda x: x[1], reverse=True):
    print(f'{k:<10s}{s:.4f}')

min       0.9988
sum       0.9949
distinct  0.9942
count     0.9907
avg       0.9820
like      0.9440
>         0.9256
max       0.9167
between   0.9151
having    0.8882
and       0.8846
or        0.8764
intersect 0.8033
union     0.7311
!=        0.7206
asc       0.7038
(         0.7033
order     0.6748
where     0.6481
limit     0.6085
except    0.6078
desc      0.5801
=         0.5578
from      0.5052
group     0.4578
)         0.4223
not       0.4213
as        0.3979
in        0.2511
join      0.1060


In [268]:
sect_k = 'text'
layer_k = 'embed'
scores_d = dict()

for exp_tok, d1 in trace_scores_avg_by_exp_tok.items():
    s = d1[sect_k][layer_k]
    scores_d[exp_tok] = s
    # print(f'{exp_tok:<10s}{s:.4f}')

for k, s in sorted(scores_d.items(), key=lambda x: x[1], reverse=True):
    print(f'{k:<10s}{s:.4f}')

from      0.8645
as        0.8363
=         0.8272
in        0.7407
)         0.7101
desc      0.4644
where     0.3572
>         0.3472
join      0.3131
not       0.2707
and       0.2300
having    0.1737
min       0.1698
(         0.1657
like      0.1639
union     0.1435
order     0.1126
group     0.1046
asc       0.0877
count     0.0718
avg       0.0311
or        0.0211
except    0.0106
distinct  0.0097
max       0.0073
sum       0.0017
intersect 0.0005
!=        0.0000
between   0.0000
limit     0.0000


#### Corrupted answer

In [279]:
# Dict[str, Dict[str, int]]: exp_tok -> c_ans, cnt
confusion_counter = defaultdict(Counter)

for d in good_samples:
    exp_tok = d['expect']
    c_ans = d['corrupted_answer']
    if exp_tok.isnumeric():
        exp_tok = 'NUM'
    if c_ans.isnumeric():
        c_ans = 'NUM'
    confusion_counter[exp_tok][c_ans] += 1

In [280]:
confusion_counter

defaultdict(collections.Counter,
            {'count': Counter({'*': 182, '': 64, 'sum': 19, 'count': 2}),
             'order': Counter({'</s>': 96,
                      'join': 12,
                      ')': 2,
                      'where': 9,
                      'NUM': 5,
                      'order': 3,
                      'union': 2,
                      'and': 2,
                      'select': 4,
                      's': 6,
                      '_': 1}),
             'avg': Counter({'*tvg': 43,
                      'maxtvg': 8,
                      'maxavg': 2,
                      'tvg': 6,
                      'mintvg': 2,
                      'counttvg': 4}),
             'min': Counter({'': 9, 'max': 5, '*': 4}),
             'max': Counter({'': 13, '*': 16, 'min': 1}),
             'where': Counter({'group': 19,
                      'except': 2,
                      '</s>': 239,
                      'order': 114,
                      'where': 9,
        

#### Case study
- Now only for the low-score setting (all-layer, embed)

In [312]:
for idx in trace_sample_ids_by_exp_tok['max']:
    d = good_samples[idx]
    _text = d['enc_sentence'].split(';')[0]
    print(f"{_text}\n{d['dec_prompt']} --> {d['corrupted_answer']} ({d['expect']})")
    print()

What is the average, minimum, and maximum age of all singers from France?
select avg(age), min(age), -->  (max)

What is the average, minimum, and maximum age for all French singers?
select avg(age), min(age), -->  (max)

What is the maximum capacity and the average of all stadiums ?
select --> * (max)

What is the average and maximum capacities for all stadiums ?
select avg(capacity), -->  (max)

Find the maximum weight for each type of pet. List the maximum weight and pet type.
select --> * (max)

List the maximum weight and type for each type of pet.
select --> * (max)

Find the average and maximum age for each type of pet.
select avg(pet_age), -->  (max)

What is the average and maximum age for each pet type?
select avg(pet_age), -->  (max)

What is the maximum accelerate for different number of cylinders?
select --> * (max)

What is the maximum accelerate for all the different cylinders?
select --> * (max)

What is the maximum miles per gallon of the car with 8 cylinders or produc

### Exp-6.1: attention corruption effect - syntax

#### Load & Check

In [180]:
# expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp6_1_attention_corruption_effect_syntax/exp=6.1.1_dev-attn_crpt=logits.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [181]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [182]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d.get('is_good_sample', True):
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(10233, (2375, 2375), 1623, 6235, 7858, 'good / correct = 2375 / 8610')

In [183]:
# good_samples[0]['trace_scores']

#### Overall avg

In [184]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            trace_scores_avg[sect_k][k] += v

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [185]:
trace_scores_avg

{'t->s': defaultdict(int,
             {'low_layers': 0.9873034850396216,
              'high_layers': 0.9794485656297967,
              'all_layers': 0.9705235183159128}),
 's->t': defaultdict(int,
             {'low_layers': 0.9914981258135093,
              'high_layers': 0.9701850305394514,
              'all_layers': 0.9438354478774255}),
 't<->s': defaultdict(int,
             {'low_layers': 0.9846226840465281,
              'high_layers': 0.9617710263219753,
              'all_layers': 0.9230146904192555}),
 't->p': defaultdict(int,
             {'low_layers': 0.9806906042866831,
              'high_layers': 0.889322833697198,
              'all_layers': 0.8364523902912913}),
 's->p': defaultdict(int,
             {'low_layers': 0.9860267275776025,
              'high_layers': 0.9480603549151314,
              'all_layers': 0.9385463390567251}),
 'ts->p': defaultdict(int,
             {'low_layers': 0.9725087593267217,
              'high_layers': 0.7826832817374065,
           

In [186]:
layers_keys = ['low_layers', 'high_layers', 'all_layers']

for sect_k, sect_d in trace_scores_avg.items():
    print_l = f'{sect_k:<8s}'
    for k in layers_keys:
        s = sect_d[k]
        print_l += f'\t{s:.4f}'
    print(print_l)

t->s    	0.9873	0.9794	0.9705
s->t    	0.9915	0.9702	0.9438
t<->s   	0.9846	0.9618	0.9230
t->p    	0.9807	0.8893	0.8365
s->p    	0.9860	0.9481	0.9385
ts->p   	0.9725	0.7827	0.6456
t->t    	0.9536	0.9105	0.7346
s->s    	0.8818	0.9009	0.8486
all     	0.6233	0.2327	0.0349


#### Corruption overall effect

In [301]:
# Dict[str, int]: expect_tok -> num of effective / not effective corruptions (all)
eff_counter = Counter()
neff_counter = Counter()

for d in good_samples:
    eff_counter[d['expect']] += 1
for d in bad_samples:
    if d['correct_prediction']:
        # "too easy", corruption not effective 
        neff_counter[d['expect']] += 1

In [302]:
eff_rate_d = dict()

for k in list(set(eff_counter.keys()) | set(neff_counter.keys())):
    eff_c = eff_counter[k]
    neff_c = neff_counter[k]
    eff_r = 1.0 * eff_c / (eff_c + neff_c)
    eff_rate_d[k] = eff_r
    # print(f'{k:<10s}{eff_c:5d} /{eff_c + neff_c:5d} = {eff_r:.4f}')

In [303]:
for k, eff_r in sorted(eff_rate_d.items(), key=lambda x: x[1], reverse=True):
    eff_c = eff_counter[k]
    neff_c = neff_counter[k]
    all_c = eff_c + neff_c
    if all_c <= 2: continue
    if k.isnumeric(): continue
#     print(f'{k:<10s}{eff_c:5d} /{all_c:5d} = {eff_r:.4f}')
    print(f'{k}\t{eff_c}\t{all_c}\t{eff_r:.4f}')

union	6	6	1.0000
!=	20	20	1.0000
like	12	12	1.0000
or	34	34	1.0000
asc	19	19	1.0000
distinct	26	26	1.0000
between	6	6	1.0000
except	21	21	1.0000
intersect	34	34	1.0000
not	45	46	0.9783
avg	63	65	0.9692
max	29	30	0.9667
having	77	81	0.9506
group	241	265	0.9094
sum	20	22	0.9091
order	197	221	0.8914
min	14	18	0.7778
and	29	39	0.7436
count	294	406	0.7241
where	350	516	0.6783
>	68	101	0.6733
)	13	23	0.5652
desc	82	164	0.5000
>=	11	30	0.3667
from	263	1196	0.2199
in	8	50	0.1600
limit	26	177	0.1469
(	98	675	0.1452
=	128	968	0.1322
join	44	496	0.0887
as	52	952	0.0546
*	10	381	0.0262
by	0	516	0.0000
on	0	516	0.0000
select	0	88	0.0000


#### Avg by expect syntax token

In [289]:
# Key: (expect_tok, sect_k, layer) -> [scores]
trace_scores_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
trace_scores_avg_by_exp_tok = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
trace_scores_cnt_by_exp_tok = defaultdict(int)  # no sect key & layer key 

trace_sample_ids_by_exp_tok = defaultdict(list)

In [290]:
for i, d in enumerate(good_samples):
    expect = d['expect']
    trace_sample_ids_by_exp_tok[expect].append(i)
    for sect_k, sect_d in d['trace_scores'].items():
        for layer_k, v in sect_d.items():
            trace_scores_by_exp_tok[expect][sect_k][layer_k].append(v)

for exp_tok, d1 in trace_scores_by_exp_tok.items():
    if exp_tok.isnumeric(): continue
    for sect_k, d2 in d1.items():
        for layer_k, scores in d2.items():
            if len(scores) <= 2: continue
            trace_scores_avg_by_exp_tok[exp_tok][sect_k][layer_k] = np.mean(scores)
            trace_scores_cnt_by_exp_tok[exp_tok] = len(scores)

In [291]:
trace_scores_cnt_by_exp_tok

defaultdict(int,
            {'count': 294,
             'order': 197,
             'desc': 82,
             'avg': 63,
             'min': 14,
             'max': 29,
             'where': 350,
             '=': 128,
             'distinct': 26,
             'from': 263,
             '>': 68,
             'group': 241,
             '(': 98,
             'between': 6,
             'limit': 26,
             'or': 34,
             '>=': 11,
             'not': 45,
             'intersect': 34,
             'except': 21,
             'as': 52,
             'join': 44,
             'like': 12,
             'and': 29,
             'in': 8,
             'having': 77,
             '!=': 20,
             '*': 10,
             'union': 6,
             'sum': 20,
             'asc': 19,
             ')': 13})

In [292]:
trace_scores_avg_by_exp_tok['count']

defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
            {'t->s': defaultdict(float,
                         {'low_layers': 0.9978861685107354,
                          'high_layers': 0.9932368714425738,
                          'all_layers': 0.9911034725571596}),
             's->t': defaultdict(float,
                         {'low_layers': 0.9994523721892817,
                          'high_layers': 0.9979493310865091,
                          'all_layers': 0.997748848329596}),
             't<->s': defaultdict(float,
                         {'low_layers': 0.998062480874613,
                          'high_layers': 0.995554579890707,
                          'all_layers': 0.9959644333643167}),
             't->p': defaultdict(float,
                         {'low_layers': 0.9941589293855347,
                          'high_layers': 0.9844484830047099,
                          'all_layers': 0.9640641826280487}),
             's->p': defaultdict(float,
        

In [296]:
sect_k = 't<->s'
layer_k = 'all_layers'
scores_d = dict()

for exp_tok, d1 in trace_scores_avg_by_exp_tok.items():
    s = d1[sect_k][layer_k]
    scores_d[exp_tok] = s
    # print(f'{exp_tok:<10s}{s:.4f}')

In [297]:
for k, s in sorted(scores_d.items(), key=lambda x: x[1], reverse=True):
    print(f'{k:<10s}{s:.4f}')

>=        1.0000
>         1.0000
like      1.0000
*         1.0000
min       1.0000
between   1.0000
in        0.9992
avg       0.9972
count     0.9960
(         0.9946
desc      0.9942
having    0.9894
order     0.9826
!=        0.9825
=         0.9786
group     0.9771
max       0.9741
)         0.9716
where     0.9702
limit     0.9684
from      0.9671
or        0.9582
intersect 0.9522
asc       0.9437
except    0.9415
not       0.9360
sum       0.8969
as        0.8349
distinct  0.8284
union     0.8037
and       0.6727
join      0.4976


In [308]:
all_exp_toks = sorted(list(trace_scores_cnt_by_exp_tok.keys()))
all_sections = list(good_samples[0]['trace_scores'].keys())

print_str = '\t'.join(['Syntax-tok'] + all_sections + ['Eff_cnt', 'All_cnt', 'Eff_rate']) + '\n'

for exp_tok in all_exp_toks:
    print_str += f'{exp_tok:<10s}'
    for sect_k in all_sections:
        s = trace_scores_avg_by_exp_tok[exp_tok][sect_k]['all_layers']
        print_str += f'\t{s:.4f}'
    eff_c = eff_counter[exp_tok]
    neff_c = neff_counter[exp_tok]
    all_c = eff_c + neff_c
    eff_r = eff_rate_d[exp_tok]
    print_str += f'\t{eff_c:<7d}\t{all_c:<7d}\t{eff_r:.4f}'
    print_str += '\n'
    

In [309]:
print(print_str)

Syntax-tok	t->s	s->t	t<->s	t->p	s->p	ts->p	t->t	s->s	all	Eff_cnt	All_cnt	Eff_rate
!=        	0.9984	0.9995	0.9825	0.8485	0.9998	0.9070	0.3025	0.9193	0.0284	20     	20     	1.0000
(         	0.9999	0.9794	0.9946	0.9719	0.9795	0.9685	0.8983	0.9467	0.1856	98     	675    	0.1452
)         	0.9257	0.9926	0.9716	0.9914	0.9074	0.8420	0.9244	0.4715	0.1108	13     	23     	0.5652
*         	1.0000	0.9999	1.0000	0.9997	0.9999	0.9998	0.9999	0.9982	0.1802	10     	381    	0.0262
=         	0.9784	0.9923	0.9786	0.9378	0.8940	0.7439	0.9466	0.5933	0.1149	128    	968    	0.1322
>         	1.0000	1.0000	1.0000	0.9509	0.9990	0.9652	0.8193	0.9845	0.0219	68     	101    	0.6733
>=        	1.0000	1.0000	1.0000	0.9151	1.0000	0.9048	0.2593	0.9964	0.1703	11     	30     	0.3667
and       	0.7931	0.8981	0.6727	0.4962	0.8409	0.3468	0.3603	0.6749	0.0554	29     	39     	0.7436
as        	0.9212	0.8936	0.8349	0.7992	0.5034	0.3457	0.8478	0.3288	0.1002	52     	952    	0.0546
asc       	0.9723	0.9096	0.9437	0.4939	0.8596

In [310]:
# Issue checking: multi-token 
for exp_tok in all_exp_toks:
    print(mt_uskg.tokenizer.tokenize(exp_tok))

['▁', '!', '=']
['▁(']
['▁', ')']
['▁*']
['▁=']
['▁>']
['▁>', '=']
['▁and']
['▁as']
['▁as', 'c']
['▁', 'a', 'v', 'g']
['▁between']
['▁count']
['▁des', 'c']
['▁distinct']
['▁except']
['▁from']
['▁group']
['▁having']
['▁in']
['▁intersect']
['▁join']
['▁like']
['▁limit']
['▁max']
['▁min']
['▁not']
['▁or']
['▁order']
['▁sum']
['▁union']
['▁where']


### Exp-6.2: dec cross attention corruption effect - syntax

#### Load & Check

In [187]:
# expect_type = 'table_alias'

res_path = f'/home/yshao/Projects/rome/results/exp6_2_decoder_cross_attention_corruption_syntax/exp=6.2.1_dev-attn_crpt=logits.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [188]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [189]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d.get('is_good_sample', True):
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(10233, (5731, 5731), 1623, 2879, 4502, 'good / correct = 5731 / 8610')

In [190]:
# good_samples[0]['trace_scores']

#### Overall avg

In [191]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            trace_scores_avg[sect_k][k] += v

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [192]:
trace_scores_avg

{'all': defaultdict(int,
             {'q1_layers': 0.8251662824628329,
              'q2_layers': 0.9568811011561555,
              'q3_layers': 0.9585451913219025,
              'q4_layers': 0.9586033384734522,
              'low_layers': 0.5032536523294612,
              'mid_layers': 0.7412946459743097,
              'high_layers': 0.7327423463639,
              'all_layers': 0.057349480442763175}),
 'ans->t': defaultdict(int,
             {'q1_layers': 0.9954168156351754,
              'q2_layers': 0.9905822788459185,
              'q3_layers': 0.9671664178380461,
              'q4_layers': 0.9757743479523219,
              'low_layers': 0.9790280398693059,
              'mid_layers': 0.8837345000447828,
              'high_layers': 0.8696248513307187,
              'all_layers': 0.7785073726874928}),
 'all->t': defaultdict(int,
             {'q1_layers': 0.9946716999747656,
              'q2_layers': 0.9847246179079293,
              'q3_layers': 0.9633879014188113,
             

In [193]:
# layers_keys = trace_scores_avg['all'].keys()

# for sect_k, sect_d in trace_scores_avg.items():
#     print_l = f'{sect_k:<8s}'
#     for k in layers_keys:
#         s = sect_d[k]
#         print_l += f'\t{s:.4f}'
#     print(print_l)

format_print_2D_dict(trace_scores_avg, head_col_w=7, col_w=11)

XXXXXXX	q1_layers  	q2_layers  	q3_layers  	q4_layers  	low_layers 	mid_layers 	high_layers	all_layers 
all    	0.8252     	0.9569     	0.9585     	0.9586     	0.5033     	0.7413     	0.7327     	0.0573     
ans->t 	0.9954     	0.9906     	0.9672     	0.9758     	0.9790     	0.8837     	0.8696     	0.7785     
all->t 	0.9947     	0.9847     	0.9634     	0.9758     	0.9655     	0.8466     	0.8653     	0.7453     
ans->s 	0.9868     	0.9901     	0.9947     	0.9962     	0.9646     	0.9871     	0.9943     	0.9606     
all->s 	0.9496     	0.9846     	0.9946     	0.9961     	0.8974     	0.9810     	0.9944     	0.8988     
ans->p 	0.9946     	0.9927     	0.9532     	0.9777     	0.9876     	0.9266     	0.8739     	0.8133     
all->p 	0.9803     	0.9783     	0.9428     	0.9773     	0.9316     	0.8687     	0.8454     	0.6509     
ans->o 	0.9975     	0.9975     	0.9976     	0.9976     	0.9974     	0.9974     	0.9976     	0.9973     
all->o 	0.9974     	0.9976     	0.9976     	0.9976     	0.9973  

#### Avg by expect syntax token

In [379]:
res_by_exp_tok = exp6_ob_by_exp_tok(good_samples)

In [380]:
trace_scores_cnt_by_exp_tok = res_by_exp_tok['cnt']
trace_scores_cnt_by_exp_tok

defaultdict(int,
            {'count': 394,
             '(': 407,
             'order': 219,
             'avg': 63,
             'min': 18,
             'max': 28,
             'where': 502,
             'distinct': 26,
             'from': 702,
             '>': 65,
             'group': 265,
             'by': 33,
             'between': 6,
             'desc': 156,
             'or': 34,
             'as': 561,
             'on': 474,
             '>=': 3,
             'not': 46,
             'intersect': 34,
             'select': 44,
             'except': 21,
             '=': 409,
             'join': 20,
             'like': 10,
             'in': 15,
             'having': 62,
             'and': 6,
             'limit': 16,
             '!=': 20,
             '*': 10,
             'union': 6,
             'sum': 22,
             'asc': 19,
             ')': 17})

In [381]:
trace_scores_avg_by_exp_tok = res_by_exp_tok['avg']
trace_scores_avg_by_exp_tok['count']

defaultdict(<function __main__.exp6_ob_by_exp_tok.<locals>.<lambda>.<locals>.<lambda>()>,
            {'all': defaultdict(float,
                         {'q1_layers': 0.9964363076889575,
                          'q2_layers': 0.9963203995739143,
                          'q3_layers': 0.9551561746639813,
                          'q4_layers': 0.9972184849572061,
                          'low_layers': 0.8215906578539822,
                          'mid_layers': 0.46418598421683044,
                          'high_layers': 0.7123522221803182,
                          'all_layers': 0.014818084030064214}),
             'ans->t': defaultdict(float,
                         {'q1_layers': 0.9994401957480435,
                          'q2_layers': 0.9989948929263855,
                          'q3_layers': 0.9765654405254729,
                          'q4_layers': 0.9960574505255004,
                          'low_layers': 0.9980983920206273,
                          'mid_layers': 0.713294777

In [383]:
sect_k = 'ans->t'
layer_k = 'all_layers'
scores_d = dict()

for exp_tok, d1 in trace_scores_avg_by_exp_tok.items():
    s = d1[sect_k][layer_k]
    scores_d[exp_tok] = s
    # print(f'{exp_tok:<10s}{s:.4f}')

In [388]:
format_print_1D_dict(scores_d, sort_by='value', reverse=True)

select    1.0000
by        1.0000
on        1.0000
as        0.9977
in        0.9977
)         0.9955
from      0.9942
*         0.9826
=         0.9823
(         0.9583
where     0.9210
not       0.8927
join      0.8654
desc      0.8452
having    0.8170
>=        0.6667
group     0.6521
limit     0.6250
like      0.5985
>         0.5678
min       0.5502
except    0.5067
count     0.4403
sum       0.3922
or        0.3801
order     0.3762
union     0.3335
distinct  0.3008
and       0.1670
max       0.1403
asc       0.0839
intersect 0.0266
!=        0.0208
avg       0.0062
between   0.0000


In [399]:
all_exp_toks = sorted(list(trace_scores_cnt_by_exp_tok.keys()))
all_sections = list(good_samples[0]['trace_scores'].keys())

_d = {exp_tok:
          {sect_k: trace_scores_avg_by_exp_tok[exp_tok][sect_k]['all_layers']
           for sect_k in all_sections}
      for exp_tok in all_exp_toks}

format_print_2D_dict(_d, all_k1=all_exp_toks, all_k2=all_sections)

XXXXXXXXXX	all	ans->t	all->t	ans->s	all->s	ans->p	all->p
!=        	0.0013	0.0208	0.0001	0.9509	0.8203	0.7167	0.5739
(         	0.0747	0.9583	0.9549	0.9997	0.9923	0.7309	0.6955
)         	0.0212	0.9955	0.9951	0.8845	0.8465	0.9982	0.9014
*         	0.1240	0.9826	0.9720	0.9978	0.9306	0.7433	0.5726
=         	0.1195	0.9823	0.9761	0.8742	0.8298	0.9526	0.9377
>         	0.0210	0.5678	0.4234	0.9599	0.9359	0.9659	0.9533
>=        	0.1964	0.6667	0.6667	1.0000	1.0000	0.9994	0.8787
and       	0.1812	0.1670	0.1309	0.9885	0.6761	0.9118	0.6287
as        	0.0883	0.9977	0.9969	0.9404	0.8392	0.9844	0.8963
asc       	0.0000	0.0839	0.1381	0.8809	0.8523	0.8530	0.8652
avg       	0.0018	0.0062	0.0049	1.0000	1.0000	0.9441	0.9177
between   	0.0006	0.0000	0.0000	0.9992	0.9991	0.9994	0.9984
by        	0.0757	1.0000	1.0000	1.0000	1.0000	1.0000	1.0000
count     	0.0148	0.4403	0.3509	1.0000	0.9959	0.9574	0.9006
desc      	0.0511	0.8452	0.8698	0.9805	0.9593	0.9229	0.7996
distinct  	0.0005	0.3008	0.2718	0.9965	0.99

#### By aspects

In [401]:
sample_ids_by_aspect = defaultdict(list)

for i, d in enumerate(good_samples):
    for asp_k, asp_v in d['category'].items():
        asp_str_k = f'{asp_k}={asp_v}'
        sample_ids_by_aspect[asp_str_k].append(i)

In [403]:
{(k, len(l)) for k, l in sample_ids_by_aspect.items()}

{('sql_hardness=easy', 689),
 ('sql_hardness=extra', 1402),
 ('sql_hardness=hard', 1004),
 ('sql_hardness=medium', 1843)}

In [418]:
# asp_k -> avg/cnt/sample_ids -> exp_tok -> sect_k -> layer_k -> s
all_res_by_exp_tok = {asp_k : exp6_ob_by_exp_tok(good_samples[i] for i in asp_sample_ids)
                      for asp_k, asp_sample_ids in sample_ids_by_aspect.items()}

# exp_tok -> [sect_k -> [layer_k -> s]]
avg_d = all_res_by_exp_tok['sql_hardness=extra']['avg']

all_exp_toks = sorted(list(avg_d.keys()))
all_sections = list(avg_d[all_exp_toks[0]].keys())

_d = {exp_tok:
          {sect_k: avg_d[exp_tok][sect_k]['all_layers']
           for sect_k in all_sections}
      for exp_tok in all_exp_toks}

In [419]:
format_print_2D_dict(_d, all_k1=all_exp_toks, all_k2=all_sections, col_w=6)

XXXXXXXXXXXX	all   	ans->t	all->t	ans->s	all->s	ans->p	all->p
(           	0.0960	0.9150	0.9145	0.9995	0.9683	0.5015	0.4892
)           	0.0147	0.9977	0.9928	0.9659	0.9458	0.9992	0.9988
*           	0.2047	1.0000	1.0000	0.9963	0.8843	0.9579	0.6780
=           	0.1128	0.9823	0.9769	0.7711	0.7308	0.9014	0.8690
>           	0.0111	0.5996	0.5905	0.9682	0.8337	0.8913	0.8308
as          	0.0735	0.9936	0.9916	0.9374	0.8383	0.9773	0.8630
avg         	0.0001	0.0003	0.0001	1.0000	1.0000	0.9980	0.9884
count       	0.0166	0.7992	0.7450	1.0000	1.0000	0.9679	0.9111
desc        	0.0754	0.9459	0.9655	1.0000	1.0000	0.9224	0.8590
distinct    	0.0005	0.5546	0.5506	0.9998	0.9998	0.8881	0.8270
except      	0.0000	0.2832	0.0007	0.8574	0.6992	0.5941	0.5798
from        	0.0672	0.9886	0.9863	0.9704	0.9614	0.7393	0.6001
group       	0.0083	0.3118	0.1122	0.9870	0.8914	0.8986	0.3441
having      	0.0316	0.8593	0.4232	0.7806	0.9744	0.7360	0.7438
in          	0.1954	1.0000	1.0000	1.0000	0.7843	0.9224	0.7463
intersec

In [446]:
_d = dict()

h_list = ['easy', 'medium', 'hard', 'extra']

for exp_tok in all_exp_toks:
    _d[exp_tok] = {
        h: all_res_by_exp_tok[f'sql_hardness={h}']['avg'][exp_tok]['all->t']['all_layers']
        for h in h_list
    }
    for h in h_list:
        _cnt = all_res_by_exp_tok[f'sql_hardness={h}']['cnt'][exp_tok]
        if _cnt < 3:
            # _d[exp_tok][h] = - 1 - _cnt
            _d[exp_tok][h] = np.nan

In [447]:
format_print_2D_dict(_d, sort_k1_kwargs={}, col_w=8)

XXXXXXXXXXXX	easy    	medium  	hard    	extra   
(           	1.0000  	1.0000  	0.8646  	0.9145  
)           	nan     	nan     	0.9994  	0.9928  
*           	nan     	nan     	nan     	1.0000  
=           	0.9633  	0.9855  	0.9654  	0.9769  
>           	0.4182  	0.4960  	0.2578  	0.5905  
as          	1.0000  	0.9998  	1.0000  	0.9916  
avg         	0.0046  	0.0101  	0.0000  	0.0001  
count       	0.1406  	0.1884  	0.5781  	0.7450  
desc        	0.4946  	0.6938  	0.9553  	0.9655  
distinct    	0.0554  	0.1229  	0.3660  	0.5506  
except      	nan     	nan     	0.0244  	0.0007  
from        	0.9417  	0.9906  	0.9807  	0.9863  
group       	0.5330  	0.7302  	0.3732  	0.1122  
having      	0.6407  	0.7282  	0.2358  	0.4232  
in          	nan     	nan     	0.9526  	1.0000  
intersect   	nan     	nan     	0.0005  	0.0007  
join        	nan     	nan     	0.9998  	0.8096  
limit       	nan     	0.9985  	nan     	0.3333  
min         	nan     	0.1869  	nan     	0.3291  
not         	nan    

In [424]:
# _d = dict()

# for exp_tok in all_exp_toks:
#     _d[exp_tok] = {
#         'easy': all_res_by_exp_tok['sql_hardness=easy']['cnt'][exp_tok],
#         'extra': all_res_by_exp_tok['sql_hardness=extra']['cnt'][exp_tok],
#     }

### Exp-7.0: layer skipping
- Including 7.0.[0-3]

#### Load & Check

In [567]:
# expect_type = 'table_alias'

# res_path = f'/home/yshao/Projects/rome/results/exp7_0_1_decoder_layer_skip_effect/exp=7.0.1_dev_{expect_type}.jsonl'

res_path = f'/home/yshao/Projects/rome/results/exp7_0_3_decoder_syntax_layer_skip_effect/exp=7.0.3_dev.jsonl'

with open(res_path, 'r') as f:
    all_samples = [json.loads(l) for l in f]
len(all_samples)

1034

In [568]:
total_samples = 0
n_good_samples = 0
n_too_hard = 0      # wrong answer 
n_too_easy = 0      # base - low < 0.5

In [569]:
good_samples = []
bad_samples = []

for i, ex in enumerate(all_samples):
    for d in ex['trace_results']:
        total_samples += 1

        if d.get('is_good_sample', True):
            n_good_samples += 1
            d['ex_id'] = i
            good_samples.append(d)
        elif not d['correct_prediction']:
            n_too_hard += 1
            bad_samples.append(d)
        else:
            assert d['base_score'] - d['low_score'] < 0.5
            n_too_easy += 1
            bad_samples.append(d)
            
total_samples, (n_good_samples, len(good_samples)), n_too_hard, n_too_easy, len(bad_samples), \
f'good / correct = {n_good_samples} / {n_good_samples + n_too_easy}'

(10233, (8609, 8609), 1623, 1, 1624, 'good / correct = 8609 / 8610')

#### Overall avg

In [570]:
trace_scores_avg = {sect_k : defaultdict(int) for sect_k in good_samples[0]['trace_scores'].keys()}

for d in good_samples:
    for sect_k, sect_d in d['trace_scores'].items():
        for k, v in sect_d.items():
            trace_scores_avg[sect_k][k] += v

for sect_k, sect_d in trace_scores_avg.items():
    for k, s in sect_d.items():
        sect_d[k] = s / len(good_samples)

In [571]:
# trace_scores_avg
format_print_2D_dict(trace_scores_avg, head_col_w=16, col_w=11)

XXXXXXXXXXXXXXXX	q1_layers  	q2_layers  	q3_layers  	q4_layers  	low_layers 	mid_layers 	high_layers	all_layers 
ans             	0.0037     	0.9564     	0.7786     	0.5126     	0.0035     	0.2946     	0.0064     	0.0000     
all             	0.0025     	0.8803     	0.7603     	0.5126     	0.0024     	0.2397     	0.0064     	0.0000     



### Detailed inspect of sections splitting

In [450]:
ex_id = 111
a_ex_id = 0

ex = processed_spider_dev[ex_id]
ex['text_in'], \
ex['struct_in'], \
ex['seq_out']

('What is the accelerate of the car make amc hornet sportabout (sw)?',
 '| car_1 | continents : contid , continent | countries : countryid , countryname , continent | car_makers : id , maker ( amc ) , fullname , country | model_list : modelid , maker , model ( amc ) | car_names : makeid , model ( amc ) , make ( amc hornet , amc hornet sportabout (sw) ) | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';")

In [451]:
a_ex_list = ctu.create_analysis_sample_dicts(
                mt_uskg, ex,
                subject_type='column',
                remove_struct_duplicate_nodes=True)

In [None]:
d = a_ex_list[a_ex_id]

In [None]:
d

In [455]:
_enc_toks = ctu.decode_tokens(mt_uskg.tokenizer, d['enc_tokenized']['input_ids'])

In [457]:
print(_enc_toks)

['What', 'is', 'the', 'accelerate', 'of', 'the', 'car', 'make', 'am', 'c', '', 'horn', 'e', 't', 'sport', 'about', '(', 's', 'w', ')', '?', ';', '', 'struct', 'e', 'd', 'knowledge', ':', '|', 'car', '_', '1', '|', 'continent', 's', '', ':', 'cont', 'i', 'd', '', ',', 'continent', '|', 'countries', '', ':', 'country', 'i', 'd', '', ',', 'country', 'name', '', ',', 'continent', '|', 'car', '_', 'makers', '', ':', '', 'i', 'd', '', ',', 'maker', '(', 'am', 'c', '', ')', '', ',', 'full', 'name', '', ',', 'country', '|', 'model', '_', 'list', '', ':', 'model', 'i', 'd', '', ',', 'maker', '', ',', 'model', '(', 'am', 'c', '', ')', '|', 'car', '_', 'name', 's', '', ':', 'make', 'i', 'd', '', ',', 'model', '(', 'am', 'c', '', ')', '', ',', 'make', '(', 'am', 'c', '', 'horn', 'e', 't', '', ',', 'am', 'c', '', 'horn', 'e', 't', 'sport', 'about', '(', 's', 'w', ')', '', ')', '|', 'cars', '_', 'data', '', ':', '', 'i', 'd', '', ',', '', 'mp', 'g', '', ',', '', 'cylinder', 's', '', ',', '', 'e', 'd

In [461]:
text_st, text_ed = d['text_range']
print(_enc_toks[text_st : text_ed])

['What', 'is', 'the', 'accelerate', 'of', 'the', 'car', 'make', 'am', 'c', '', 'horn', 'e', 't', 'sport', 'about', '(', 's', 'w', ')', '?']


In [460]:
struct_st, struct_ed = d['struct_range']
print(_enc_toks[struct_st : struct_ed])

['|', 'car', '_', '1', '|', 'continent', 's', '', ':', 'cont', 'i', 'd', '', ',', 'continent', '|', 'countries', '', ':', 'country', 'i', 'd', '', ',', 'country', 'name', '', ',', 'continent', '|', 'car', '_', 'makers', '', ':', '', 'i', 'd', '', ',', 'maker', '(', 'am', 'c', '', ')', '', ',', 'full', 'name', '', ',', 'country', '|', 'model', '_', 'list', '', ':', 'model', 'i', 'd', '', ',', 'maker', '', ',', 'model', '(', 'am', 'c', '', ')', '|', 'car', '_', 'name', 's', '', ':', 'make', 'i', 'd', '', ',', 'model', '(', 'am', 'c', '', ')', '', ',', 'make', '(', 'am', 'c', '', 'horn', 'e', 't', '', ',', 'am', 'c', '', 'horn', 'e', 't', 'sport', 'about', '(', 's', 'w', ')', '', ')', '|', 'cars', '_', 'data', '', ':', '', 'i', 'd', '', ',', '', 'mp', 'g', '', ',', '', 'cylinder', 's', '', ',', '', 'e', 'disp', 'l', '', ',', 'horsepower', '', ',', 'weight', '', ',', 'accelerate', '', ',', 'year']


In [462]:
for self_st, self_ed in d['self_ranges']:
    print(_enc_toks[self_st : self_ed])

['', ',', 'accelerate', '', ',']


In [463]:
for s, e in d['context_ranges']:
    print(_enc_toks[s : e])

['|', 'car', '_', '1', '|', 'continent', 's', '', ':', 'cont', 'i', 'd', '', ',', 'continent', '|', 'countries', '', ':', 'country', 'i', 'd', '', ',', 'country', 'name', '', ',', 'continent', '|', 'car', '_', 'makers', '', ':', '', 'i', 'd', '', ',', 'maker', '(', 'am', 'c', '', ')', '', ',', 'full', 'name', '', ',', 'country', '|', 'model', '_', 'list', '', ':', 'model', 'i', 'd', '', ',', 'maker', '', ',', 'model', '(', 'am', 'c', '', ')', '|', 'car', '_', 'name', 's', '', ':', 'make', 'i', 'd', '', ',', 'model', '(', 'am', 'c', '', ')', '', ',', 'make', '(', 'am', 'c', '', 'horn', 'e', 't', '', ',', 'am', 'c', '', 'horn', 'e', 't', 'sport', 'about', '(', 's', 'w', ')', '', ')', '|', 'cars', '_', 'data', '', ':', '', 'i', 'd', '', ',', '', 'mp', 'g', '', ',', '', 'cylinder', 's', '', ',', '', 'e', 'disp', 'l', '', ',', 'horsepower', '', ',', 'weight']
['year']


## Tests

### create_analysis_samples

In [508]:
# ex_id = 111
ex_id = 4755
a_ex_id = 0

# ex = processed_spider_dev[ex_id]
ex = processed_spider_train[ex_id]

ex['text_in'], \
ex['struct_in'], \
ex['seq_out']

('What is id of the staff who had a Staff Department Assignment earlier than any Clerical Staff?',
 '| department_store | addresses : address_id , address_details | staff : staff_id , staff_gender , staff_name | suppliers : supplier_id , supplier_name , supplier_phone | department_store_chain : dept_store_chain_id , dept_store_chain_name | customers : customer_id , payment_method_code , customer_code , customer_name , customer_address , customer_phone , customer_email | products : product_id , product_type_code , product_name , product_price | supplier_addresses : supplier_id , address_id , date_from , date_to | customer_addresses : customer_id , address_id , date_from , date_to | customer_orders : order_id , customer_id , order_status_code , order_date | department_stores : dept_store_id , dept_store_chain_id , store_name , store_address , store_phone , store_email | departments : department_id , dept_store_id , department_name | order_items : order_item_id , order_id , product_id | p

In [509]:
# temp test
# ex['seq_out'] = 'select year from cars_data'

In [517]:
a_ex_list = ctu.create_analysis_sample_dicts(
                mt_uskg, ex,
                subject_type='table',
                remove_struct_duplicate_nodes=True)

In [518]:
a_ex_list[a_ex_id].keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'enc_tokenized', 'text_range', 'struct_range', 'parsed_struct_in', 'alias2table', 'struct_node_ranges_dict', 'dec_prompt', 'expect', 'expect_type', 'remove_struct_duplicate_nodes', 'col2table', 'token_ranges_dict', 'node_name_ranges', 'expect_input_ranges', 'self_ranges', 'context_ranges', 'category'])

In [519]:
a_ex_list[a_ex_id]['alias2table']

{}

In [None]:
[(k, v) for k, v in a_ex_list[1].items() if k != 'rat_sql_graph']

In [None]:
[(d['dec_prompt'], d['expect'], d['node_name_ranges'], d['expect_input_ranges'], '------',\
  d['self_ranges'], d['context_ranges'],\
  d['category'], '------' * 2) for d in a_ex_list]

In [None]:
d = dict(a_ex_list[a_ex_id])
d

In [None]:
d = ctu.add_clean_prediction(mt_uskg, d)

In [None]:
d

#### parse_sql_alias2table

In [185]:
_sql = 'SELECT t2.aaa , t3.ccc FROM table_name as t1 JOIN other_table as t2 on table_name.a_a = other_table.b_a JOIN ttt as t3 on other_table.asth = ttt.asth'
ctu.parse_sql_alias2table(_sql)

{'t1': 'table_name', 't2': 'other_table', 't3': 'ttt'}

#### for syntax

In [149]:
_ex = copy.deepcopy(ex)
# _ex['seq_out'] += 'order by t1.mpg'
a_ex_list_syntax = ctu.create_syntax_analysis_sample_dicts(mt_uskg, _ex)

from 	 None 	 False
as 	 None 	 False
join 	 None 	 False
as 	 None 	 False
on 	 None 	 False
= 	 None 	 False
where 	 None 	 False
= 	 None 	 False
' 	 ' 	 True
amc 	 ' 	 True
hornet 	 ' 	 True
sportabout 	 ' 	 True
( 	 ' 	 True
sw 	 ' 	 True
) 	 ' 	 True
' 	 None 	 True


In [150]:
for a_ex in a_ex_list_syntax:
    print(a_ex['dec_prompt'], ' --> ', a_ex['expect'])

select t1.accelerate  -->  from
select t1.accelerate from cars_data  -->  as
select t1.accelerate from cars_data as t1  -->  join
select t1.accelerate from cars_data as t1 join car_names  -->  as
select t1.accelerate from cars_data as t1 join car_names as t2  -->  on
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id  -->  =
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid  -->  where
select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make  -->  =


In [156]:
a_ex['enc_tokenized']

{'input_ids': [363, 19, 8, 16845, 13, 8, 443, 143, 183, 75, 3, 6293, 15, 17, 2600, 7932, 41, 7, 210, 61, 58, 117, 3, 7593, 15, 26, 1103, 10, 1820, 443, 834, 536, 1820, 10829, 7, 3, 10, 3622, 23, 26, 3, 6, 10829, 1820, 1440, 3, 10, 684, 23, 26, 3, 6, 684, 4350, 3, 6, 10829, 1820, 443, 834, 8910, 3, 10, 3, 23, 26, 3, 6, 13762, 41, 183, 75, 3, 61, 3, 6, 423, 4350, 3, 6, 684, 1820, 825, 834, 3350, 3, 10, 825, 23, 26, 3, 6, 13762, 3, 6, 825, 41, 183, 75, 3, 61, 1820, 443, 834, 4350, 7, 3, 10, 143, 23, 26, 3, 6, 825, 41, 183, 75, 3, 61, 3, 6, 143, 41, 183, 75, 3, 6293, 15, 17, 3, 6, 183, 75, 3, 6293, 15, 17, 2600, 7932, 41, 7, 210, 61, 3, 61, 1820, 2948, 834, 6757, 3, 10, 3, 23, 26, 3, 6, 3, 1167, 122, 3, 6, 3, 12980, 7, 3, 6, 3, 15, 10475, 40, 3, 6, 28906, 3, 6, 1293, 3, 6, 16845, 3, 6, 215, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

In [159]:
col_name_ranges = a_ex['token_ranges_dict']['col_name_ranges']
# col_name_indices = [i for s, e in col_name_ranges.values() for i in range(s, e)]
for ranges in col_name_ranges.values():
    print([mt_uskg.tokenizer.decode(a_ex['enc_tokenized']['input_ids'][s:e]) for s, e in ranges])

['contid']
['continent', 'continent']
['countryid']
['countryname']
['id', 'id']
['maker ( amc )', 'maker']
['fullname']
['country']
['modelid']
['model ( amc )', 'model ( amc )']
['makeid']
['make ( amc hornet, amc hornet sportabout (sw) )']
['mpg']
['cylinders']
['edispl']
['horsepower']
['weight']
['accelerate']
['year']


In [None]:
col_name_ranges

In [160]:
table_name_ranges = a_ex['token_ranges_dict']['table_name_ranges']
# col_name_indices = [i for s, e in col_name_ranges.values() for i in range(s, e)]
for ranges in table_name_ranges.values():
    print([mt_uskg.tokenizer.decode(a_ex['enc_tokenized']['input_ids'][s:e]) for s, e in ranges])

['continents']
['countries']
['car_makers']
['model_list']
['car_names']
['cars_data']


In [161]:
table_name_ranges

defaultdict(list,
            {'continents': [(33, 35)],
             'countries': [(44, 45)],
             'car_makers': [(58, 61)],
             'model_list': [(82, 85)],
             'car_names': [(102, 106)],
             'cars_data': [(146, 149)]})

#### for all nodes

In [196]:
combined_a_ex = ctu.create_analysis_sample_dicts_all_nodes(
                    mt_uskg, ex,
                    remove_struct_duplicate_nodes=True)

In [197]:
combined_a_ex.keys()

dict_keys(['query', 'question', 'db_id', 'db_path', 'db_table_names', 'db_column_names', 'db_column_types', 'db_primary_keys', 'db_foreign_keys', 'rat_sql_graph', 'serialized_schema', 'struct_in', 'text_in', 'seq_out', 'enc_sentence', 'enc_tokenized', 'text_range', 'struct_range', 'struct_node_ranges_dict', 'dec_prompt', 'remove_struct_duplicate_nodes', 'parsed_struct_in', 'col2table', 'token_ranges_dict', 'alias2table', 'category', 'occ_cols', 'occ_tabs', 'non_occ_cols', 'non_occ_tabs', 'col_self_ranges', 'col_context_ranges', 'tab_self_ranges', 'tab_context_ranges'])

In [202]:
combined_a_ex['enc_sentence'], \
combined_a_ex['seq_out']

('What is the accelerate of the car make amc hornet sportabout (sw)?; structed knowledge: | car_1 | continents : contid , continent | countries : countryid , countryname , continent | car_makers : id , maker ( amc ) , fullname , country | model_list : modelid , maker , model ( amc ) | car_names : makeid , model ( amc ) , make ( amc hornet , amc hornet sportabout (sw) ) | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';")

In [200]:
for k, v in combined_a_ex.items():
    if 'occ' in k:
        print(k, v)

occ_cols ['accelerate', 'makeid', 'make']
occ_tabs ['cars_data', 'car_names']
non_occ_cols ['contid', 'countryid', 'countryname', 'fullname', 'country', 'modelid', 'mpg', 'cylinders', 'edispl', 'horsepower', 'weight', 'year']
non_occ_tabs ['continents', 'countries', 'car_makers', 'model_list']


In [203]:
for k, v in combined_a_ex.items():
    if '_self' in k:
        print(k, v)

col_self_ranges {'accelerate': [(176, 181)], 'makeid': [(106, 113)], 'make': [(119, 146)], 'contid': [(35, 42)], 'countryid': [(45, 52)], 'countryname': [(50, 56)], 'fullname': [(74, 80)], 'country': [(78, 82)], 'modelid': [(85, 92)], 'mpg': [(154, 161)], 'cylinders': [(159, 166)], 'edispl': [(164, 172)], 'horsepower': [(170, 175)], 'weight': [(173, 178)], 'year': [(179, 182)]}
tab_self_ranges {'cars_data': [(145, 151)], 'car_names': [(101, 108)], 'continents': [(32, 37)], 'countries': [(43, 47)], 'car_makers': [(57, 63)], 'model_list': [(81, 87)]}


In [204]:
for k, v in combined_a_ex.items():
    if '_context' in k:
        print(k, v)

col_context_ranges {'accelerate': [(28, 176), (181, 182)], 'makeid': [(28, 106), (113, 182)], 'make': [(28, 119), (146, 182)], 'contid': [(28, 35), (42, 182)], 'countryid': [(28, 45), (52, 182)], 'countryname': [(28, 50), (56, 182)], 'fullname': [(28, 74), (80, 182)], 'country': [(28, 78), (82, 182)], 'modelid': [(28, 85), (92, 182)], 'mpg': [(28, 154), (161, 182)], 'cylinders': [(28, 159), (166, 182)], 'edispl': [(28, 164), (172, 182)], 'horsepower': [(28, 170), (175, 182)], 'weight': [(28, 173), (178, 182)], 'year': [(28, 179)]}
tab_context_ranges {'cars_data': [(28, 145), (151, 182)], 'car_names': [(28, 101), (108, 182)], 'continents': [(28, 32), (37, 182)], 'countries': [(28, 43), (47, 182)], 'car_makers': [(28, 57), (63, 182)], 'model_list': [(28, 81), (87, 182)]}


### utils


#### Separate punct

In [445]:
ctu.separate_punct('select A from B as t1 join C as t2 on t1.b = t2.c where t2.N = "Fast as a shark";')

'select A from B as t1 join C as t2 on t1 . b = t2 . c where t2 . N = " Fast as a shark " ;'

#### Separate punct by offset

In [778]:
# _sql = 'SELECT t2.aaa, DISTINCT(t3.ccc), COUNT(*) FROM table_name as t1 JOIN other_table as t2 on table_name.a_a = other_table.b_a JOIN ttt as t3 on other_table.asth = ttt.asth WHERE t2.col like "%hey%" AND t3.p <= 40'.lower()
_sql = 'select t2.petid from has_pet as t1 join pets as t2 on t2.petid = t1.petid join student as t3 on t3.stuid = t1.stuid where t3.lname = ‘Smith’'
_tok_ranges = ctu.separate_punct_by_offset(_sql)
print([_sql[s:e] for s, e in _tok_ranges])

['select', 't2.', 'petid', 'from', 'has_pet', 'as', 't1', 'join', 'pets', 'as', 't2', 'on', 't2.', 'petid', '=', 't1.', 'petid', 'join', 'student', 'as', 't3', 'on', 't3.', 'stuid', '=', 't1.', 'stuid', 'where', 't3.', 'lname', '=', '‘', 'Smith', '’']


#### Build attention mask

In [470]:
_test_a_ex = {
    'enc_sentence': 'which school is good? structed_knowledge: school | school : school_name, is_good',
    'dec_prompt': 'select distinct',
    'expect': 'school_name',  # ['▁school', '_', 'name']
    'answers_t': [1,2,3],
    'answer': 'school_name',
    'text_range': [0, 5],
    'struct_range': [15, 25],
    'self_ranges': [[18, 21]],
    'context_ranges': [[15, 18], [21, 25]],
}

_test_att_masks = ctu.build_dec_cross_attention_mask(
    a_ex=_test_a_ex,
    mt=mt_uskg,
    use_self_node=True
)

In [471]:
_test_att_masks['all->t+o']

tensor([[[[False, False, False, False, False, False, False, False, False, False,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True, False, False, False, False, False,
           False, False, False, False, False,  True],
          [False, False, False, False, False, False, False, False, False, False,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True, False, False, False, False, False,
           False, False, False, False, False,  True],
          [False, False, False, False, False, False, False, False, False, False,
            True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True, False, False, False, False, False,
           False, False, False, False, False,  True],
          [False, False, False, False, False, False, False, False, False, False,
            True,  True,  Tr

In [472]:
_test_att_masks['all->t+o'] | _test_att_masks['all->s'] | _test_att_masks['all->p']

tensor([[[[True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True, True, True,
           True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True, True, True,
           True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True, True, True,
           True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True, True, True, True, True

In [473]:
(_test_att_masks['all->s'] | _test_att_masks['all->p']) & _test_att_masks['all->t+o']

tensor([[[[False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False],
          [False, False, False, False, False, False, False, False, False, False,
           False, False, Fal

#### Parse SQL alias2table

In [449]:
ctu.parse_sql_alias2table('select A from B as t1 join C as t2 on t1.b = t2.c where t2.N = "Fast as a shark";')

{'t1': 'b', 't2': 'c'}

#### Categorize tokens offset

In [775]:
# _sql = 'SELECT t2.aaa, DISTINCT(t3.ccc), COUNT(*) FROM table_name as t1 JOIN other_table as t2 on table_name.a_a = other_table.b_a JOIN ttt as t3 on other_table.asth = ttt.asth WHERE t2.col like "%hey%" AND t3.p <= 40'.lower()
# _sql = 'select date_incident_start, date_incident_end, d+e from behavior_incident join tb where incident_type_code = "NOISE" and 3 < tbcol'
_sql = 'select t2.petid from has_pet as t1 join pets as t2 on t2.petid = t1.petid join student as t3 on t3.stuid = t1.stuid where t3.lname = ‘Smith’'
_tok_ranges = ctu.separate_punct_by_offset(_sql)
print([_sql[s:e] for s, e in _tok_ranges])

['select', 't2.', 'petid', 'from', 'has_pet', 'as', 't1', 'join', 'pets', 'as', 't2', 'on', 't2.', 'petid', '=', 't1.', 'petid', 'join', 'student', 'as', 't3', 'on', 't3.', 'stuid', '=', 't1.', 'stuid', 'where', 't3.', 'lname', '=', '‘', 'Smith', '’']


In [776]:
_rgs2type = ctu.categorize_tokens_offset(_sql, _tok_ranges)

In [777]:
for _rg, _type in _rgs2type.items():
    s, e = _rg
    print(_sql[s:e], _type)

select syntax
t2. table_alias
petid column
from syntax
has_pet table
as syntax
t1 table_alias
join syntax
pets table
as syntax
t2 table_alias
on syntax
t2. table_alias
petid column
= syntax
t1. table_alias
petid column
join syntax
student table
as syntax
t3 table_alias
on syntax
t3. table_alias
stuid column
= syntax
t1. table_alias
stuid column
where syntax
t3. table_alias
lname column
= syntax
‘ val
Smith val
’ val


### trace

In [525]:
a_ex = dict(a_ex_list[a_ex_id])
a_ex = ctu.add_clean_prediction(mt_uskg, a_ex)

In [526]:
result = ctu.make_basic_result_dict(a_ex)
result

{'enc_sentence': 'What is the accelerate of the car make amc hornet sportabout (sw)?; structed knowledge: | car_1 | continents : contid , continent | countries : countryid , countryname , continent | car_makers : id , maker ( amc ) , fullname , country | model_list : modelid , maker , model ( amc ) | car_names : makeid , model ( amc ) , make ( amc hornet , amc hornet sportabout (sw) ) | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 'seq_out': "select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';",
 'dec_prompt': 'select t1.',
 'expect': 'accelerate',
 'expect_type': 'column',
 'db_id': 'car_1',
 'expect_input_ranges': [(178, 179)],
 'self_ranges': [(176, 181)],
 'expect_table': 'cars_data',
 'answer': 'acc',
 'base_score': 0.9999758005142212,
 'answers_t': [6004],
 'correct_prediction': False,
 'category': {'sql_hardness': 'medium',
  'node_role': 'select',
  'text_matc

In [527]:
enc_sentence = a_ex['enc_sentence']
dec_prompt = a_ex['dec_prompt']
expect = a_ex['expect']
answer = result['answer']
answers_t = result['answers_t']

inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    [enc_sentence] * 2,
    [dec_prompt] * 2,
    answer=expect)

_, enc_seq_len = inp['input_ids'].size()
_, dec_seq_len = inp['decoder_input_ids'].size()

text_range = a_ex['text_range']
struct_range = a_ex['struct_range']

self_ranges = a_ex['self_ranges']
context_ranges = a_ex['context_ranges']

self_tok_indices = [i for s, e in self_ranges for i in range(s, e)]
context_tok_indices = corrupt_tok_indices = [i for s, e in context_ranges for i in range(s, e)]
text_tok_indices = list(range(*text_range))
struct_tok_indices = list(range(*struct_range))

#### Repatch-uskg

In [210]:
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', mt_uskg.num_enc_layers - 1))
#                     for tnum in range(*struct_range)],
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
    tokens_to_mix=text_tok_indices,
    tokens_to_mix_individual_indices=True,
    replace=True,
).item()

In [211]:
answers_t, answer, _score

([6726], 'city', 0.8450507521629333)

In [212]:
states_to_corrupt = [(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
                for tnum in text_tok_indices]

_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
    states_to_corrupt=states_to_corrupt,
#     tokens_to_mix=corrupt_tok_indices,
#     tokens_to_mix_individual_indices=True,
    replace=True,
).item()
_score

0.8450507521629333

In [213]:
# Pair of identical input to test correctness 

_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
#                     for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_patch=[],
    states_to_patch_1st_pass=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 12))
                    for tnum in text_tok_indices],
    states_to_unpatch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 23))
                    for tnum in struct_tok_indices],
    answers_t=answers_t,
    tokens_to_mix=text_tok_indices,
    tokens_to_mix_individual_indices=True,
    tokens_to_mix_1st_pass=context_tok_indices,
    replace=True,
).item()

_score

tensor(0.9952, device='cuda:0')

In [214]:
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
#                     for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_patch=[],
    states_to_patch_1st_pass=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 12))
                    for tnum in text_tok_indices],
    states_to_unpatch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', 23))
                    for tnum in struct_tok_indices],
    answers_t=answers_t,
    states_to_corrupt=[(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
                    for tnum in text_tok_indices],
    states_to_corrupt_1st_pass=[(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
                    for tnum in context_tok_indices],
    replace=True,
).item()

_score

tensor(0.9952, device='cuda:0')

In [215]:
# Test corrupting attention 
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
#                     for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
    states_to_corrupt=[(tnum, ctu.layername_uskg(mt_uskg.model, "encoder", l, "self_attn"))
                    for tnum in text_tok_indices for l in range(mt_uskg.num_enc_layers)],
    replace=True,
).item()

_score

0.7253002524375916

In [None]:
[n for n, w in mt_uskg.model.named_parameters()]

In [216]:
vocab_probs = ctu.run_repatch_uskg_multi_token(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', mt_uskg.num_enc_layers - 1))
#                     for tnum in range(*struct_range)],
    states_to_patch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', l))
                    for tnum in self_tok_indices for l in range(mt_uskg.num_enc_layers - 1)],
    states_to_unpatch=[(tnum, ctu.layername_uskg(mt_uskg.model, 'encoder', mt_uskg.num_enc_layers - 1))
                    for tnum in self_tok_indices],
    answer_len=len(answers_t),
    tokens_to_mix=corrupt_tok_indices,
    tokens_to_mix_individual_indices=True,
    replace=True,
)

In [217]:
vocab_probs.size()

torch.Size([1, 32102])

In [218]:
torch.max(vocab_probs, dim=-1)

torch.return_types.max(
values=tensor([1.], device='cuda:0'),
indices=tensor([7634], device='cuda:0'))

In [219]:
vocab_probs[0, 7634]

tensor(1., device='cuda:0')

In [220]:
vocab_probs

tensor([[2.2642e-25, 1.2223e-15, 7.3942e-18,  ..., 9.1578e-20, 2.6884e-39,
         2.8131e-39]], device='cuda:0')

#### Attention-manipulation

In [111]:
small_inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    ['name of singer'] * 2,
    ['select'] * 2,
    answer='name')

In [112]:
bs, enc_seq_len = small_inp['input_ids'].size()
bs, dec_seq_len = small_inp['decoder_input_ids'].size()
prefix_len = mt_uskg.model.preseqlen

bs, enc_seq_len, dec_seq_len, prefix_len

(2, 4, 2, 10)

In [113]:
mix_mask = torch.zeros(1, 1, enc_seq_len, enc_seq_len + prefix_len).bool()
mix_mask[:, :, 1:3, 11:13] = True
mix_mask

tensor([[[[False, False, False, False, False, False, False, False, False, False,
           False, False, False, False],
          [False, False, False, False, False, False, False, False, False, False,
           False,  True,  True, False],
          [False, False, False, False, False, False, False, False, False, False,
           False,  True,  True, False],
          [False, False, False, False, False, False, False, False, False, False,
           False, False, False, False]]]])

In [None]:
# this checks attention logits/weights to verify that corruption is working 

corrupted_vocab_probs = ctu.run_attention_manip_uskg_multi_token(
    model=mt_uskg.model,
    inp=small_inp,
    answer_len=1,
    mix_mask_per_layer={ctu.layername_uskg(mt_uskg.model, 'encoder', l, 'self_attn') : mix_mask for l in [0, 12, 23]},
    replace=True,
    attn_corrupt_type='logits',
)

In [115]:
corrupted_vocab_probs.size()

torch.Size([1, 32102])

#### Layer-copy-uskg

In [528]:
_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
#     states_to_copy_from=states_to_copy_from,
#     states_to_copy_to=states_to_copy_to,
#     answer_len=len(answers_t),
    states_to_corrupt=[],
    replace=True,
).item()
_score

0.9999758005142212

In [529]:
states_to_copy_to = [
    (tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 17))
    for tnum in range(enc_seq_len)
]

_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
#     states_to_copy_from=states_to_copy_from,
#     states_to_copy_to=states_to_copy_to,
#     answer_len=len(answers_t),
    states_to_corrupt=states_to_copy_to,
    replace=True,
).item()
_score

1.8760302111786586e-07

In [533]:
# Implement 1: layer-copy 

states_to_copy_from = [
    (tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed"))
    for tnum in range(enc_seq_len)
]

states_to_copy_to = [
    (tnum, ctu.layername_uskg(mt_uskg.model, "encoder", 12))
    for tnum in range(enc_seq_len)
]

vocab_probs = ctu.run_layer_copy_uskg_multi_token(
    model=mt_uskg.model,
    inp=inp,
#     states_to_patch=[],
#     states_to_unpatch=[],
#     answers_t=answers_t,
    states_to_copy_from=states_to_copy_from,
    states_to_copy_to=states_to_copy_to,
    answer_len=len(answers_t),
#     states_to_corrupt=states_to_corrupt,
#     tokens_to_mix=corrupt_tok_indices,
#     tokens_to_mix_individual_indices=True,
    replace=True,
)

ans_probs = [vocab_probs[i, _t].item() for i, _t in enumerate(answers_t)]

ans_probs

[6.155608645030952e-08]

In [534]:
# Implement 2: sublayer-zero 
# Should have the same score as implement 1: yes!

states_to_zero = [
    (tnum, ctu.layername_uskg(mt_uskg.model, "encoder", l, sublayer))
    for tnum in range(enc_seq_len) for l in range(0, 13) for sublayer in ['self_attn', 'mlp']
]

_score = ctu.trace_with_repatch_uskg(
    model=mt_uskg.model,
    inp=inp,
    states_to_patch=[],
    states_to_unpatch=[],
    answers_t=answers_t,
#     states_to_copy_from=states_to_copy_from,
#     states_to_copy_to=states_to_copy_to,
#     answer_len=len(answers_t),
    states_to_corrupt=states_to_zero,
    replace=True,
    noise=0.0,
).item()
_score

6.155608645030952e-08

In [None]:
#  'pretrain_model',
#  'pretrain_model.shared',
#  'pretrain_model.encoder',
#  'pretrain_model.encoder.block',
#  'pretrain_model.encoder.block.0',
#  'pretrain_model.encoder.block.0.layer',
#  'pretrain_model.encoder.block.0.layer.0',
#  'pretrain_model.encoder.block.0.layer.0.SelfAttention',
#  'pretrain_model.encoder.block.0.layer.0.SelfAttention.q',
#  'pretrain_model.encoder.block.0.layer.0.SelfAttention.k',
#  'pretrain_model.encoder.block.0.layer.0.SelfAttention.v',
#  'pretrain_model.encoder.block.0.layer.0.SelfAttention.o',
#  'pretrain_model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias',
#  'pretrain_model.encoder.block.0.layer.0.layer_norm',
#  'pretrain_model.encoder.block.0.layer.0.dropout',
#  'pretrain_model.encoder.block.0.layer.1',
#  'pretrain_model.encoder.block.0.layer.1.DenseReluDense',
#  'pretrain_model.encoder.block.0.layer.1.DenseReluDense.wi',
#  'pretrain_model.encoder.block.0.layer.1.DenseReluDense.wo',
#  'pretrain_model.encoder.block.0.layer.1.DenseReluDense.dropout',
#  'pretrain_model.encoder.block.0.layer.1.layer_norm',
#  'pretrain_model.encoder.block.0.layer.1.dropout',

## Other observation

### Model

#### Embedding

In [604]:
embs = mt_uskg.model.pretrain_model.encoder.embed_tokens.weight
embs.size()

torch.Size([32102, 1024])

In [605]:
mt_uskg.tokenizer.batch_encode_plus(['name', 'age', 'nation', 'singer'], add_special_tokens=False)

{'input_ids': [[564], [1246], [2982], [7634]], 'attention_mask': [[1], [1], [1], [1]]}

In [613]:
embs[[564, 1246, 2982, 7634]]

tensor([[  2.8750,  17.0000,   5.3438,  ..., -15.8750,  -0.8789,   2.6562],
        [  6.3750,  13.5000, -35.7500,  ...,   2.3281,  15.7500,   3.5938],
        [  3.3750, -16.8750, -25.2500,  ...,   0.9258,  -5.8750,   4.5625],
        [  0.4629,  -3.0156, -10.1875,  ...,  -8.4375,   1.8828,   3.3750]],
       device='cuda:0')

In [687]:
embs_std, embs_mean = torch.std_mean(embs, dim=0)
embs_std, embs_mean

(tensor([12.6989, 15.4454, 16.5982,  ..., 12.3047, 12.5982,  9.7114],
        device='cuda:0'),
 tensor([-4.3081, -2.5801,  2.5176,  ..., -2.8731,  3.0601, 12.4257],
        device='cuda:0'))

In [688]:
torch.std_mean(embs_std), torch.std_mean(embs_mean)

((tensor(2.3013, device='cuda:0'), tensor(13.3482, device='cuda:0')),
 (tensor(4.8516, device='cuda:0'), tensor(0.1390, device='cuda:0')))

In [609]:
embs_norm = torch.linalg.norm(embs, ord=2, dim=1)
embs_norm

tensor([639.1667, 281.5377, 303.9011,  ..., 242.6456, 397.0485, 398.6314],
       device='cuda:0')

In [610]:
torch.std_mean(embs_norm)

(tensor(66.8115, device='cuda:0'), tensor(455.5204, device='cuda:0'))

In [639]:
tgt_wid = 564
tgt_vec = embs[tgt_wid]
# delta = 5.0 * torch.randn_like(tgt_vec)
# tgt_vec = tgt_vec + delta

embs_dist = torch.linalg.norm(embs - tgt_vec, ord=2, dim=1)
embs_dist

tensor([719.9138, 459.3446, 463.6613,  ..., 422.5326, 498.5402, 499.7012],
       device='cuda:0')

In [640]:
torch.std_mean(embs_dist), torch.min(embs_dist), torch.argmin(embs_dist)

((tensor(54.9653, device='cuda:0'), tensor(565.9248, device='cuda:0')),
 tensor(0., device='cuda:0'),
 tensor(564, device='cuda:0'))

In [641]:
sorted_dists = sorted((dist, i) for i, dist in enumerate(embs_dist.cpu().tolist()))
sorted_dists[:10]

[(0.0, 564),
 (273.6109924316406, 5570),
 (279.52862548828125, 3056),
 (315.5259704589844, 4350),
 (353.9427490234375, 23954),
 (359.2679443359375, 2650),
 (369.5257873535156, 10016),
 (405.4482421875, 2233),
 (411.2537841796875, 3),
 (411.6006164550781, 2862)]

In [642]:
ctu.decode_tokens(mt_uskg.tokenizer, [tok_id for _, tok_id in sorted_dists[:10]])

['name',
 'Name',
 'names',
 'name',
 'Name',
 'named',
 'Namen',
 'title',
 '',
 'identify']

In [650]:
tgt_wid = 564
tgt_vec = embs[tgt_wid]
delta = 5.0 * torch.randn_like(tgt_vec)
tgt_vec = tgt_vec + delta

embs_dist = torch.linalg.norm(embs - tgt_vec, ord=2, dim=1)
embs_dist

tensor([737.6123, 487.0907, 491.6216,  ..., 456.4087, 529.4299, 530.4269],
       device='cuda:0')

In [651]:
torch.std_mean(embs_dist), torch.min(embs_dist), torch.argmin(embs_dist)

((tensor(53.0791, device='cuda:0'), tensor(590.1509, device='cuda:0')),
 tensor(158.3963, device='cuda:0'),
 tensor(564, device='cuda:0'))

In [652]:
sorted_dists = sorted((dist, i) for i, dist in enumerate(embs_dist.cpu().tolist()))
sorted_dists[:10]

[(158.39627075195312, 564),
 (316.92431640625, 5570),
 (323.1575927734375, 3056),
 (357.3472595214844, 4350),
 (393.6805419921875, 23954),
 (400.9082946777344, 10016),
 (402.192138671875, 2650),
 (433.865234375, 2233),
 (441.6644592285156, 3),
 (443.2170715332031, 2862)]

In [646]:
ctu.decode_tokens(mt_uskg.tokenizer, [tok_id for _, tok_id in sorted_dists[:10]])

['name', 'names', 'Name', 'name', 'named', 'Name', 'Namen', '', 'title', '.']

#### Encoding

In [653]:
ex_id = 111
a_ex_id = 0

ex = processed_spider_dev[ex_id]
ex['text_in'], \
ex['struct_in'], \
ex['seq_out']

('What is the accelerate of the car make amc hornet sportabout (sw)?',
 '| car_1 | continents : contid , continent | countries : countryid , countryname , continent | car_makers : id , maker ( amc ) , fullname , country | model_list : modelid , maker , model ( amc ) | car_names : makeid , model ( amc ) , make ( amc hornet , amc hornet sportabout (sw) ) | cars_data : id , mpg , cylinders , edispl , horsepower , weight , accelerate , year',
 "select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';")

In [654]:
a_ex_list = ctu.create_analysis_sample_dicts(
                mt_uskg, ex,
                subject_type='column',
                remove_struct_duplicate_nodes=True)

In [758]:
enc_sentence = a_ex['enc_sentence']
dec_prompt = a_ex['dec_prompt']
expect = a_ex['expect']
answer = result['answer']
answers_t = result['answers_t']

text_range = a_ex['text_range']
struct_range = a_ex['struct_range']

self_ranges = a_ex['self_ranges']
struct_context_ranges = a_ex['context_ranges']

In [655]:
inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    [enc_sentence] * 2,
    [dec_prompt] * 2,
    answer=expect)

In [656]:
from util import nethook

In [659]:
with torch.no_grad(), nethook.TraceDict(
    mt_uskg.model,
    [ctu.layername_uskg(mt_uskg.model, "encoder", l) for l in [0, 12, 23]]
) as td:
    outputs_exp = ctu.run_model_forward_uskg(mt_uskg.model, **inp)

In [660]:
td.keys()

odict_keys(['pretrain_model.encoder.block.0', 'pretrain_model.encoder.block.12', 'pretrain_model.encoder.block.23'])

In [665]:
hidden, attn = td['pretrain_model.encoder.block.23'].output
hidden.size(), attn.size()

(torch.Size([2, 183, 1024]), torch.Size([2, 16, 183, 193]))

In [666]:
hidden[0]

tensor([[ 431.2646,   22.2026,  -46.9836,  ...,  236.4605, -101.4929,
          -13.4173],
        [ 680.9009, -155.7530,   -3.9374,  ...,  357.9876,  -74.5153,
         -248.9389],
        [ 726.4300,  -97.9734,   82.3419,  ...,  280.2939,  -92.4423,
         -208.4543],
        ...,
        [ -12.3197,  -82.2136,  -86.8491,  ...,  343.5234,   46.3294,
          -30.3550],
        [ 430.2145,  197.3166, -176.3856,  ...,  235.9048,  -84.3116,
         -104.0558],
        [ 935.1066,  134.0760,   52.9306,  ...,  -17.0239,  -42.8478,
          -13.8100]], device='cuda:0')

In [673]:
torch.std_mean(hidden[0]), torch.std_mean(torch.norm(hidden[0], dim=-1))

((tensor(3933.3369, device='cuda:0'), tensor(-22.5662, device='cuda:0')),
 (tensor(120171.7969, device='cuda:0'), tensor(38477.7578, device='cuda:0')))

In [685]:
sorted_norms = sorted((dist, i) for i, dist in enumerate(torch.norm(hidden[0], dim=-1).cpu().tolist()))
sorted_norms[-10:]

[(124030.015625, 28),
 (145044.171875, 40),
 (199170.0625, 50),
 (409527.75, 164),
 (489051.65625, 143),
 (592587.4375, 135),
 (672203.0, 78),
 (684664.5625, 170),
 (685764.875, 74),
 (766906.25, 54)]

In [683]:
tokenized_input = mt_uskg.tokenizer.tokenize(enc_sentence, add_special_tokens=True)
print(len(tokenized_input), tokenized_input)

183 ['▁What', '▁is', '▁the', '▁accelerate', '▁of', '▁the', '▁car', '▁make', '▁am', 'c', '▁', 'horn', 'e', 't', '▁sport', 'about', '▁(', 's', 'w', ')', '?', ';', '▁', 'struct', 'e', 'd', '▁knowledge', ':', '▁|', '▁car', '_', '1', '▁|', '▁continent', 's', '▁', ':', '▁cont', 'i', 'd', '▁', ',', '▁continent', '▁|', '▁countries', '▁', ':', '▁country', 'i', 'd', '▁', ',', '▁country', 'name', '▁', ',', '▁continent', '▁|', '▁car', '_', 'makers', '▁', ':', '▁', 'i', 'd', '▁', ',', '▁maker', '▁(', '▁am', 'c', '▁', ')', '▁', ',', '▁full', 'name', '▁', ',', '▁country', '▁|', '▁model', '_', 'list', '▁', ':', '▁model', 'i', 'd', '▁', ',', '▁maker', '▁', ',', '▁model', '▁(', '▁am', 'c', '▁', ')', '▁|', '▁car', '_', 'name', 's', '▁', ':', '▁make', 'i', 'd', '▁', ',', '▁model', '▁(', '▁am', 'c', '▁', ')', '▁', ',', '▁make', '▁(', '▁am', 'c', '▁', 'horn', 'e', 't', '▁', ',', '▁am', 'c', '▁', 'horn', 'e', 't', '▁sport', 'about', '▁(', 's', 'w', ')', '▁', ')', '▁|', '▁cars', '_', 'data', '▁', ':', '▁', 'i

In [686]:
print([tokenized_input[i] for _, i in sorted_norms[-20:]])

[':', ';', '▁', '▁', '▁', '▁', '</s>', 's', '▁', '_', '▁|', '▁', '▁', '▁', '▁', 'e', '▁', '▁', '▁', '▁']


In [None]:
# hidden[0, ::10, ::100]

#### Edit output

In [726]:
def untuple(x):
    return x[0] if isinstance(x, tuple) else x

def patch_rep(x, layer):
    h = untuple(x)
    if layer in corrupt_spec:
        toks_to_mix = corrupt_spec[layer]
        if toks_to_mix:
            mix_len = len(toks_to_mix)
#             noise_data = noise_fn(
#                 torch.from_numpy(prng(h.shape[0] - 1, mix_len, h.shape[2]))
#             ).to(device=h.device, dtype=h.dtype)
#             if replace:
#                 h[1:, toks_to_mix] = noise_data
#             else:
#                 h[1:, toks_to_mix] += noise_data
            h[1:, toks_to_mix] = 0

#     # If this layer is in the patch_spec, restore the uncorrupted hidden state
#     # for selected tokens.
#     toks_to_patch = patch_spec.get(layer, [])
#     toks_to_unpatch = unpatch_spec.get(layer, [])

#     for t in toks_to_patch:
#         h[1:, t] = h[0, t]
#     for t in toks_to_unpatch:
#         h[1:, t] = untuple(first_pass_trace[layer].output)[1:, t]
    return x

In [759]:
text_range

(0, 21)

In [771]:
# corrupt_spec = {ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed") : list(range(*text_range))}
corrupt_spec = {ctu.layername_uskg(mt_uskg.model, "encoder", 23) : list(range(*text_range))}

hook_layers = [ctu.layername_uskg(mt_uskg.model, "encoder", l) for l in [0, 12, 23]] + \
    [ctu.layername_uskg(mt_uskg.model, "encoder", 0, "embed")]

with torch.no_grad(), nethook.TraceDict(
    mt_uskg.model,
    layers=hook_layers,
    edit_output=patch_rep,
) as td:
    outputs_exp = ctu.run_model_forward_uskg(mt_uskg.model, **inp)

In [772]:
hidden, attn = td['pretrain_model.encoder.block.0'].output
hidden.size(), attn.size()

(torch.Size([2, 183, 1024]), torch.Size([2, 16, 183, 193]))

In [773]:
hidden[0], hidden[1]

(tensor([[ -0.2820,  -0.0777, -10.0541,  ...,  18.3919,  19.2083,   4.6788],
         [  2.3624,  -4.3468,   2.1581,  ...,   0.8777,  -2.4387,   1.6466],
         [ -0.4869,   5.1502,   1.1220,  ...,   9.9191,  -7.7408,  -6.1049],
         ...,
         [  7.9017,  -2.7019, -14.2257,  ...,   6.8594,  -0.4216,   0.3843],
         [ -2.5844,  13.8755,   1.1908,  ...,   1.7322,  -7.5010,   6.5653],
         [ 19.2494, -10.1652,   1.5651,  ...,   9.0200,  13.4675,  31.8201]],
        device='cuda:0'),
 tensor([[ -0.2820,  -0.0777, -10.0541,  ...,  18.3919,  19.2083,   4.6788],
         [  2.3624,  -4.3468,   2.1581,  ...,   0.8777,  -2.4387,   1.6466],
         [ -0.4869,   5.1502,   1.1220,  ...,   9.9191,  -7.7408,  -6.1049],
         ...,
         [  7.9017,  -2.7019, -14.2257,  ...,   6.8594,  -0.4216,   0.3843],
         [ -2.5844,  13.8755,   1.1908,  ...,   1.7322,  -7.5010,   6.5653],
         [ 19.2494, -10.1652,   1.5651,  ...,   9.0200,  13.4675,  31.8201]],
        device='cuda

In [None]:
hidden, attn = td['pretrain_model.encoder.block.12'].output
hidden[0], hidden[1]

In [774]:
hidden, attn = td['pretrain_model.encoder.block.23'].output
hidden[0], hidden[1]

(tensor([[ 431.2646,   22.2026,  -46.9836,  ...,  236.4605, -101.4929,
           -13.4173],
         [ 680.9009, -155.7530,   -3.9374,  ...,  357.9876,  -74.5153,
          -248.9389],
         [ 726.4300,  -97.9734,   82.3419,  ...,  280.2939,  -92.4423,
          -208.4543],
         ...,
         [ -12.3197,  -82.2136,  -86.8491,  ...,  343.5234,   46.3294,
           -30.3550],
         [ 430.2145,  197.3166, -176.3856,  ...,  235.9048,  -84.3116,
          -104.0558],
         [ 935.1066,  134.0760,   52.9306,  ...,  -17.0239,  -42.8478,
           -13.8100]], device='cuda:0'),
 tensor([[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
             0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
             0.0000],
         [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
             0.0000],
         ...,
         [ -12.3197,  -82.2136,  -86.8491,  ...,  343.5234,   46.3294,
           -30.3550],
         [ 430.2145,  1

In [775]:
torch.std_mean(hidden[0]), torch.std_mean(hidden[1])

((tensor(3933.3369, device='cuda:0'), tensor(-22.5662, device='cuda:0')),
 (tensor(3931.8728, device='cuda:0'), tensor(-22.0143, device='cuda:0')))

In [776]:
probs = torch.softmax(outputs_exp.logits[:, -len(answers_t):, :], dim=-1)
probs.size()

torch.Size([2, 1, 32102])

In [777]:
sorted_probs = sorted([(p, i) for i, p in enumerate(probs[0, 0].cpu().tolist())], reverse=True)
sorted_probs[:10]

[(0.9999758005142212, 6004),
 (1.38331379275769e-05, 9),
 (7.489517884096131e-06, 30819),
 (1.270908683181915e-06, 20246),
 (5.996853360556997e-07, 21007),
 (4.935201332045835e-07, 26389),
 (1.2063669885264972e-07, 11584),
 (1.1599770033399182e-07, 6500),
 (1.004895580081211e-07, 291),
 (3.185817476492048e-08, 12497)]

In [778]:
print(ctu.decode_tokens(mt_uskg.tokenizer, [i for p, i in sorted_probs[:10]]))

['acc', 'a', 'accelerating', 'inclin', 'acco', 'accelerated', 'fast', 'assi', 'ar', 'appro']


In [779]:
sorted_probs = sorted([(p, i) for i, p in enumerate(probs[1, 0].cpu().tolist())], reverse=True)
sorted_probs[:10]

[(0.8888899087905884, 6004),
 (0.10935894399881363, 9),
 (0.0005981624126434326, 30819),
 (0.0003909562074113637, 21007),
 (0.0002998432901222259, 9993),
 (8.253266423707828e-05, 291),
 (7.506454858230427e-05, 144),
 (5.3470714192371815e-05, 8010),
 (5.256106669548899e-05, 11584),
 (4.682378130382858e-05, 23)]

In [780]:
print(ctu.decode_tokens(mt_uskg.tokenizer, [i for p, i in sorted_probs[:10]]))

['acc', 'a', 'accelerating', 'acco', 'speed', 'ar', 'at', 'auto', 'fast', 'i']


## Temp

### One-time patch

In [496]:
# # expect_type = 'table_alias'
# orig_res_path = f'/home/yshao/Projects/rome/results/exp6_2_decoder_cross_attention_corruption_syntax/no_o_exp=6.2_dev_corrupt=zero.jsonl'
# add_res_path = f'/home/yshao/Projects/rome/results/exp6_2_decoder_cross_attention_corruption_syntax/exp=6.2+o_dev_corrupt=zero.jsonl'

# merge_res_path = f'/home/yshao/Projects/rome/results/exp6_2_decoder_cross_attention_corruption_syntax/exp=6.2_dev_corrupt=zero.jsonl'

In [497]:
# with open(orig_res_path, 'r') as f:
#     orig_all_samples = [json.loads(l) for l in f]
# with open(add_res_path, 'r') as f:
#     add_all_samples = [json.loads(l) for l in f]

# f = open(merge_res_path, 'w')

# for i, (orig_ex, add_ex) in enumerate(zip(orig_all_samples, add_all_samples)):
#     assert len(orig_ex['trace_results']) == len(add_ex['trace_results']), i
#     # There is randomness in the order of expected node (from set()), thus sorting here 
#     orig_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     add_ex['trace_results'].sort(key=lambda d: len(d['dec_prompt']))
#     for j, (orig_d, add_d) in enumerate(zip(orig_ex['trace_results'], add_ex['trace_results'])):
#         assert orig_d['is_good_sample'] == add_d['is_good_sample'], (i, j)
#         if not orig_d['is_good_sample']:
#             continue
            
#         # is good sample: add the new sections 
#         for k, v in add_d['trace_scores'].items():
#             if k in orig_d['trace_scores']:
#                 continue
#             orig_d['trace_scores'][k] = add_d['trace_scores'][k]
        
#     f.write(json.dumps(orig_ex, indent=None) + '\n')
    
# f.close()

### Debugging exp

In [497]:
ex = processed_spider_dev[97]
text_in = ex['text_in']
struct_in = ex['struct_in']

enc_sentence = f"{text_in}; structed knowledge: {struct_in}"
dec_prompt = "select t1.model from"
expect = "car_names"

In [498]:
inp = ctu.make_inputs_t5(
    mt_uskg.tokenizer,
    enc_sentences=[enc_sentence]*11,
    dec_prompts=[dec_prompt]*11,
    answer=expect
)

### RE

In [70]:
seq = 'aa,bb< cc  \t dd(  )ee <= ff=5 %h% "06-15".'
sep_pattern = r'\s+|\W'

all_matches = re.finditer(sep_pattern, seq)

In [71]:
all_matches = list(all_matches)
all_matches

[<re.Match object; span=(2, 3), match=','>,
 <re.Match object; span=(5, 6), match='<'>,
 <re.Match object; span=(6, 7), match=' '>,
 <re.Match object; span=(9, 13), match='  \t '>,
 <re.Match object; span=(15, 16), match='('>,
 <re.Match object; span=(16, 18), match='  '>,
 <re.Match object; span=(18, 19), match=')'>,
 <re.Match object; span=(21, 22), match=' '>,
 <re.Match object; span=(22, 23), match='<'>,
 <re.Match object; span=(23, 24), match='='>,
 <re.Match object; span=(24, 25), match=' '>,
 <re.Match object; span=(27, 28), match='='>,
 <re.Match object; span=(29, 30), match=' '>,
 <re.Match object; span=(30, 31), match='%'>,
 <re.Match object; span=(32, 33), match='%'>,
 <re.Match object; span=(33, 34), match=' '>,
 <re.Match object; span=(34, 35), match='"'>,
 <re.Match object; span=(37, 38), match='-'>,
 <re.Match object; span=(40, 41), match='"'>,
 <re.Match object; span=(41, 42), match='.'>]

In [72]:
all_matches[0].span()

(2, 3)

In [73]:
splits = [0] + [i for m in all_matches for i in m.span()] + [len(seq)]
splits = sorted(list(set(splits)))
print(splits)

[0, 2, 3, 5, 6, 7, 9, 13, 15, 16, 18, 19, 21, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 40, 41, 42]


In [74]:
st = 0
SP = ["<=", ">=", "<>", "!="]
toks = []

for s, e in zip(splits[:-1], splits[1:]):
    if not seq[s:e].strip():
        # is a whitespace
        st = e
    else:
        # is a punct
        if seq[s:s+2] in SP:
            # wait next
            continue
        toks.append(seq[st:e])
        st = e
        

In [75]:
print(toks)

['aa', ',', 'bb', '<', 'cc', 'dd', '(', ')', 'ee', '<=', 'ff', '=', '5', '%', 'h', '%', '"', '06', '-', '15', '"', '.']


### temp observation

In [269]:
att_weights_dict[23][7]['occ_cols']['prefix#0']

[0.98,
 0.74,
 0.96,
 0.82,
 0.15,
 0.36,
 0.82,
 0.03,
 0.45,
 0.18,
 0.1,
 0.82,
 0.82,
 0.17,
 0.96,
 0.96,
 0.94,
 0.17,
 0.01]

In [270]:
sample_backtrace_dict['occ_cols']

[(1, 'accelerate'),
 (1, 'makeid'),
 (1, 'make'),
 (2, 'city'),
 (2, 'airportcode'),
 (2, 'destairport'),
 (2, 'city'),
 (4, 'loser_name'),
 (5, 'first_name'),
 (5, 'middle_name'),
 (5, 'last_name'),
 (5, 'date_first_registered'),
 (6, 'birth_date'),
 (6, 'earnings'),
 (7, 'continent'),
 (7, 'continent'),
 (8, 'name'),
 (8, 'id'),
 (9, 'treatment_type_description')]

In [274]:
layer_id = 23
head_id = 8
occ_k = 'occ_cols'
sect_k = 'text'

for att_w, (ex_id, col_name) in zip(att_weights_dict[layer_id][head_id][occ_k][sect_k], sample_backtrace_dict[occ_k]):
    ex = processed_spider_dev[ex_id * 111]
    print(ex['text_in'], '-->', ex['seq_out'])
    print(f'{col_name}: {att_w:.4f}')
    print()

What is the accelerate of the car make amc hornet sportabout (sw)? --> select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';
accelerate: 0.5900

What is the accelerate of the car make amc hornet sportabout (sw)? --> select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';
makeid: 0.0900

What is the accelerate of the car make amc hornet sportabout (sw)? --> select t1.accelerate from cars_data as t1 join car_names as t2 on t1.id = t2.makeid where t2.make = 'amc hornet sportabout (sw)';
make: 0.8200

Which city has the most frequent destination airport? --> select t1.city from airports as t1 join flights as t2 on t1.airportcode = t2.destairport group by t1.city order by count(*) desc limit 1
city: 0.6100

Which city has the most frequent destination airport? --> select t1.city from airports as t1 join flights as t2 on t1.airportcode = t2.des

In [435]:
ex = processed_spider_dev[2]
ex['question']

'Show name, country, age for all singers ordered by age from the oldest to the youngest.'

In [710]:
_path = f'/home/yshao/Projects/rome/results/exp4_1_attention_weights_distribution/exp=4.1_dev.jsonl'

with open(_path, 'r') as f:
    all_dev_samples = [ujson.loads(l) for l in tqdm(f)]
len(all_dev_samples)

0it [00:00, ?it/s]

1034

In [713]:
# del all_dev_samples
# with open(_path, 'r') as f:
#     all_dev_samples = [json.loads(l) for l in tqdm(f)]
# len(all_dev_samples)

In [715]:
all_dev_samples[0]['trace_results'].keys()

dict_keys(['enc_sentence', 'seq_out', 'dec_prompt', 'db_id', 'col_self_ranges', 'col_context_ranges', 'tab_self_ranges', 'tab_context_ranges', 'category', 'occ_cols', 'non_occ_cols', 'occ_tabs', 'non_occ_tabs', 'attentions'])

In [718]:
all_dev_samples[0]['trace_results']['attentions']['col'].keys()

dict_keys(['location', 'capacity', 'highest', 'lowest', 'average', 'country', 'song_name', 'song_release_year', 'age', 'is_male', 'concert_name', 'theme', 'year'])

In [716]:
for i, ex in enumerate(all_dev_samples):
    occ_cols = ex['trace_results']['occ_cols']
    if len(set(occ_cols)) != len(occ_cols):
        print(i, occ_cols)

2 ['country', 'age', 'age']
3 ['country', 'age', 'age']
4 ['age', 'age', 'age', 'country']
5 ['age', 'age', 'age', 'country']
10 ['country', 'country']
11 ['country', 'country']
12 ['song_name', 'age', 'age']
13 ['song_name', 'age', 'age']
17 ['capacity', 'capacity']
20 ['year', 'year']
21 ['year', 'year']
26 ['year', 'year']
27 ['year', 'year']
30 ['country', 'age', 'country', 'age']
41 ['location', 'year', 'location', 'year']
42 ['location', 'year', 'location', 'year']
49 ['weight', 'pettype', 'pettype']
50 ['weight', 'pettype', 'pettype']
57 ['fname', 'pettype', 'pettype']
58 ['fname', 'pettype', 'pettype']
59 ['fname', 'pettype', 'fname', 'pettype']
60 ['fname', 'pettype', 'fname', 'pettype']
65 ['fname', 'age', 'pettype', 'pettype']
66 ['fname', 'age', 'pettype', 'pettype']
71 ['pet_age', 'pet_age', 'pettype', 'pettype']
72 ['pet_age', 'pet_age', 'pettype', 'pettype']
73 ['weight', 'pettype', 'pettype']
74 ['weight', 'pettype', 'pettype']
89 ['contid', 'contid', 'contid']
90 ['con

### other temp

In [311]:
mt_uskg.tokenizer.tokenize('school_name')

['▁school', '_', 'name']

In [236]:
mt_uskg.tokenizer.tokenize('cylinder ')

['▁', 'cylinder']

In [233]:
mt_uskg.tokenizer.tokenize('cylinder xa')

['▁', 'cylinder', '▁', 'x', 'a']

In [240]:
mt_uskg.tokenizer.tokenize('structed_input : a | b')

['▁', 'struct', 'e', 'd', '_', 'in', 'put', '▁', ':', '▁', 'a', '▁|', '▁', 'b']

In [242]:
_sql = 'SELECT t2.aaa , COUNT(distinct t1.name) FROM cars_data as t1 JOIN models as t2 on cars_data.a_a = models.b_a'

mt_uskg.tokenizer.tokenize(_sql)

['▁',
 'SEL',
 'ECT',
 '▁',
 't',
 '2.',
 'a',
 'a',
 'a',
 '▁',
 ',',
 '▁CO',
 'UNT',
 '(',
 'distin',
 'c',
 't',
 '▁',
 't',
 '1.',
 'name',
 ')',
 '▁FROM',
 '▁cars',
 '_',
 'data',
 '▁as',
 '▁',
 't',
 '1',
 '▁',
 'JO',
 'IN',
 '▁models',
 '▁as',
 '▁',
 't',
 '2',
 '▁on',
 '▁cars',
 '_',
 'data',
 '.',
 'a',
 '_',
 'a',
 '▁=',
 '▁models',
 '.',
 'b',
 '_',
 'a']

In [342]:
N_layers = mt_uskg.num_dec_layers

layers_range_dict = {
    'q1_layers': range(0, N_layers // 4),
    'q2_layers': range(N_layers // 4, N_layers // 2),
    'q3_layers': range(N_layers // 2, N_layers * 3 // 4),
    'q4_layers': range(N_layers * 3 // 4, N_layers),
    'low_layers': range(0, N_layers // 2),
    'mid_layers': range(N_layers // 4, N_layers * 3 // 4),
    'high_layers': range(N_layers // 2, N_layers),
    'all_layers': range(N_layers),
}

layers_range_dict

{'q1_layers': range(0, 6),
 'q2_layers': range(6, 12),
 'q3_layers': range(12, 18),
 'q4_layers': range(18, 24),
 'low_layers': range(0, 12),
 'mid_layers': range(6, 18),
 'high_layers': range(12, 24),
 'all_layers': range(0, 24)}

In [391]:
d = {
    'a': {
        '1': 1.0,
        '2': 2.0,
    },
    'b': {
        '1': 2.0,
        '2': 1.0,
    },
}
reverse_2D_dict(d)

defaultdict(<function __main__.reverse_2D_dict.<locals>.<lambda>()>,
            {'1': defaultdict(float, {'a': 1.0, 'b': 2.0}),
             '2': defaultdict(float, {'a': 2.0, 'b': 1.0})})

In [None]:
{'a': 1} + {'b': 2}

In [811]:
mt_uskg.tokenizer.tokenize('(distinct)')

['▁(', 'distin', 'c', 't', ')']

## (placeholder)