In [1]:
import json
import os, sys
from sys import modules
import re
import _jsonnet
from tqdm.notebook import tqdm
import spacy
import networkx as nx
import numpy as np
import random
import importlib
from copy import deepcopy
import editdistance
import datetime
from collections import Counter, defaultdict
import sqlite3
import time
import pickle

from sklearn.linear_model import LogisticRegression

from nltk.translate.bleu_score import corpus_bleu
from nltk.tokenize.treebank import TreebankWordDetokenizer

import torch

from ratsql.utils import registry, batched_sequence
from ratsql.commands.infer import Inferer
from ratsql.datasets.spider import SpiderDataset, SpiderItem, Column, Table, Schema

from language.xsp.data_preprocessing import spider_preprocessing, wikisql_preprocessing, michigan_preprocessing

from sdr_analysis.helpers.general_helpers import collect_link_prediction_samples, db_dict_to_general_fmt
from sdr_analysis.helpers.ratsql_helpers import general_fmt_dict_to_ratsql_schema, extract_probing_samples_link_prediction_new
from sdr_analysis.helpers.legacy import db_dict_to_ratsql_schema, extract_probing_samples_link_prediction



## Read data - Schema

### Original way (spider)

In [None]:
root_dir='/Users/mac/Desktop/syt/Deep-Learning/Repos/rat-sql'
exp_config_path='/Users/mac/Desktop/syt/Deep-Learning/Repos/rat-sql/experiments/spider-glove-ASR-run.jsonnet'
model_dir='/Users/mac/Desktop/syt/Deep-Learning/Repos/rat-sql/logdir/glove_ASR_run/ASR,bs=20,lr=7.4e-04,end_lr=0e0,att=0'
checkpoint_step=40000

In [None]:
exp_config = json.loads(_jsonnet.evaluate_file(exp_config_path))

model_config_path = os.path.join(root_dir, exp_config["model_config"])
model_config_args = exp_config.get("model_config_args")

infer_config = json.loads(_jsonnet.evaluate_file(model_config_path, tla_codes={'args': json.dumps(model_config_args)}))


In [None]:
infer_config

In [None]:
dataset = SpiderDataset(
    db_path='data/spider/database',
    paths=['data/spider/my/train/train_asr_amazon.json'],
    tables_paths=['data/spider/tables.json'],
)

In [None]:
dataset[0].schema.__dict__

### New way (for XSP)

In [4]:
## assume that re.split(r'[_ ]', orig_col_name) == col_name

In [5]:
xsp_data_dir = "/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data"

spider_tables_path = os.path.join(xsp_data_dir, 'spider', 'tables.json')

spider_dbs_dict = spider_preprocessing.load_spider_tables(spider_tables_path)

In [6]:
spider_dbs_dict.keys()

dict_keys(['perpetrator', 'college_2', 'flight_company', 'icfp_1', 'body_builder', 'storm_record', 'pilot_record', 'race_track', 'academic', 'department_store', 'music_4', 'insurance_fnol', 'cinema', 'decoration_competition', 'phone_market', 'store_product', 'assets_maintenance', 'student_assessment', 'dog_kennels', 'music_1', 'company_employee', 'farm', 'solvency_ii', 'city_record', 'swimming', 'flight_2', 'election', 'manufactory_1', 'debate', 'network_2', 'local_govt_in_alabama', 'climbing', 'e_learning', 'scientist_1', 'ship_1', 'entertainment_awards', 'allergy_1', 'imdb', 'products_for_hire', 'candidate_poll', 'chinook_1', 'flight_4', 'pets_1', 'dorm_1', 'journal_committee', 'flight_1', 'medicine_enzyme_interaction', 'local_govt_and_lot', 'station_weather', 'shop_membership', 'driving_school', 'concert_singer', 'music_2', 'sports_competition', 'railway', 'inn_1', 'museum_visit', 'browser_web', 'baseball_1', 'architecture', 'csu_1', 'tracking_orders', 'insurance_policies', 'gas_com

In [22]:
# def _get_clean_name(orig_name):
#     return ' '.join(re.split(r'[_ ]', orig_name.lower()))

# _get_clean_name('WWW3_ID'), _get_clean_name('Home Town'), _get_clean_name('W/L')

('www3 id', 'home town', 'w/l')

In [23]:
db_id = 'imdb'
db_dict = spider_dbs_dict[db_id]
db_dict

{'actor': [{'field name': 'aid',
   'is primary key': True,
   'is foreign key': True,
   'type': 'number'},
  {'field name': 'gender',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'name',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'nationality',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'birth_city',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'birth_year',
   'is primary key': False,
   'is foreign key': False,
   'type': 'number'}],
 'copyright': [{'field name': 'id',
   'is primary key': True,
   'is foreign key': False,
   'type': 'number'},
  {'field name': 'msid',
   'is primary key': False,
   'is foreign key': True,
   'type': 'number'},
  {'field name': 'cid',
   'is primary key': False,
   'is foreign key': False,
   'type': 'number'}],
 'cast': [{'field name': 'id',
   'is primar

In [24]:
# test_db_dict = dict()
# for table_name, table_cols in db_dict.items():
#     test_db_dict[table_name] = table_cols[:2]
# test_db_dict

In [None]:
test_schema = db_dict_to_ratsql_schema(db_dict, db_id,
                                       sqlite_path=f"/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/database/{db_id}/{db_id}.sqlite",
                                       rigorous_foreign_key=True,
                                       debug=True)
test_schema

In [None]:
test_general_fmt_dict = db_dict_to_general_fmt(db_dict, db_id,
                                               sqlite_path=f"/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/database/{db_id}/{db_id}.sqlite",
                                               rigorous_foreign_key=True,
                                               debug=True)
test_general_fmt_dict

In [None]:
test_schema_2 = general_fmt_dict_to_ratsql_schema(test_general_fmt_dict,
                                                  debug=True)
test_schema_2

In [27]:
str(test_schema.tables) == str(test_schema_2.tables), \
str(test_schema.columns) == str(test_schema_2.columns)

(True, True)

## Load model
- TODO: use unified way to load schema

In [33]:
def Load_Rat_sql(root_dir,
                 exp_config_path,
                 model_dir,
                 checkpoint_step=40000):

    exp_config = json.loads(_jsonnet.evaluate_file(exp_config_path))
    
    model_config_path = os.path.join(root_dir, exp_config["model_config"])
    model_config_args = exp_config.get("model_config_args")
    
    infer_config = json.loads(_jsonnet.evaluate_file(model_config_path, tla_codes={'args': json.dumps(model_config_args)}))

    inferer = Inferer(infer_config)
    inferer.device = torch.device("cpu")
    model = inferer.load_model(model_dir, checkpoint_step)
    
#     dataset = registry.construct('dataset', inferer.config['data']['val'])
#     for _, schema in dataset.schemas.items():
#         model.preproc.enc_preproc._preprocess_schema(schema)
    
    _ret_dict = {
        'model': model,
        'inferer': inferer,
    }
    
    return _ret_dict
    

In [34]:
rat_sql_model_dict = Load_Rat_sql(root_dir='/Users/mac/Desktop/syt/Deep-Learning/Repos/rat-sql',
                                  exp_config_path='/Users/mac/Desktop/syt/Deep-Learning/Repos/rat-sql/experiments/spider-glove-run.jsonnet',
                                  model_dir='/Users/mac/Desktop/syt/Deep-Learning/Repos/rat-sql/logdir/glove_run/bs=20,lr=7.4e-04,end_lr=0e0,att=0',
                                  checkpoint_step=40000)



Loading model from /Users/mac/Desktop/syt/Deep-Learning/Repos/rat-sql/logdir/glove_run/bs=20,lr=7.4e-04,end_lr=0e0,att=0/model_checkpoint-00040000


In [None]:
rat_sql_model_dict['model']

In [36]:
def Question(q, db_schema, model_dict):
    model = model_dict['model']
    inferer = model_dict['inferer']
    
    data_item = SpiderItem(
        text=None,  # intentionally None -- should be ignored when the tokenizer is set correctly
        code=None,
        schema=db_schema,
        orig_schema=db_schema.orig,
        orig={"question": q}
    )
    
    model.preproc.clear_items()
    enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)
    preproc_data = enc_input, None
    
    with torch.no_grad():
        return inferer._infer_one(model, data_item, preproc_data, beam_size=1, use_heuristic=True)
    

In [None]:
Question('name of people', test_schema, rat_sql_model_dict)

## Infer on WikiSQL

In [17]:
wikisql_tables_path = os.path.join(xsp_data_dir, 'wikisql', 'dev.tables.jsonl')
wikisql_dataset_path = os.path.join(xsp_data_dir, 'wikisql', 'dev.jsonl')
wikisql_sqlite_path = os.path.join(xsp_data_dir, 'wikisql', 'dev.db')

In [18]:
wikisql_dbs_dict = wikisql_preprocessing.load_wikisql_tables(wikisql_tables_path)

with open(wikisql_dataset_path, 'r') as f:
    wikisql_dataset = [json.loads(l) for l in f]

len(wikisql_dbs_dict), len(wikisql_dataset)

(2716, 8421)

In [None]:
## too many duplicated sqlite connections, not ok
# wikisql_schemas_dict = dict()
# for db_id, db_dict in wikisql_dbs_dict.items():
#     wikisql_schemas_dict[db_id] = db_dict_to_ratsql_schema(db_dict, db_id, wikisql_sqlite_path)

In [19]:
wikisql_dataset[0]

{'phase': 1,
 'table_id': '1-10015132-11',
 'question': 'What position does the player who played for butler cc (ks) play?',
 'sql': {'sel': 3, 'conds': [[5, 0, 'Butler CC (KS)']], 'agg': 0}}

In [20]:
sample = wikisql_dataset[0]
db_id = sample['table_id']
db_dict = wikisql_dbs_dict[db_id]
db_schema = db_dict_to_ratsql_schema(db_dict, db_id, wikisql_sqlite_path)

In [22]:
Question(q=sample['question'], db_schema=db_schema, model_dict=rat_sql_model_dict)

[{'orig_question': 'What position does the player who played for butler cc (ks) play?',
  'model_output': {'_type': 'sql',
   'select': {'_type': 'select',
    'is_distinct': False,
    'aggs': [{'_type': 'agg',
      'agg_id': {'_type': 'NoneAggOp'},
      'val_unit': {'_type': 'Column',
       'col_unit1': {'_type': 'col_unit',
        'agg_id': {'_type': 'NoneAggOp'},
        'col_id': 4,
        'is_distinct': False}}},
     {'_type': 'agg',
      'agg_id': {'_type': 'NoneAggOp'},
      'val_unit': {'_type': 'Column',
       'col_unit1': {'_type': 'col_unit',
        'agg_id': {'_type': 'NoneAggOp'},
        'col_id': 1,
        'is_distinct': False}}}]},
   'sql_where': {'_type': 'sql_where',
    'where': {'_type': 'Eq',
     'val_unit': {'_type': 'Column',
      'col_unit1': {'_type': 'col_unit',
       'agg_id': {'_type': 'NoneAggOp'},
       'col_id': 1,
       'is_distinct': False}},
     'val1': {'_type': 'Terminal'}}},
   'sql_groupby': {'_type': 'sql_groupby'},
   'sql_orde

## Infer on Michigan

In [24]:
atis_table_path = os.path.join(xsp_data_dir, 'atis', 'atis_schema.csv')
atis_dataset_path = os.path.join(xsp_data_dir, 'atis', 'atis.json')
atis_sqlite_path = os.path.join(xsp_data_dir, 'databases', 'atis.db')

In [25]:
atis_db_dict = michigan_preprocessing.read_schema(atis_table_path)

with open(atis_dataset_path, 'r') as f:
    atis_dataset = json.load(f)

len(atis_db_dict), len(atis_dataset)

(25, 947)

In [27]:
atis_schema = db_dict_to_ratsql_schema(atis_db_dict, 'atis', atis_sqlite_path)



In [26]:
atis_dataset[0]

{'comments': [],
 'old-name': '',
 'query-split': 'train',
 'sentences': [{'text': 'list all the flights that arrive at airport_code0 from various cities',
   'question-split': 'train',
   'variables': {'airport_code0': 'MKE'}},
  {'text': 'what flights from any city land at airport_code0',
   'question-split': 'train',
   'variables': {'airport_code0': 'MKE'}},
  {'text': 'show me the flights into airport_code0',
   'question-split': 'train',
   'variables': {'airport_code0': 'DAL'}},
  {'text': 'show me the flights arriving at airport_code0',
   'question-split': 'train',
   'variables': {'airport_code0': 'DAL'}},
  {'text': 'list all the flights that arrive at airport_code0',
   'question-split': 'train',
   'variables': {'airport_code0': 'MKE'}},
  {'text': 'list all the arriving flights at airport_code0',
   'question-split': 'train',
   'variables': {'airport_code0': 'MKE'}},
  {'text': 'what flights land at airport_code0',
   'question-split': 'train',
   'variables': {'airport_

In [31]:
sample = atis_dataset[0]
preds = []
for question_dict in sample['sentences'][:3]:
    q = question_dict['text']
    for k, v in question_dict['variables'].items():
        q = q.replace(k, v)
    
    pred = Question(q=q, db_schema=atis_schema, model_dict=rat_sql_model_dict)
    preds.append(pred)

In [32]:
preds

[[{'orig_question': 'list all the flights that arrive at MKE from various cities',
   'model_output': {'_type': 'sql',
    'select': {'_type': 'select',
     'is_distinct': False,
     'aggs': [{'_type': 'agg',
       'agg_id': {'_type': 'NoneAggOp'},
       'val_unit': {'_type': 'Column',
        'col_unit1': {'_type': 'col_unit',
         'agg_id': {'_type': 'NoneAggOp'},
         'col_id': 81,
         'is_distinct': False}}}]},
    'sql_where': {'_type': 'sql_where',
     'where': {'_type': 'Eq',
      'val_unit': {'_type': 'Column',
       'col_unit1': {'_type': 'col_unit',
        'agg_id': {'_type': 'NoneAggOp'},
        'col_id': 88,
        'is_distinct': False}},
      'val1': {'_type': 'Terminal'}}},
    'sql_groupby': {'_type': 'sql_groupby'},
    'sql_orderby': {'_type': 'sql_orderby', 'limit': False},
    'sql_ieu': {'_type': 'sql_ieu'},
    'from': {'_type': 'from',
     'table_units': [{'_type': 'Table', 'table_id': 14}]}},
   'inferred_code': "SELECT FLIGHT.FLIGHT_DAYS

## Get rat-sql graph

In [60]:
def get_rat_sql_graph(question, db_schema, model):
    """
    Args:
        question (str)
        db_schema (ratsql.datasets.spider.Schema): output from db_dict_to_ratsql_schema()
        model (ratsql.models.EncDec)
    
    Return:
        rat_sql_graph_dict: Dict[
            "nodes" (List[str]): the name of nodes (question toks, columns, tables)
            "relations" (np.array): the integer relation matrix, shape = (N, N) where N = #nodes
            "relation_id2name" (Dict[int, ?]): translates the integer relation to readable name (str or tuple)
        ]
    """
    
    data_item = SpiderItem(
        text=None,  # intentionally None -- should be ignored when the tokenizer is set correctly
        code=None,
        schema=db_schema,
        orig_schema=db_schema.orig,
        orig={"question": question}
    )

    model.preproc.clear_items()
    enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)

    desc = enc_input
    
    ## Adapted from SpiderEncoderV2.forward
    q_enc, _ = model.encoder.question_encoder([[desc['question']]])
    c_enc, c_boundaries = model.encoder.column_encoder([desc['columns']])
    t_enc, t_boundaries = model.encoder.table_encoder([desc['tables']])
    
    ## Adapted from RelationalTransformerUpdate.forward
    enc = batched_sequence.PackedSequencePlus.cat_seqs((q_enc, c_enc, t_enc))

    q_enc_lengths = list(q_enc.orig_lengths())
    c_enc_lengths = list(c_enc.orig_lengths())
    t_enc_lengths = list(t_enc.orig_lengths())
    enc_lengths = list(enc.orig_lengths())
    max_enc_length = max(enc_lengths)

    enc_length = enc_lengths[0]
    relations = model.encoder.encs_update.compute_relations(
        desc,
        enc_length,
        q_enc_lengths[0],
        c_enc_lengths[0],
        c_boundaries[0],
        t_boundaries[0])
    
    ## Collect nodes 
    nodes = []

    nodes.extend(enc_input['question'])

    for c_id, c_toks in enumerate(enc_input['columns']):
        c_name = '_'.join(c_toks[1:])
        t_id = enc_input['column_to_table'][str(c_id)]
        if t_id is None:
            t_name = 'NONE'
        else:
            t_toks = enc_input['tables'][t_id]
            t_name = '_'.join(t_toks)
        c_save_name = f'<C>{t_name}::{c_name}'
        nodes.append(c_save_name)

    for t_toks in enc_input['tables']:
        nodes.append('<T>' + '_'.join(t_toks))
    
    ## Get relation_id2name (a constant, just passing for convenience)
    relation_id2name = {v : k for k, v in model.encoder.encs_update.relation_ids.items()}
    
    return {
        'nodes': nodes,
        'relations': relations,
        'relation_id2name': relation_id2name,
        'q_nodes_orig': enc_input['question_for_copying'],
    }
    

In [62]:
db_id = 'store_product'
db_dict = spider_dbs_dict[db_id]

general_fmt_dict = db_dict_to_general_fmt(db_dict,
                                          db_id,
                                          sqlite_path=f"/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/database/{db_id}/{db_id}.sqlite")

db_schema = general_fmt_dict_to_ratsql_schema(general_fmt_dict)

# db_schema = db_dict_to_ratsql_schema(db_dict,
#                                      db_id,
#                                      sqlite_path=f"/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/database/{db_id}/{db_id}.sqlite")

question = "How many singers are from France?"

model = rat_sql_model_dict['model']

rat_sql_graph_dict = get_rat_sql_graph(question=question, db_schema=db_schema, model=model)

In [63]:
rat_sql_graph_dict

{'nodes': ['how',
  'many',
  'singer',
  'be',
  'from',
  'france',
  '?',
  '<C>NONE::*',
  '<C>product::product_id',
  '<C>product::product',
  '<C>product::dimension',
  '<C>product::dpi',
  '<C>product::page_per_minute_color',
  '<C>product::max_page_size',
  '<C>product::interface',
  '<C>store::store_id',
  '<C>store::store_name',
  '<C>store::type',
  '<C>store::area_size',
  '<C>store::number_of_product_category',
  '<C>store::rank',
  '<C>district::district_id',
  '<C>district::district_name',
  '<C>district::headquarter_city',
  '<C>district::city_population',
  '<C>district::city_area',
  '<C>store_product::store_id',
  '<C>store_product::product_id',
  '<C>store_district::store_id',
  '<C>store_district::district_id',
  '<T>product',
  '<T>store',
  '<T>district',
  '<T>store_product',
  '<T>store_district'],
 'relations': array([[ 2,  3,  4, ...,  6,  6,  6],
        [ 1,  2,  3, ...,  6,  6,  6],
        [ 0,  1,  2, ...,  6,  6,  6],
        ...,
        [22, 22, 22, .

In [None]:
rat_sql_graph_dict['relation_id2name']

### Save graph to dataset file

In [71]:
orig_ds = 'train'
orig_dataset_file = f"/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/{orig_ds}.json"
orig_tables_file = "/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/tables.json"

output_dir = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/spider"
output_dataset_file = os.path.join(output_dir, f"{orig_ds}+ratsql_graph.json")

os.makedirs(output_dir, exist_ok=True)

In [72]:
spider_dbs_dict = spider_preprocessing.load_spider_tables(orig_tables_file)
len(spider_dbs_dict)

166

In [73]:
spider_db_schemas_dict = dict()
for db_id, db_dict in spider_dbs_dict.items():
#     db_schema = db_dict_to_ratsql_schema(db_dict, db_id,
#                                          sqlite_path=os.path.join(xsp_data_dir, f"spider/database/{db_id}/{db_id}.sqlite"),
#                                          rigorous_foreign_key=True)
    general_fmt_dict = db_dict_to_general_fmt(db_dict, db_id,
                                              sqlite_path=os.path.join(xsp_data_dir, f"spider/database/{db_id}/{db_id}.sqlite"),
                                              rigorous_foreign_key=True)
    db_schema = general_fmt_dict_to_ratsql_schema(general_fmt_dict)

    spider_db_schemas_dict[db_id] = db_schema
len(spider_db_schemas_dict)



166

In [74]:
with open(orig_dataset_file, 'r') as f:
    orig_dataset = json.load(f)

len(orig_dataset)

7000

In [75]:
model = rat_sql_model_dict['model']

for d in tqdm(orig_dataset):
    db_id = d['db_id']
    db_schema = spider_db_schemas_dict[db_id]
    question = d['question']

    # get relation matrix
    graph_dict = get_rat_sql_graph(question=question, db_schema=db_schema, model=model)
    nodes = graph_dict['nodes']
    q_nodes_orig = graph_dict['q_nodes_orig']
    relations = json.dumps(graph_dict['relations'].tolist(), indent=None)  # dump to a line to save space in json
    
    d['rat_sql_graph'] = {
        'nodes': nodes,
        'q_nodes_orig': q_nodes_orig,
        'relations': relations
    }
    
    time.sleep(0.2)

HBox(children=(FloatProgress(value=0.0, max=7000.0), HTML(value='')))




In [76]:
with open(output_dataset_file, 'w') as f:
    json.dump(orig_dataset, f, indent=2)

## Get rat-sql encoding

In [7]:
def get_rat_sql_encoder_state(question, db_schema, model):
    """
    Args:
        question (str)
        db_schema (ratsql.datasets.spider.Schema): output from db_dict_to_ratsql_schema()
        model (ratsql.models.EncDec)
    
    Return:
        rat_sql_encoder_state (EncoderState): check SpiderEncoderV2.forward return object
    """
    
    data_item = SpiderItem(
        text=None,  # intentionally None -- should be ignored when the tokenizer is set correctly
        code=None,
        schema=db_schema,
        orig_schema=db_schema.orig,
        orig={"question": question}
    )

    model.preproc.clear_items()
    enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)
    
    ## Adapted from EncDec.begin_inference
    with torch.no_grad():
        if getattr(model.encoder, 'batched'):
            enc_state, = model.encoder([enc_input])
        else:
            enc_state = model.encoder(enc_input)
            
    return enc_state
    

In [None]:
db_id = 'concert_singer'
db_dict = spider_dbs_dict[db_id]
db_schema = db_dict_to_ratsql_schema(db_dict,
                                     db_id,
                                     sqlite_path=f"/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/database/{db_id}/{db_id}.sqlite")

question = "How many singers are from France?"

model = rat_sql_model_dict['model']

rat_sql_encoder_state = get_rat_sql_encoder_state(question=question, db_schema=db_schema, model=model)

In [51]:
rat_sql_encoder_state

SpiderEncoderState(state=None, memory=tensor([[[ 0.0853, -0.5688, -0.0971,  ..., -0.1367, -0.1160,  0.3701],
         [ 0.1361, -0.4921, -0.1026,  ..., -0.1142, -0.1253,  0.4126],
         [ 0.4294, -0.2658, -0.3613,  ..., -0.1438,  0.0771, -0.4309],
         ...,
         [ 0.3001, -0.3542, -0.3583,  ..., -0.4167, -0.2776, -0.5714],
         [ 0.7069,  0.0588, -0.6766,  ..., -0.2973,  0.4049,  0.3688],
         [ 0.6961,  0.1930, -0.7123,  ..., -0.1825,  0.3640, -0.0304]]]), question_memory=tensor([[[ 0.0853, -0.5688, -0.0971,  ..., -0.1367, -0.1160,  0.3701],
         [ 0.1361, -0.4921, -0.1026,  ..., -0.1142, -0.1253,  0.4126],
         [ 0.4294, -0.2658, -0.3613,  ..., -0.1438,  0.0771, -0.4309],
         ...,
         [-0.1439, -0.3311,  0.0520,  ...,  0.0463,  0.2806,  0.0438],
         [ 0.1652, -1.5329,  0.6279,  ..., -0.1972, -0.0643, -0.1945],
         [-0.1282, -0.3162, -0.1159,  ..., -0.0256,  0.2365,  0.0471]]]), schema_memory=tensor([[[ 0.1525, -0.2725,  0.4607,  ..., -0.

In [52]:
rat_sql_encoder_state.memory.size(), rat_sql_encoder_state.schema_memory.size()

(torch.Size([1, 33, 256]), torch.Size([1, 26, 256]))

## Probing: link prediction

In [3]:
xsp_data_dir = "/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data"
probing_data_dir = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/probing/text2sql/link_prediction/spider/ratsql"
probing_exp_dir = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/experiments/probing/text2sql/link_prediction/spider/ratsql"

os.makedirs(probing_data_dir, exist_ok=True)
os.makedirs(probing_exp_dir, exist_ok=True)

### Data collection

In [18]:
orig_ds = 'dev'
spider_data_path = os.path.join(xsp_data_dir, 'spider', f'{orig_ds}.json')
spider_tables_path = os.path.join(xsp_data_dir, 'spider', 'tables.json')

In [19]:
spider_dbs_dict = spider_preprocessing.load_spider_tables(spider_tables_path)

with open(spider_data_path, 'r') as f:
    orig_dataset = json.load(f)

len(orig_dataset), len(spider_dbs_dict)

(1034, 166)

In [39]:
spider_db_schemas_dict = dict()
for db_id, db_dict in spider_dbs_dict.items():
    db_schema = db_dict_to_ratsql_schema(db_dict, db_id,
                                         sqlite_path=os.path.join(xsp_data_dir, f"spider/database/{db_id}/{db_id}.sqlite"),
                                         rigorous_foreign_key=True)
    spider_db_schemas_dict[db_id] = db_schema
len(spider_db_schemas_dict)

** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_g

** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_general_fmt() and general_fmt_dict_to_ratsql_schema()
** db_dict_to_ratsql_schema: Deprecated, should use db_dict_to_g

166

In [None]:
orig_dataset[0]

In [409]:
# Test vector_pair_features()
vector_pair_features(np.array([-1, 0, 1]), np.array([1, 2, 3]))

array([-1,  0,  1,  1,  2,  3, -1,  0,  3])

In [410]:
# Test extract_probing_samples_link_prediction()
d = orig_dataset[228]

X, y, pos = extract_probing_samples_link_prediction(dataset_sample=d,
                                                   db_schemas_dict=spider_db_schemas_dict,
                                                   model=rat_sql_model_dict['model'],
                                                   max_rel_occ=None,
                                                   debug=True)
len(X), len(y), len(pos)

Sampled output idx: 486
i = 21, j = 3
Nodes: <T>publication, who
Relation: 22 (tq_default)
Repr vectors:
[ 0.66455     0.46284533 -0.73639905] ... [-0.23395467  0.34225646  0.00472841]
[-0.03710726  0.37979963 -0.2971748 ] ... [-0.1910254   0.2258268   0.03110231]
Combined vector:
[ 0.66455     0.46284533 -0.73639905] ... [0.04469129 0.07729068 0.00014706]
Label:
22


(529, 529, 529)

In [411]:
# collect probing dataset (500 samples per train/test as planned)

n_samples = 500

random.seed(42)

sel_ids = random.sample(range(len(orig_dataset)), k=n_samples*2)
train_sel_ids = sel_ids[:n_samples]
test_sel_ids = sel_ids[n_samples:]

train_sel_ids[:5], test_sel_ids[:5]

([5238, 912, 204, 6074, 2253], [2525, 2987, 6521, 326, 2931])

In [412]:
train_X = []
train_y = []
train_pos = []

for idx in tqdm(train_sel_ids):
    X, y, pos = extract_probing_samples_link_prediction(dataset_sample=orig_dataset[idx],
                                                        db_schemas_dict=spider_db_schemas_dict,
                                                        model=rat_sql_model_dict['model'],
                                                        max_rel_occ=1)
    train_X.extend(X)
    train_y.extend(y)
    
    pos = [(idx, i, j) for i, j in pos]   # add sample idx 
    train_pos.extend(pos)
    
    time.sleep(0.5)
    
len(train_X), len(train_y), len(train_pos)

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




(16263, 16263, 16263)

In [413]:
output_path_train_X = os.path.join(probing_data_dir, f'{orig_ds}.train.X.pkl')
output_path_train_y = os.path.join(probing_data_dir, f'{orig_ds}.train.y.pkl')
output_path_train_pos = os.path.join(probing_data_dir, f'{orig_ds}.train.pos.txt')

with open(output_path_train_X, 'wb') as f:
    pickle.dump(train_X, f)
with open(output_path_train_y, 'wb') as f:
    pickle.dump(train_y, f)
with open(output_path_train_pos, 'w') as f:
    for idx, i, j in train_pos:
        f.write(f'{idx}\t{i}\t{j}\n')

In [414]:
test_X = []
test_y = []
test_pos = []

for idx in tqdm(test_sel_ids):
    X, y, pos = extract_probing_samples_link_prediction(dataset_sample=orig_dataset[idx],
                                                        db_schemas_dict=spider_db_schemas_dict,
                                                        model=rat_sql_model_dict['model'],
                                                        max_rel_occ=1)
    test_X.extend(X)
    test_y.extend(y)
    
    pos = [(idx, i, j) for i, j in pos]   # add sample idx 
    test_pos.extend(pos)
    
    time.sleep(0.5)
    
len(test_X), len(test_y), len(test_pos)

HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




(16252, 16252, 16252)

In [415]:
output_path_test_X = os.path.join(probing_data_dir, f'{orig_ds}.test.X.pkl')
output_path_test_y = os.path.join(probing_data_dir, f'{orig_ds}.test.y.pkl')
output_path_test_pos = os.path.join(probing_data_dir, f'{orig_ds}.test.pos.txt')

with open(output_path_test_X, 'wb') as f:
    pickle.dump(test_X, f)
with open(output_path_test_y, 'wb') as f:
    pickle.dump(test_y, f)
with open(output_path_test_pos, 'w') as f:
    for idx, i, j in test_pos:
        f.write(f'{idx}\t{i}\t{j}\n')

In [416]:
y_counter = Counter(train_y)

for r, c in y_counter.most_common():
    print(relation_id2name[r], ':', c)

('qq_dist', 0) : 500
('qq_dist', 1) : 500
('qq_dist', 2) : 500
qc_default : 500
qt_default : 500
('qq_dist', -1) : 500
('qq_dist', -2) : 500
cq_default : 500
('cc_dist', 0) : 500
cc_default : 500
ct_any_table : 500
cc_table_match : 500
ct_default : 500
ct_table_match : 500
tq_default : 500
tc_any_table : 500
tc_table_match : 500
tc_default : 500
('tt_dist', 0) : 500
ct_primary_key : 496
tc_primary_key : 496
cc_foreign_key_backward : 493
cc_foreign_key_forward : 493
tt_foreign_key_backward : 493
tt_foreign_key_forward : 493
qtTEM : 448
tqTEM : 448
tt_default : 431
qcCEM : 358
cqCEM : 358
qcCPM : 347
cqCPM : 347
qcCELLMATCH : 229
cqCELLMATCH : 229
qtTPM : 156
tqTPM : 156
qcNUMBER : 93
cqNUMBER : 93
qcTIME : 27
cqTIME : 27
ct_foreign_key : 26
tc_foreign_key : 26


In [417]:
len(y_counter), len(relation_id2name)

(42, 51)

In [418]:
y_counter = Counter(test_y)

for r, c in y_counter.most_common():
    print(relation_id2name[r], ':', c)

('qq_dist', 0) : 500
('qq_dist', 1) : 500
('qq_dist', 2) : 500
qc_default : 500
qt_default : 500
('qq_dist', -1) : 500
('qq_dist', -2) : 500
cq_default : 500
('cc_dist', 0) : 500
cc_default : 500
ct_any_table : 500
cc_table_match : 500
ct_default : 500
ct_table_match : 500
tq_default : 500
tc_any_table : 500
tc_table_match : 500
tc_default : 500
('tt_dist', 0) : 500
ct_primary_key : 496
tc_primary_key : 496
cc_foreign_key_backward : 495
cc_foreign_key_forward : 495
tt_foreign_key_backward : 495
tt_foreign_key_forward : 495
qtTEM : 454
tqTEM : 454
tt_default : 436
qcCEM : 350
cqCEM : 350
qcCPM : 344
cqCPM : 344
qcCELLMATCH : 220
cqCELLMATCH : 220
qtTPM : 160
tqTPM : 160
qcNUMBER : 96
cqNUMBER : 96
qcTIME : 32
cqTIME : 32
ct_foreign_key : 16
tc_foreign_key : 16


In [419]:
len(y_counter), len(relation_id2name)

(42, 51)

In [420]:
for r in set(relation_id2name.keys()) - set(y_counter.keys()):
    print(relation_id2name[r])

('tt_dist', -2)
('tt_dist', -1)
('tt_dist', 1)
('tt_dist', 2)
('cc_dist', -2)
('cc_dist', -1)
('cc_dist', 1)
('cc_dist', 2)
tt_foreign_key_both


### Probing experiments

In [422]:
train_train_X_path = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/text2sql/link_prediction/spider/train.train.X.pkl'
train_train_y_path = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/text2sql/link_prediction/spider/train.train.y.pkl'
train_test_X_path = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/text2sql/link_prediction/spider/train.test.X.pkl'
train_test_y_path = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/text2sql/link_prediction/spider/train.test.y.pkl'
dev_train_X_path = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/text2sql/link_prediction/spider/dev.train.X.pkl'
dev_train_y_path = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/text2sql/link_prediction/spider/dev.train.y.pkl'
dev_test_X_path = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/text2sql/link_prediction/spider/dev.test.X.pkl'
dev_test_y_path = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/text2sql/link_prediction/spider/dev.test.y.pkl'

In [423]:
with open(train_train_X_path, 'rb') as f:
    train_train_X = pickle.load(f)
with open(train_train_y_path, 'rb') as f:
    train_train_y = pickle.load(f)
with open(train_test_X_path, 'rb') as f:
    train_test_X = pickle.load(f)
with open(train_test_y_path, 'rb') as f:
    train_test_y = pickle.load(f)
with open(dev_train_X_path, 'rb') as f:
    dev_train_X = pickle.load(f)
with open(dev_train_y_path, 'rb') as f:
    dev_train_y = pickle.load(f)
with open(dev_test_X_path, 'rb') as f:
    dev_test_X = pickle.load(f)
with open(dev_test_y_path, 'rb') as f:
    dev_test_y = pickle.load(f)
    
len(train_train_X), len(train_train_y), len(train_test_X), len(train_test_y), \
len(dev_train_X), len(dev_train_y), len(dev_test_X), len(dev_test_y)

(16263, 16263, 16252, 16252, 16059, 16059, 16035, 16035)

In [432]:
## Train on train.train

In [424]:
clf_train = LogisticRegression(C=1.0)
clf_train.fit(train_train_X, train_train_y)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


LogisticRegression()

In [445]:
clf_train_save_path = os.path.join(probing_exp_dir, 'clf_train.pkl')

with open(clf_train_save_path, 'wb') as f:
    pickle.dump(clf_train, f)

In [446]:
with open(clf_train_save_path, 'rb') as f:
    clf_train = pickle.load(f)

In [447]:
# train-train
preds_train_test_y = clf_train.predict(train_test_X)

In [448]:
sum([py == ty for py, ty in zip(preds_train_test_y, train_test_y)]) / len(train_test_y)

0.8451267536303224

In [449]:
# train-dev
preds_dev_test_y = clf_train.predict(dev_test_X)

In [450]:
sum([py == ty for py, ty in zip(preds_dev_test_y, dev_test_y)]) / len(dev_test_y)

0.863548487683193

In [433]:
## Train on dev.train

In [434]:
clf_dev = LogisticRegression(C=1.0)
clf_dev.fit(dev_train_X, dev_train_y)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


LogisticRegression()

In [444]:
clf_dev_save_path = os.path.join(probing_exp_dir, 'clf_dev.pkl')

with open(clf_dev_save_path, 'wb') as f:
    pickle.dump(clf_dev, f)

In [451]:
with open(clf_dev_save_path, 'rb') as f:
    clf_dev = pickle.load(f)

In [452]:
clf_dev.coef_.shape

(40, 768)

In [453]:
clf_dev.coef_[:10, :10], clf_dev.intercept_

(array([[-0.80019769, -0.310581  ,  0.22105818, -0.97968546, -0.34733584,
          0.31893237,  0.38149306, -0.25577057, -0.60828252, -0.04704865],
        [-0.61489598, -0.2648912 ,  0.1570404 , -0.44630992, -0.13713991,
         -0.0336505 ,  0.35800063, -0.03629656, -0.09981454,  0.74472062],
        [-0.11164215, -0.27722106,  0.33412083,  0.42736333, -0.16961653,
         -0.11628333,  0.13258798, -0.17215397,  0.14638017,  0.06203258],
        [-0.41538832, -0.01712923,  0.04621374,  0.80956922, -0.02668683,
          0.3356642 ,  0.00276802,  0.14841327,  0.0273386 ,  0.3711157 ],
        [-0.066393  , -0.22268648,  0.64487417,  1.11459103, -0.21232745,
          0.21571025, -0.25149596,  0.0589617 ,  0.49173806,  0.44301353],
        [-0.57114421, -0.25044479,  0.29635857,  0.15346643, -0.38270093,
          0.19178963,  0.01220698, -0.18477867,  0.21283573,  0.51224108],
        [-0.56462636, -0.13272939,  0.36375653, -0.05909965, -0.27392543,
          0.26266305,  0.0634358

In [454]:
preds_dev_test_y = clf_dev.predict(dev_test_X)
len(preds_dev_test_y)

16035

In [455]:
sum([py == ty for py, ty in zip(preds_dev_test_y, dev_test_y)]) / len(dev_test_y)

0.8743997505456813

In [430]:
y_counter = Counter(preds_dev_test_y)

for r, c in y_counter.most_common():
    print(relation_id2name[r], ':', c)

('qq_dist', 0) : 622
('tt_dist', 0) : 601
ct_table_match : 556
('cc_dist', 0) : 546
ct_primary_key : 546
tc_primary_key : 537
tc_table_match : 530
cc_foreign_key_backward : 525
tq_default : 507
ct_any_table : 500
tc_any_table : 500
qt_default : 495
cc_default : 491
qc_default : 489
cq_default : 482
('qq_dist', 2) : 481
cc_foreign_key_forward : 481
tt_foreign_key_forward : 481
tt_foreign_key_backward : 479
('qq_dist', -2) : 466
('qq_dist', 1) : 463
('qq_dist', -1) : 462
cc_table_match : 456
tc_default : 433
qtTEM : 427
tqTEM : 426
tt_default : 411
ct_default : 401
qcCEM : 396
cqCEM : 392
cqCPM : 328
qcCPM : 323
cqCELLMATCH : 209
qcCELLMATCH : 201
qtTPM : 102
tqTPM : 93
cqNUMBER : 82
qcNUMBER : 80
qcTIME : 17
cqTIME : 17
tc_foreign_key : 1


## Temp

### Test: getting F-P pairs from sqlite

In [213]:
db_id = 'imdb'
sqlite_path = f"/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/database/{db_id}/{db_id}.sqlite"

sqlite_conn = sqlite3.connect(str(sqlite_path))

In [214]:
## Copied from stackoverflow 
rows = sqlite_conn.execute("SELECT name FROM sqlite_master WHERE type = 'table'")
tables = [r[0] for r in rows]
tables

['actor',
 'copyright',
 'cast',
 'genre',
 'classification',
 'company',
 'director',
 'producer',
 'directed_by',
 'keyword',
 'made_by',
 'movie',
 'tags',
 'tv_series',
 'writer',
 'written_by']

In [215]:
## Copied from stackoverflow 
def sql_identifier(s):
    return '"' + s.replace('"', '""') + '"'

for table in tables:
    print("table: " + table)
    rows = sqlite_conn.execute("PRAGMA table_info({})".format(sql_identifier(table)))
    print(rows.fetchall())
    rows = sqlite_conn.execute("PRAGMA foreign_key_list({})".format(sql_identifier(table)))
    print(rows.fetchall())

table: actor
[(0, 'aid', 'int', 0, None, 1), (1, 'gender', 'text', 0, None, 0), (2, 'name', 'text', 0, None, 0), (3, 'nationality', 'text', 0, None, 0), (4, 'birth_city', 'text', 0, None, 0), (5, 'birth_year', 'int', 0, None, 0)]
[]
table: copyright
[(0, 'id', 'int', 0, None, 1), (1, 'msid', 'int', 0, None, 0), (2, 'cid', 'int', 0, None, 0)]
[]
table: cast
[(0, 'id', 'int', 0, None, 1), (1, 'msid', 'int', 0, None, 0), (2, 'aid', 'int', 0, None, 0), (3, 'role', 'int', 0, None, 0)]
[(0, 0, 'copyright', 'msid', 'msid', 'NO ACTION', 'NO ACTION', 'NONE'), (1, 0, 'actor', 'aid', 'aid', 'NO ACTION', 'NO ACTION', 'NONE')]
table: genre
[(0, 'gid', 'int', 0, None, 1), (1, 'genre', 'text', 0, None, 0)]
[]
table: classification
[(0, 'id', 'int', 0, None, 1), (1, 'msid', 'int', 0, None, 0), (2, 'gid', 'int', 0, None, 0)]
[(0, 0, 'copyright', 'msid', 'msid', 'NO ACTION', 'NO ACTION', 'NONE'), (1, 0, 'genre', 'gid', 'gid', 'NO ACTION', 'NO ACTION', 'NONE')]
table: company
[(0, 'id', 'int', 0, None, 1

In [None]:
# with sqlite3.connect(str(sqlite_path)) as source:
#     dest = sqlite3.connect(':memory:')
#     dest.row_factory = sqlite3.Row
#     source.backup(dest)

### Test: get rat-sql graph

In [None]:
# def Question(q, db_schema, model_dict):
#     model = model_dict['model']
#     inferer = model_dict['inferer']
    
#     data_item = SpiderItem(
#         text=None,  # intentionally None -- should be ignored when the tokenizer is set correctly
#         code=None,
#         schema=db_schema,
#         orig_schema=db_schema.orig,
#         orig={"question": q}
#     )
    
#     model.preproc.clear_items()
#     enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)
#     preproc_data = enc_input, None
    
#     with torch.no_grad():
#         return inferer._infer_one(model, data_item, preproc_data, beam_size=1, use_heuristic=True)
    

# Question('name of people', test_schema, rat_sql_model_dict)

In [325]:
db_id = 'concert_singer'
db_dict = spider_dbs_dict[db_id]
db_schema = db_dict_to_ratsql_schema(db_dict,
                                     db_id,
                                     sqlite_path=f"/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/database/{db_id}/{db_id}.sqlite")

In [None]:
db_schema

In [326]:
question = "How many singers are from France?"

model = rat_sql_model_dict['model']
inferer = rat_sql_model_dict['inferer']

data_item = SpiderItem(
    text=None,  # intentionally None -- should be ignored when the tokenizer is set correctly
    code=None,
    schema=db_schema,
    orig_schema=db_schema.orig,
    orig={"question": question}
)

model.preproc.clear_items()
enc_input = model.preproc.enc_preproc.preprocess_item(data_item, None)

In [327]:
db_schema.orig

{'column_names': [(-1, '*'),
  (0, 'stadium id'),
  (0, 'location'),
  (0, 'name'),
  (0, 'capacity'),
  (0, 'highest'),
  (0, 'lowest'),
  (0, 'average'),
  (1, 'singer id'),
  (1, 'name'),
  (1, 'country'),
  (1, 'song name'),
  (1, 'song release year'),
  (1, 'age'),
  (1, 'is male'),
  (2, 'concert id'),
  (2, 'concert name'),
  (2, 'theme'),
  (2, 'stadium id'),
  (2, 'year'),
  (3, 'concert id'),
  (3, 'singer id')],
 'column_names_original': [(-1, '*'),
  (0, 'Stadium_ID'),
  (0, 'Location'),
  (0, 'Name'),
  (0, 'Capacity'),
  (0, 'Highest'),
  (0, 'Lowest'),
  (0, 'Average'),
  (1, 'Singer_ID'),
  (1, 'Name'),
  (1, 'Country'),
  (1, 'Song_Name'),
  (1, 'Song_release_year'),
  (1, 'Age'),
  (1, 'Is_male'),
  (2, 'concert_ID'),
  (2, 'concert_Name'),
  (2, 'Theme'),
  (2, 'Stadium_ID'),
  (2, 'Year'),
  (3, 'concert_ID'),
  (3, 'Singer_ID')],
 'column_types': ['text',
  'number',
  'text',
  'text',
  'number',
  'number',
  'number',
  'number',
  'number',
  'text',
  'text',

In [328]:
enc_input

{'raw_question': 'How many singers are from France?',
 'question': ['how', 'many', 'singer', 'be', 'from', 'france', '?'],
 'question_for_copying': ['how',
  'many',
  'singers',
  'are',
  'from',
  'france',
  '?'],
 'db_id': 'concert_singer',
 'sc_link': {'q_col_match': {'2,8': 'CPM', '2,21': 'CPM'},
  'q_tab_match': {'2,1': 'TEM', '2,3': 'TPM'}},
 'cv_link': {'num_date_match': {}, 'cell_match': {'5,10': 'CELLMATCH'}},
 'columns': [['<type: text>', '*'],
  ['<type: number>', 'stadium', 'id'],
  ['<type: text>', 'location'],
  ['<type: text>', 'name'],
  ['<type: number>', 'capacity'],
  ['<type: number>', 'highest'],
  ['<type: number>', 'lowest'],
  ['<type: number>', 'average'],
  ['<type: number>', 'singer', 'id'],
  ['<type: text>', 'name'],
  ['<type: text>', 'country'],
  ['<type: text>', 'song', 'name'],
  ['<type: text>', 'song', 'release', 'year'],
  ['<type: number>', 'age'],
  ['<type: others>', 'be', 'male'],
  ['<type: number>', 'concert', 'id'],
  ['<type: text>', 'con

In [329]:
model.encoder.encs_update

RelationalTransformerUpdate(
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn): MultiHeadedAttentionWithRelations(
          (linears): ModuleList(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): Linear(in_features=256, out_features=256, bias=True)
            (2): Linear(in_features=256, out_features=256, bias=True)
            (3): Linear(in_features=256, out_features=256, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=256, out_features=1024, bias=True)
          (w_2): Linear(in_features=1024, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0): SublayerConnection(
            (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [330]:
desc = enc_input

In [331]:
# ## Copied from SpiderEncoderV2.forward_unbatched (not correct)

# # q_enc: question len x batch (=1) x recurrent_size
# q_enc, _ = model.encoder.question_encoder([desc['question']])

# # Encode the columns
# # - LookupEmbeddings
# # - Transform embeddings wrt each other?
# # - Summarize each column into one?
# # c_enc: sum of column lens x batch (=1) x recurrent_size
# c_enc, c_boundaries = model.encoder.column_encoder(desc['columns'])
# # column_pointer_maps = {
# #     i: list(range(left, right))
# #     for i, (left, right) in enumerate(zip(c_boundaries, c_boundaries[1:]))
# # }

# # Encode the tables
# # - LookupEmbeddings
# # - Transform embeddings wrt each other?
# # - Summarize each table into one?
# # t_enc: sum of table lens x batch (=1) x recurrent_size
# t_enc, t_boundaries = model.encoder.table_encoder(desc['tables'])
# # c_enc_length = c_enc.shape[0]
# # table_pointer_maps = {
# #     i: [
# #            idx
# #            for col in desc['table_to_columns'][str(i)]
# #            for idx in column_pointer_maps[col]
# #        ] + list(range(left + c_enc_length, right + c_enc_length))
# #     for i, (left, right) in enumerate(zip(t_boundaries, t_boundaries[1:]))
# # }

In [391]:
# token_lists = [desc['columns']]
# boundaries = [
#             np.cumsum([0] + [len(token_list) for token_list in token_lists_for_item])
#             for token_lists_for_item in token_lists]
# boundaries

[array([ 0,  2,  5,  7,  9, 11, 13, 15, 17, 20, 22, 24, 27, 31, 33, 36, 39,
        42, 44, 47, 49, 52, 55])]

In [332]:
## Adapted from SpiderEncoderV2.forward

q_enc, _ = model.encoder.question_encoder([[desc['question']]])

c_enc, c_boundaries = model.encoder.column_encoder([desc['columns']])

t_enc, t_boundaries = model.encoder.table_encoder([desc['tables']])

In [333]:
q_enc.ps.data.size()

torch.Size([7, 256])

In [373]:
model.encoder.column_encoder

Sequential(
  (0): LookupEmbeddings(
    (embedding): Embedding(1580, 300)
  )
  (1): BiLSTM(
    (lstm): LSTM(
      original_name=LSTM
      (cell): RecurrentDropoutLSTMCell(original_name=RecurrentDropoutLSTMCell)
      (cell_reverse): RecurrentDropoutLSTMCell(original_name=RecurrentDropoutLSTMCell)
    )
  )
)

In [334]:
# ## Adapted from RelationalTransformerUpdate.forward_unbatched (not correct)

# q_enc_unbatched = q_enc.select(0).unsqueeze(1)
# c_enc_unbatched = c_enc.select(0).unsqueeze(1)
# t_enc_unbatched = t_enc.select(0).unsqueeze(1)

# # enc shape: total len x batch (=1) x recurrent size
# enc = torch.cat((q_enc_unbatched, c_enc_unbatched, t_enc_unbatched), dim=0)

# # enc shape: batch (=1) x total len x recurrent size
# enc = enc.transpose(0, 1)

# # Catalogue which things are where
# relations = model.encoder.encs_update.compute_relations(
#     desc,
#     enc_length=enc.shape[1],
#     q_enc_length=q_enc_unbatched.shape[0],
#     c_enc_length=c_enc_unbatched.shape[0],
#     c_boundaries=c_boundaries[0],
#     t_boundaries=t_boundaries[0])

In [335]:
## Adapted from RelationalTransformerUpdate.forward

enc = batched_sequence.PackedSequencePlus.cat_seqs((q_enc, c_enc, t_enc))

q_enc_lengths = list(q_enc.orig_lengths())
c_enc_lengths = list(c_enc.orig_lengths())
t_enc_lengths = list(t_enc.orig_lengths())
enc_lengths = list(enc.orig_lengths())
max_enc_length = max(enc_lengths)

enc_length = enc_lengths[0]
relations = model.encoder.encs_update.compute_relations(
    desc,
    enc_length,
    q_enc_lengths[0],
    c_enc_lengths[0],
    c_boundaries[0],
    t_boundaries[0])

In [336]:
relations.shape

(33, 33)

In [337]:
q_enc_lengths, c_enc_lengths, t_enc_lengths

([7], [22], [4])

In [338]:
relations

array([[ 2,  3,  4, ...,  6,  6,  6],
       [ 1,  2,  3, ...,  6,  6,  6],
       [ 0,  1,  2, ..., 39,  6, 43],
       ...,
       [22, 22, 40, ..., 34, 28, 30],
       [22, 22, 22, ..., 28, 34, 30],
       [22, 22, 44, ..., 29, 29, 34]])

In [339]:
nodes = []

nodes.extend(enc_input['question'])

for c_id, c_toks in enumerate(enc_input['columns']):
    c_name = '_'.join(c_toks[1:])
    t_id = enc_input['column_to_table'][str(c_id)]
    if t_id is None:
        t_name = 'NONE'
    else:
        t_toks = enc_input['tables'][t_id]
        t_name = '_'.join(t_toks)
    c_save_name = f'<C>{t_name}::{c_name}'
    nodes.append(c_save_name)
    
for t_toks in enc_input['tables']:
    nodes.append('<T>' + '_'.join(t_toks))

In [340]:
list(enumerate(nodes))

[(0, 'how'),
 (1, 'many'),
 (2, 'singer'),
 (3, 'be'),
 (4, 'from'),
 (5, 'france'),
 (6, '?'),
 (7, '<C>NONE::*'),
 (8, '<C>stadium::stadium_id'),
 (9, '<C>stadium::location'),
 (10, '<C>stadium::name'),
 (11, '<C>stadium::capacity'),
 (12, '<C>stadium::highest'),
 (13, '<C>stadium::lowest'),
 (14, '<C>stadium::average'),
 (15, '<C>singer::singer_id'),
 (16, '<C>singer::name'),
 (17, '<C>singer::country'),
 (18, '<C>singer::song_name'),
 (19, '<C>singer::song_release_year'),
 (20, '<C>singer::age'),
 (21, '<C>singer::be_male'),
 (22, '<C>concert::concert_id'),
 (23, '<C>concert::concert_name'),
 (24, '<C>concert::theme'),
 (25, '<C>concert::stadium_id'),
 (26, '<C>concert::year'),
 (27, '<C>singer_in_concert::concert_id'),
 (28, '<C>singer_in_concert::singer_id'),
 (29, '<T>stadium'),
 (30, '<T>singer'),
 (31, '<T>concert'),
 (32, '<T>singer_in_concert')]

In [None]:
model.encoder.encs_update.relation_ids

In [341]:
relation_id2name = {v : k for k, v in model.encoder.encs_update.relation_ids.items()}

In [342]:
i = 2
j = 15

nodes[i], nodes[j], relation_id2name[relations[i][j]]

('singer', '<C>singer::singer_id', 'qcCPM')

In [343]:
i = 2
j = 30

nodes[i], nodes[j], relation_id2name[relations[i][j]]

('singer', '<T>singer', 'qtTEM')

In [344]:
i = 5
j = 17

nodes[i], nodes[j], relation_id2name[relations[i][j]]

('france', '<C>singer::country', 'qcCELLMATCH')

In [345]:
i = 5
j = 30

nodes[i], nodes[j], relation_id2name[relations[i][j]]

('france', '<T>singer', 'qt_default')

In [348]:
desc['foreign_keys']

{'18': 1, '20': 15, '21': 8}

In [350]:
i = 8
j = 25

nodes[i], nodes[j], relation_id2name[relations[i][j]]

('<C>stadium::stadium_id', '<C>concert::stadium_id', 'cc_foreign_key_backward')

In [352]:
[t.orig_name for t in db_schema.tables]

['stadium', 'singer', 'concert', 'singer_in_concert']

In [359]:
[(i, c.table.unsplit_name, c.unsplit_name) for i, c in list(enumerate(db_schema.columns))[1:]]

[(1, 'stadium', 'stadium id'),
 (2, 'stadium', 'location'),
 (3, 'stadium', 'name'),
 (4, 'stadium', 'capacity'),
 (5, 'stadium', 'highest'),
 (6, 'stadium', 'lowest'),
 (7, 'stadium', 'average'),
 (8, 'singer', 'singer id'),
 (9, 'singer', 'name'),
 (10, 'singer', 'country'),
 (11, 'singer', 'song name'),
 (12, 'singer', 'song release year'),
 (13, 'singer', 'age'),
 (14, 'singer', 'is male'),
 (15, 'concert', 'concert id'),
 (16, 'concert', 'concert name'),
 (17, 'concert', 'theme'),
 (18, 'concert', 'stadium id'),
 (19, 'concert', 'year'),
 (20, 'singer in concert', 'concert id'),
 (21, 'singer in concert', 'singer id')]

In [377]:
_t = db_schema.tables[0]
_t

Table(id=0, name=['stadium'], unsplit_name='stadium', orig_name='stadium', columns=[Column(id=1, table=..., name=['stadium', 'id'], unsplit_name='stadium id', orig_name='Stadium_ID', type='number', foreign_key_for=None), Column(id=2, table=..., name=['location'], unsplit_name='location', orig_name='Location', type='text', foreign_key_for=None), Column(id=3, table=..., name=['name'], unsplit_name='name', orig_name='Name', type='text', foreign_key_for=None), Column(id=4, table=..., name=['capacity'], unsplit_name='capacity', orig_name='Capacity', type='number', foreign_key_for=None), Column(id=5, table=..., name=['highest'], unsplit_name='highest', orig_name='Highest', type='number', foreign_key_for=None), Column(id=6, table=..., name=['lowest'], unsplit_name='lowest', orig_name='Lowest', type='number', foreign_key_for=None), Column(id=7, table=..., name=['average'], unsplit_name='average', orig_name='Average', type='number', foreign_key_for=None)], primary_keys=[Column(id=1, table=..., 

In [378]:
_c = db_schema.columns[18]
_c

Column(id=18, table=Table(id=2, name=['concert'], unsplit_name='concert', orig_name='concert', columns=[Column(id=15, table=..., name=['concert', 'id'], unsplit_name='concert id', orig_name='concert_ID', type='number', foreign_key_for=None), Column(id=16, table=..., name=['concert', 'name'], unsplit_name='concert name', orig_name='concert_Name', type='text', foreign_key_for=None), Column(id=17, table=..., name=['theme'], unsplit_name='theme', orig_name='Theme', type='text', foreign_key_for=None), ..., Column(id=19, table=..., name=['year'], unsplit_name='year', orig_name='Year', type='text', foreign_key_for=None)], primary_keys=[Column(id=15, table=..., name=['concert', 'id'], unsplit_name='concert id', orig_name='concert_ID', type='number', foreign_key_for=None)]), name=['stadium', 'id'], unsplit_name='stadium id', orig_name='Stadium_ID', type='text', foreign_key_for=Column(id=1, table=Table(id=0, name=['stadium'], unsplit_name='stadium', orig_name='stadium', columns=[..., Column(id=2

In [382]:
_c = 18
_t = 0
model.encoder.encs_update.match_foreign_key(desc, _c, _t)

False

In [375]:
desc['foreign_keys']

{'18': 1, '20': 15, '21': 8}

In [380]:
desc['column_to_table']

{'0': None,
 '1': 0,
 '2': 0,
 '3': 0,
 '4': 0,
 '5': 0,
 '6': 0,
 '7': 0,
 '8': 1,
 '9': 1,
 '10': 1,
 '11': 1,
 '12': 1,
 '13': 1,
 '14': 1,
 '15': 2,
 '16': 2,
 '17': 2,
 '18': 2,
 '19': 2,
 '20': 3,
 '21': 3}

In [383]:
foreign_key_for = desc['foreign_keys'].get(str(_c))

foreign_table = desc['column_to_table'][str(foreign_key_for)]

desc['column_to_table'][str(_c)], foreign_table

(2, 0)

In [388]:
model.encoder.column_encoder

Sequential(
  (0): LookupEmbeddings(
    (embedding): Embedding(1580, 300)
  )
  (1): BiLSTM(
    (lstm): LSTM(
      original_name=LSTM
      (cell): RecurrentDropoutLSTMCell(original_name=RecurrentDropoutLSTMCell)
      (cell_reverse): RecurrentDropoutLSTMCell(original_name=RecurrentDropoutLSTMCell)
    )
  )
)

In [387]:
model.encoder.column_encoder[1].summarize

True

### Test collect_link_prediction_samples
- extract_probing_samples_link_prediction_new()

In [27]:
tmp_probing_data_dir = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/probing/text2sql/link_prediction/spider/ratsql"
tmp_pos_path = os.path.join(tmp_probing_data_dir, 'dev.train.pos.txt')

tmp_dataset_path = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/spider/dev+ratsql_graph.json"


In [28]:
with open(tmp_dataset_path, 'r') as f:
    tmp_dataset = json.load(f)
    
for d in tmp_dataset:
    d['rat_sql_graph']['relations'] = json.loads(d['rat_sql_graph']['relations'])

len(tmp_dataset), tmp_dataset[0].keys()

(1034,
 dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql', 'rat_sql_graph']))

In [29]:
with open(tmp_pos_path, 'r') as f:
    lines = f.read().strip().split('\n')
    tmp_pos = [tuple([int(s) for s in l.split('\t')]) for l in lines]
len(tmp_pos), tmp_pos[0]

(16059, (228, 0, 0))

In [30]:
# Load pos file 
pos_per_sample = defaultdict(list)   # key = ds_idx, value = pos_list: List[(i, j)]

for ds_idx, i, j in tmp_pos:
    pos_per_sample[ds_idx].append((i, j))

len(pos_per_sample)

500

In [31]:
# the first ds_idx in dev.train is 228 
idx = 228
d = tmp_dataset[idx]
d.keys(), len(pos_per_sample[idx])

(dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql', 'rat_sql_graph']),
 30)

In [46]:
X, y, pos = extract_probing_samples_link_prediction_new(dataset_sample=d,
                                                        db_schemas_dict=spider_db_schemas_dict,
                                                        model=rat_sql_model_dict['model'],
                                                        pos=pos_per_sample[idx],
                                                        max_rel_occ=None,
                                                        debug=True)
len(X), len(y), len(pos)

Sampled output idx: 2
i = 2, j = 5
Nodes: code, airport
Relation: 4 (('qq_dist', 2))
Repr vectors:
[ 0.1501283   0.21355523 -0.25271013] ... [0.13540143 0.27801117 0.00371891]
[ 0.31505942 -0.47566113  0.06388362] ... [-0.08114626  0.08402237 -0.43422976]
Combined vector:
[ 0.1501283   0.21355523 -0.25271013] ... [-0.01098732  0.02335916 -0.00161486]
Label:
4


(30, 30, 30)

In [48]:
tmp_X_path = os.path.join(tmp_probing_data_dir, 'dev.train.X.pkl')

with open(tmp_X_path, 'rb') as f:
    tmp_X = pickle.load(f)

len(tmp_X)

16059

In [52]:
np.allclose(X, tmp_X[:30])

True

In [83]:
relation_id2name = {idx : rel for rel, idx in model.encoder.encs_update.relation_ids.items()}
relation_id2name

{0: ('qq_dist', -2),
 1: ('qq_dist', -1),
 2: ('qq_dist', 0),
 3: ('qq_dist', 1),
 4: ('qq_dist', 2),
 5: 'qc_default',
 6: 'qt_default',
 7: 'cq_default',
 8: 'cc_default',
 9: 'cc_foreign_key_forward',
 10: 'cc_foreign_key_backward',
 11: 'cc_table_match',
 12: ('cc_dist', -2),
 13: ('cc_dist', -1),
 14: ('cc_dist', 0),
 15: ('cc_dist', 1),
 16: ('cc_dist', 2),
 17: 'ct_default',
 18: 'ct_foreign_key',
 19: 'ct_primary_key',
 20: 'ct_table_match',
 21: 'ct_any_table',
 22: 'tq_default',
 23: 'tc_default',
 24: 'tc_primary_key',
 25: 'tc_table_match',
 26: 'tc_any_table',
 27: 'tc_foreign_key',
 28: 'tt_default',
 29: 'tt_foreign_key_forward',
 30: 'tt_foreign_key_backward',
 31: 'tt_foreign_key_both',
 32: ('tt_dist', -2),
 33: ('tt_dist', -1),
 34: ('tt_dist', 0),
 35: ('tt_dist', 1),
 36: ('tt_dist', 2),
 37: 'qcCEM',
 38: 'cqCEM',
 39: 'qtTEM',
 40: 'tqTEM',
 41: 'qcCPM',
 42: 'cqCPM',
 43: 'qtTPM',
 44: 'tqTPM',
 45: 'qcNUMBER',
 46: 'cqNUMBER',
 47: 'qcTIME',
 48: 'cqTIME',
 49:

### others

In [229]:
random.sample([1,2], k=3)

[1, 1, 2]

In [231]:
np.concatenate([[1,2], [3,4]])

array([1, 2, 3, 4])

In [54]:
model = rat_sql_model_dict['model']
type(model)

ratsql.models.enc_dec.EncDecModel

In [57]:
model.encoder.preproc.word_emb.__dict__

{'glove': <torchtext.vocab.GloVe at 0x158e96690>,
 'dim': 300,
 'vectors': tensor([[ 0.1838, -0.1212, -0.1199,  ..., -0.0390,  0.1827,  0.1465],
         [-0.2084, -0.1493, -0.0175,  ..., -0.5407,  0.2120, -0.0094],
         [ 0.1088,  0.0022,  0.2221,  ..., -0.2970,  0.1594, -0.1490],
         ...,
         [ 0.2736,  0.0413, -0.1227,  ..., -0.3318,  0.0379,  0.0564],
         [-0.0524,  0.3214,  0.2324,  ..., -0.0813,  0.0481, -0.0872],
         [-0.1197,  0.1602, -0.2492,  ..., -0.0909,  0.2783,  0.1137]]),
 'lemmatize': True,
 'corenlp_annotators': ['tokenize', 'ssplit', 'lemma']}