In [325]:
import sys
import os
import time
import torch
import datasets
from transformers import (
    HfArgumentParser,
    set_seed,
    AutoTokenizer
)
from utils.configue import Configure
from utils.training_arguments import WrappedSeq2SeqTrainingArguments
from models.unified import finetune, prefixtuning
from models.unified.prefixtuning import Model

import nltk

# from filelock import FileLock
# with FileLock(".lock") as lock:
#     nltk.download("punkt", quiet=True)
#     nltk.download("stopwords", quiet=True)

In [182]:
import json
from copy import deepcopy
from collections import Counter, defaultdict
import importlib
import pickle

from seq2seq_construction import spider
from third_party.spider.preprocess.get_tables import dump_db_json_schema

import numpy as np
from tqdm.notebook import tqdm
import editdistance
from nltk.translate.bleu_score import corpus_bleu
from nltk.tokenize.treebank import TreebankWordDetokenizer

# from SpeakQL.Allennlp_models.utils.spider import process_sql, evaluation
# from SpeakQL.Allennlp_models.utils.misc_utils import EvaluateSQL, EvaluateSQL_full, \
#     Postprocess_rewrite_seq, Postprocess_rewrite_seq_freeze_POS, Postprocess_rewrite_seq_modify_POS

In [15]:
from language.xsp.data_preprocessing import spider_preprocessing, wikisql_preprocessing, michigan_preprocessing

import sdr_analysis
importlib.reload(sdr_analysis.helpers.general_helpers)
from sdr_analysis.helpers.general_helpers import db_dict_to_general_fmt, collect_link_prediction_samples

## Read data - Schema

### Spider

#### Original loading test

In [20]:
schema_cache = dict()

db_path = '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/database'

def format_spider_schema(db_id):
    if db_id not in schema_cache:
        schema_cache[db_id] = dump_db_json_schema(
            db_path + "/" + db_id + "/" + db_id + ".sqlite", db_id)
    schema = schema_cache[db_id]

    return {
        "db_id": db_id,
        "db_path": db_path,
        "db_table_names": schema["table_names_original"],
        "db_column_names": {
            "table_id": [table_id for table_id, column_name in schema["column_names_original"]],
            "column_name": [column_name for table_id, column_name in schema["column_names_original"]]
        },
        "db_column_types": schema["column_types"],
        "db_primary_keys": [{"column_id": column_id} for column_id in schema["primary_keys"]],
        "db_foreign_keys": [
            {"column_id": column_id, "other_column_id": other_column_id}
            for column_id, other_column_id in schema["foreign_keys"]
        ],
    }

In [166]:
fmt_schema = format_spider_schema('world_1')
fmt_schema

{'db_id': 'world_1',
 'db_path': '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/database',
 'db_table_names': ['city', 'sqlite_sequence', 'country', 'countrylanguage'],
 'db_column_names': {'table_id': [-1,
   0,
   0,
   0,
   0,
   0,
   1,
   1,
   2,
   2,
   2,
   2,
   2,
   2,
   2,
   2,
   2,
   2,
   2,
   2,
   2,
   2,
   2,
   3,
   3,
   3,
   3],
  'column_name': ['*',
   'ID',
   'Name',
   'CountryCode',
   'District',
   'Population',
   'name',
   'seq',
   'Code',
   'Name',
   'Continent',
   'Region',
   'SurfaceArea',
   'IndepYear',
   'Population',
   'LifeExpectancy',
   'GNP',
   'GNPOld',
   'LocalName',
   'GovernmentForm',
   'HeadOfState',
   'Capital',
   'Code2',
   'CountryCode',
   'Language',
   'IsOfficial',
   'Percentage']},
 'db_column_types': ['text',
  'number',
  'text',
  'text',
  'text',
  'number',
  'text',
  'text',
  'text',
  'text',
  'text',
  'text',
  'number',
  'number',
  'number',
  'number',
  'number',
  'number',
  'te

In [22]:
schema_cache['concert_singer']

{'db_id': 'concert_singer',
 'table_names_original': ['stadium', 'singer', 'concert', 'singer_in_concert'],
 'table_names': ['stadium', 'singer', 'concert', 'singer in concert'],
 'column_names_original': [(-1, '*'),
  (0, 'Stadium_ID'),
  (0, 'Location'),
  (0, 'Name'),
  (0, 'Capacity'),
  (0, 'Highest'),
  (0, 'Lowest'),
  (0, 'Average'),
  (1, 'Singer_ID'),
  (1, 'Name'),
  (1, 'Country'),
  (1, 'Song_Name'),
  (1, 'Song_release_year'),
  (1, 'Age'),
  (1, 'Is_male'),
  (2, 'concert_ID'),
  (2, 'concert_Name'),
  (2, 'Theme'),
  (2, 'Stadium_ID'),
  (2, 'Year'),
  (3, 'concert_ID'),
  (3, 'Singer_ID')],
 'column_names': [(-1, '*'),
  (0, 'stadium id'),
  (0, 'location'),
  (0, 'name'),
  (0, 'capacity'),
  (0, 'highest'),
  (0, 'lowest'),
  (0, 'average'),
  (1, 'singer id'),
  (1, 'name'),
  (1, 'country'),
  (1, 'song name'),
  (1, 'song release year'),
  (1, 'age'),
  (1, 'is male'),
  (2, 'concert id'),
  (2, 'concert name'),
  (2, 'theme'),
  (2, 'stadium id'),
  (2, 'year'),


In [23]:
list(enumerate(fmt_schema['db_column_names']['column_name']))

[(0, '*'),
 (1, 'Stadium_ID'),
 (2, 'Location'),
 (3, 'Name'),
 (4, 'Capacity'),
 (5, 'Highest'),
 (6, 'Lowest'),
 (7, 'Average'),
 (8, 'Singer_ID'),
 (9, 'Name'),
 (10, 'Country'),
 (11, 'Song_Name'),
 (12, 'Song_release_year'),
 (13, 'Age'),
 (14, 'Is_male'),
 (15, 'concert_ID'),
 (16, 'concert_Name'),
 (17, 'Theme'),
 (18, 'Stadium_ID'),
 (19, 'Year'),
 (20, 'concert_ID'),
 (21, 'Singer_ID')]

#### New loading from nested dict

In [4]:
xsp_data_dir = "/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data"

spider_tables_path = os.path.join(xsp_data_dir, 'spider', 'tables.json')

spider_dbs_dict = spider_preprocessing.load_spider_tables(spider_tables_path)

In [5]:
spider_dbs_dict.keys()

dict_keys(['perpetrator', 'college_2', 'flight_company', 'icfp_1', 'body_builder', 'storm_record', 'pilot_record', 'race_track', 'academic', 'department_store', 'music_4', 'insurance_fnol', 'cinema', 'decoration_competition', 'phone_market', 'store_product', 'assets_maintenance', 'student_assessment', 'dog_kennels', 'music_1', 'company_employee', 'farm', 'solvency_ii', 'city_record', 'swimming', 'flight_2', 'election', 'manufactory_1', 'debate', 'network_2', 'local_govt_in_alabama', 'climbing', 'e_learning', 'scientist_1', 'ship_1', 'entertainment_awards', 'allergy_1', 'imdb', 'products_for_hire', 'candidate_poll', 'chinook_1', 'flight_4', 'pets_1', 'dorm_1', 'journal_committee', 'flight_1', 'medicine_enzyme_interaction', 'local_govt_and_lot', 'station_weather', 'shop_membership', 'driving_school', 'concert_singer', 'music_2', 'sports_competition', 'railway', 'inn_1', 'museum_visit', 'browser_web', 'baseball_1', 'architecture', 'csu_1', 'tracking_orders', 'insurance_policies', 'gas_com

In [6]:
spider_dbs_dict['concert_singer']

{'stadium': [{'field name': 'Stadium_ID',
   'is primary key': True,
   'is foreign key': True,
   'type': 'number'},
  {'field name': 'Location',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'Name',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'Capacity',
   'is primary key': False,
   'is foreign key': False,
   'type': 'number'},
  {'field name': 'Highest',
   'is primary key': False,
   'is foreign key': False,
   'type': 'number'},
  {'field name': 'Lowest',
   'is primary key': False,
   'is foreign key': False,
   'type': 'number'},
  {'field name': 'Average',
   'is primary key': False,
   'is foreign key': False,
   'type': 'number'}],
 'singer': [{'field name': 'Singer_ID',
   'is primary key': True,
   'is foreign key': True,
   'type': 'number'},
  {'field name': 'Name',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'Country',
   'is

In [7]:
# def db_dict_to_fmt_schema(db_dict):
#     """
#     Args:
#         db_dict: Dict[table_name, List[column_dict["field name", "is primary key", "is foreign key", "type"]]]
    
#     Output:
#         fmt_schema = (for reference) {
#             "db_table_names": schema["table_names_original"],
#             "db_column_names": {
#                 "table_id": [table_id for table_id, column_name in schema["column_names_original"]],
#                 "column_name": [column_name for table_id, column_name in schema["column_names_original"]]
#             },
#             "db_column_types": schema["column_types"],
#             "db_primary_keys": [{"column_id": column_id} for column_id in schema["primary_keys"]],
#             "db_foreign_keys": [
#                 {"column_id": column_id, "other_column_id": other_column_id}
#                 for column_id, other_column_id in schema["foreign_keys"]
#             ],
#         }
#     """

#     db_table_names = []
#     db_column_names = ['*']  # default values in spider, same below 
#     db_column_table_ids = [-1]
#     db_column_types = ['text']
#     db_primary_keys = []
#     db_foreign_keys = []
    
#     # for a column name, find the primary key idx. If this is a foreign key and not primary key, then found a f-p pair
#     # (assume the f-p pair have the same name)
#     # This is not always true. Example DB: architecture::architect_id, college_1::PROF_NUM
#     col_name2p_key_column_idx = dict()
#     col_name2f_keys_column_idx = defaultdict(list)
    
#     for table_name, table_columns in db_dict.items():
#         table_idx = len(db_table_names)
#         db_table_names.append(table_name)
        
#         for col_dict in table_columns:
#             col_idx = len(db_column_names)
#             db_column_names.append(col_dict["field name"])
#             db_column_table_ids.append(table_idx)
#             db_column_types.append(col_dict["type"])
            
#             # in michigan datasets, "is primary||foreign key" is "y"/"n"; in spider and wikisql, it is true/false
#             if col_dict["is primary key"] in {True, 'y'}:
#                 db_primary_keys.append({"column_id": col_idx})
#                 if col_dict["field name"] in col_name2p_key_column_idx:
#                     ## already exists?? Yes, there could be f-p pairs where both are primary! 
#                     # print(f'Warning: {col_dict["field name"]} already exists')
#                     pass
#                 else:
#                     col_name2p_key_column_idx[col_dict["field name"]] = col_idx
            
#             if col_dict["is foreign key"] in {True, 'y'}:
#                 col_name2f_keys_column_idx[col_dict["field name"]].append(col_idx)
        
#     for col_name in col_name2f_keys_column_idx.keys():
#         f_key_column_ids = col_name2f_keys_column_idx[col_name]
#         if col_name in col_name2p_key_column_idx:
#             p_key_column_idx = col_name2p_key_column_idx[col_name]
#         else:
#             print(f'Warning: {col_name}, no primary key found')
#             continue
            
#         for f_key_column_idx in f_key_column_ids:
#             if f_key_column_idx != p_key_column_idx:
#                 db_foreign_keys.append({"column_id": f_key_column_idx, "other_column_id": p_key_column_idx})
    
#     return {
#         "db_table_names": db_table_names,
#         "db_column_names": {
#             "table_id": db_column_table_ids,
#             "column_name": db_column_names,
#         },
#         "db_column_types": db_column_types,
#         "db_primary_keys": db_primary_keys,
#         "db_foreign_keys": db_foreign_keys
#     }


In [43]:
def general_fmt_dict_to_uskg_schema(general_fmt_dict):
    """
    Args:
        general_fmt_dict (Dict): {
            "db_id": str
            "table_names_original": List[str], original table name (concert_singer)
            "table_names_clean": List[str], clean table names (concert_singer)
            "column_names_original": List[str], original column name (singer_id)
            "column_names_clean": List[str], clean columns names (singer id)
            "column_db_full_names": List[str], name of table::column in DB (may differ from column_names) (singer::singer_id)
            "column_table_ids": List[int], for each column, the corresponding table index
            "column_types": List[str], column types
            "primary_keys": List[int], the columns indices that are primary key
            "foreign_keys": List[[int, int]], the f-p column index pairs (fk_id, pk_id)
            "sqlite_path": str
            "sqlite_conn": sqlite3.Connection
        }
    
    Output:
        uskg_schema = (for reference) {
            "db_table_names": schema["table_names_original"],
            "db_column_names": {
                "table_id": [table_id for table_id, column_name in schema["column_names_original"]],
                "column_name": [column_name for table_id, column_name in schema["column_names_original"]]
            },
            "db_column_types": schema["column_types"],
            "db_primary_keys": [{"column_id": column_id} for column_id in schema["primary_keys"]],
            "db_foreign_keys": [
                {"column_id": column_id, "other_column_id": other_column_id}
                for column_id, other_column_id in schema["foreign_keys"]
            ],
        }
    """

    db_id = general_fmt_dict["db_id"]
    db_table_orig_names = general_fmt_dict["table_names_original"]
    db_table_clean_names = general_fmt_dict["table_names_clean"]
    db_column_orig_names = general_fmt_dict["column_names_original"]
    db_column_clean_names = general_fmt_dict["column_names_clean"]
    col_db_full_names = general_fmt_dict["column_db_full_names"]
    db_column_table_ids = general_fmt_dict["column_table_ids"]
    db_column_types = general_fmt_dict["column_types"]
    db_primary_keys = general_fmt_dict["primary_keys"]
    db_foreign_keys = general_fmt_dict["foreign_keys"]
    sqlite_path = general_fmt_dict["sqlite_path"]
    sqlite_conn = general_fmt_dict["sqlite_conn"]
    
    # USKG specific
    uskg_primary_keys = [{"column_id": col_idx} for col_idx in db_primary_keys]
    uskg_foreign_keys = [{"column_id": fk_idx, "other_column_id": pk_idx} for fk_idx, pk_idx in db_foreign_keys]

    uskg_schema = {
        "db_id": db_id,
        "db_table_names": db_table_orig_names,
        "db_column_names": {
            "table_id": db_column_table_ids,
            "column_name": db_column_orig_names,
        },
        "db_column_types": db_column_types,
        "db_primary_keys": [{"column_id": column_id} for column_id in db_primary_keys],
        "db_foreign_keys": [
            {"column_id": column_id, "other_column_id": other_column_id}
            for column_id, other_column_id in db_foreign_keys
        ],
    }
    
    return uskg_schema


In [44]:
db_id = 'architecture'
db_dict = spider_dbs_dict[db_id]

general_fmt_dict = db_dict_to_general_fmt(db_dict, db_id,
                                          sqlite_path=f"/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/database/{db_id}/{db_id}.sqlite",
                                          rigorous_foreign_key=True)

In [45]:
general_fmt_dict

{'db_id': 'architecture',
 'table_names_original': ['architect', 'bridge', 'mill'],
 'table_names_clean': ['architect', 'bridge', 'mill'],
 'column_names_original': ['*',
  'id',
  'name',
  'nationality',
  'gender',
  'architect_id',
  'id',
  'name',
  'location',
  'length_meters',
  'length_feet',
  'architect_id',
  'id',
  'location',
  'name',
  'type',
  'built_year',
  'notes'],
 'column_names_clean': ['*',
  'id',
  'name',
  'nationality',
  'gender',
  'architect id',
  'id',
  'name',
  'location',
  'length meters',
  'length feet',
  'architect id',
  'id',
  'location',
  'name',
  'type',
  'built year',
  'notes'],
 'column_db_full_names': ['NONE::*',
  'architect::id',
  'architect::name',
  'architect::nationality',
  'architect::gender',
  'bridge::architect_id',
  'bridge::id',
  'bridge::name',
  'bridge::location',
  'bridge::length_meters',
  'bridge::length_feet',
  'mill::architect_id',
  'mill::id',
  'mill::location',
  'mill::name',
  'mill::type',
  'mil

In [46]:
fmt_schema_2 = general_fmt_dict_to_uskg_schema(general_fmt_dict)

In [47]:
fmt_schema_2

{'db_id': 'architecture',
 'db_table_names': ['architect', 'bridge', 'mill'],
 'db_column_names': {'table_id': [-1,
   0,
   0,
   0,
   0,
   1,
   1,
   1,
   1,
   1,
   1,
   2,
   2,
   2,
   2,
   2,
   2,
   2],
  'column_name': ['*',
   'id',
   'name',
   'nationality',
   'gender',
   'architect_id',
   'id',
   'name',
   'location',
   'length_meters',
   'length_feet',
   'architect_id',
   'id',
   'location',
   'name',
   'type',
   'built_year',
   'notes']},
 'db_column_types': ['text',
  'text',
  'text',
  'text',
  'text',
  'number',
  'number',
  'text',
  'text',
  'number',
  'number',
  'number',
  'number',
  'text',
  'text',
  'text',
  'number',
  'text'],
 'db_primary_keys': [{'column_id': 1}, {'column_id': 6}, {'column_id': 12}],
 'db_foreign_keys': [{'column_id': 5, 'other_column_id': 1},
  {'column_id': 11, 'other_column_id': 1}]}

In [None]:
fmt_schema = format_spider_schema(db_id)
# del fmt_schema['db_id']
del fmt_schema['db_path']

fmt_schema == fmt_schema_2

In [None]:
sorted(fmt_schema["db_foreign_keys"], key=lambda d: d['column_id']), sorted(fmt_schema_2["db_foreign_keys"], key=lambda d: d['column_id'])

In [None]:
list(enumerate(fmt_schema['db_column_names']['column_name']))

### WikiSQL
- Should be able to reuse db_dict_to_fmt_schema()

In [103]:
wikisql_tables_path = os.path.join(xsp_data_dir, 'wikisql', 'dev.tables.jsonl')

wikisql_dbs_dict = wikisql_preprocessing.load_wikisql_tables(wikisql_tables_path)

In [104]:
wikisql_dbs_dict['1-29690363-3'].keys(), wikisql_dbs_dict['1-29690363-3']['RACE_RESULTS']

(dict_keys(['RACE_RESULTS']),
 [{'field name': 'RD',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'RACE',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'POLE_POSITION',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'FASTEST_LAP',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'MOST_LAPS_LED',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'WINNING_DRIVER',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'},
  {'field name': 'WINNING_TEAM',
   'is primary key': False,
   'is foreign key': False,
   'type': 'text'}])

In [105]:
db_dict_to_fmt_schema(wikisql_dbs_dict['1-29690363-3'])

{'db_table_names': ['RACE_RESULTS'],
 'db_column_names': {'table_id': [-1, 0, 0, 0, 0, 0, 0, 0],
  'column_name': ['*',
   'RD',
   'RACE',
   'POLE_POSITION',
   'FASTEST_LAP',
   'MOST_LAPS_LED',
   'WINNING_DRIVER',
   'WINNING_TEAM']},
 'db_column_types': ['text',
  'text',
  'text',
  'text',
  'text',
  'text',
  'text',
  'text'],
 'db_primary_keys': [],
 'db_foreign_keys': []}

### Michigan
- Should be able to reuse db_dict_to_fmt_schema()

In [106]:
atis_schema_path = os.path.join(xsp_data_dir, 'atis', 'atis_schema.csv')

atis_db_dict = michigan_preprocessing.read_schema(atis_schema_path)

In [None]:
atis_db_dict

In [None]:
db_dict_to_fmt_schema(atis_db_dict)

## Get USKG encoding

In [None]:
# TODO:
#     - set up inference (for sanity check)
#     - look into source (of transformers Bert.generate) to find ways to get the encoding

### Check USKG inference

In [31]:
def play_pred(txt, model, tokenizer):
    tokenized_txt = tokenizer([txt], max_length=1024, padding="max_length", truncation=True)
    pred = tokenizer.batch_decode(
      model.generate(
        torch.LongTensor(tokenized_txt.data['input_ids']),
        torch.LongTensor(tokenized_txt.data['attention_mask']),
        num_beams=1, 
        max_length=256
        ), 
      skip_special_tokens=True 
    )
    return pred

In [32]:
# Set args here for runnning on notebook, we make them out here to make it more illustrative.
sys.argv = ['/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py', # This is the name of your .py launcher when you run this line of code.
            # belows are the parameters we set, take spider for example
            '--cfg', 'Salesforce/T5_large_prefix_spider_with_cell_value.cfg', 
            '--output_dir', './tmp']
parser = HfArgumentParser((WrappedSeq2SeqTrainingArguments,))
training_args, = parser.parse_args_into_dataclasses()
set_seed(training_args.seed)
args = Configure.Get(training_args.cfg)

In [33]:
model_path = 'hkunlp/from_all_T5_large_prefix_spider_with_cell_value2'
# model_path = 'hkunlp/from_all_T5_large_prefix_spider_with_cell_value2'
# model_path = '/Users/mac/Desktop/syt/Deep-Learning/Repos/UnifiedSKG/output/server_runs/A-T5_base_prefix_spider_with_cell_value-asr_mixed/checkpoint-79500/'
# model_path = '/Users/mac/Desktop/syt/Deep-Learning/Repos/UnifiedSKG/output/server_runs/A-T5_base_prefix_spider_with_cell_value-rewritten_mixed/checkpoint-56500/'

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

# for reconstruction
tokenizer_fast = AutoTokenizer.from_pretrained('t5-base', use_fast=True)

model = Model(args)
model.load(model_path)

prefix-tuning sequence length is 10.


In [263]:
tokenizer_large_fast = AutoTokenizer.from_pretrained('t5-large', use_fast=True)

In [34]:
struct_in = "| concert_singer | stadium : stadium_id , location , name , capacity , highest , lowest , average | singer : singer_id , name , country ( France ) , song_name , song_release_year , age , is_male | concert : concert_id , concert_name , theme , stadium_id , year | singer_in_concert : concert_id , singer_id"
text_in = "what is the minimum, average, and maximum age of all singers from France?"

play_pred("{}; structed knowledge: {}".format(text_in, struct_in), model, tokenizer)

['select min(age), avg(age), max(age) from singer where country = "France"']

In [35]:
play_pred("{}; structed knowledge: {}".format(text_in, struct_in), model, tokenizer_fast)

['select min(age), avg(age), max(age) from singer where country = "France"']

### Check USKG inference on dataset samples

In [48]:
xsp_data_dir = "/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data"

orig_dataset_path = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/spider/dev+ratsql_graph.json"
orig_tables_path = "/Users/mac/Desktop/syt/Deep-Learning/Repos/Google-Research-Language/language/language/xsp/data/spider/tables.json"

In [49]:
spider_dbs_dict = spider_preprocessing.load_spider_tables(orig_tables_path)
len(spider_dbs_dict)

166

In [50]:
uskg_schemas_dict = dict()
for db_id, db_dict in spider_dbs_dict.items():
    general_fmt_dict = db_dict_to_general_fmt(db_dict, db_id,
                                              sqlite_path=os.path.join(xsp_data_dir, f"spider/database/{db_id}/{db_id}.sqlite"),
                                              rigorous_foreign_key=True)
    
    uskg_schema = general_fmt_dict_to_uskg_schema(general_fmt_dict)
    
    uskg_schemas_dict[db_id] = uskg_schema
    
len(uskg_schemas_dict)



166

In [124]:
with open(orig_dataset_path, 'r') as f:
    orig_dataset = json.load(f)

for d in orig_dataset:
    d['rat_sql_graph']['relations'] = json.loads(d['rat_sql_graph']['relations'])

len(orig_dataset)

1034

In [125]:
orig_dataset[0].keys()

dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql', 'rat_sql_graph'])

In [126]:
SPIDER_DB_PATH = os.path.join(xsp_data_dir, "spider/database")

def uskg_sample_to_struct_input(uskg_sample):
    db_id = uskg_sample["db_id"]
    uskg_schema = uskg_schemas_dict[db_id]
    
    return spider.serialize_schema(
        question=uskg_sample["question"],
        db_path=SPIDER_DB_PATH,
        db_id=db_id,
        db_column_names=uskg_schema["db_column_names"],
        db_table_names=uskg_schema["db_table_names"],
        schema_serialization_type="peteshaw",
        schema_serialization_randomized=False,
        schema_serialization_with_db_id=True,
        schema_serialization_with_db_content=True,
        normalize_query=True,
    )

In [128]:
idx = 777
sample = orig_dataset[idx]
sample.keys(), sample['rat_sql_graph'].keys()

(dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql', 'rat_sql_graph']),
 dict_keys(['nodes', 'q_nodes_orig', 'relations']))

In [129]:
text_in = sample['question']
struct_in = uskg_sample_to_struct_input(sample)
text_in, struct_in

('What are the Asian countries which have a population larger than that of any country in Africa?',
 ' | world_1 | city : id , name , countrycode , district , population | sqlite_sequence : name , seq | country : code , name , continent ( Africa , Asia ) , region , surfacearea , indepyear , population , lifeexpectancy , gnp , gnpold , localname , governmentform , headofstate , capital , code2 | countrylanguage : countrycode , language , isofficial , percentage')

In [None]:
play_pred("{}; structed knowledge: {}".format(text_in, struct_in), model, tokenizer_fast)

### Tokenized pieces-nodes mapping

In [130]:
# TODO (done):
#     find mapping between tokenized pieces to ratsql nodes
#         ratsql nodes to sentence char ids (done)
#         sentence char ids to tokenized pieces (done)
#     pool the encodings of pieces in a node (done)

In [165]:
class StructCharRangesCollector:
    def __init__(self):
        self.initialize()

    def initialize(self):
        self.db_id2char_ranges = dict()
        self.table2char_ranges = dict()
        self.column2char_ranges = dict()
        
        # Due to rat-sql stemming tokens, rat-sql nodes and uskg text may mismatch,
        # so we save a list and use the order instead of name for indexing 
        self.db_id_char_ranges_list = []
        self.column_char_ranges_list = []
        self.table_char_ranges_list = []

        self.bar_cnt = 0
        self.curr_table = None
        self.curr_node_type = None   # [None, 'db_id', 'table', 'column']
        self.curr_node_toks = []
        self.curr_node_char_start = None
        self.curr_node_char_end = None
        self.open_bracket = False
        
    def _register_curr_node(self):
        curr_node_name = ' '.join(self.curr_node_toks)
        curr_range = (self.curr_node_char_start, self.curr_node_char_end)
        
        if self.curr_node_type == 'db_id':
            self.db_id2char_ranges[curr_node_name] = curr_range
            self.db_id_char_ranges_list.append(curr_range)   
        elif self.curr_node_type == 'table':
            self.table2char_ranges[curr_node_name] = curr_range
            self.table_char_ranges_list.append(curr_range)
            self.curr_table = curr_node_name
        elif self.curr_node_type == 'column':
            self.column2char_ranges[(self.curr_table, curr_node_name)] = curr_range
            self.column_char_ranges_list.append(curr_range)
        else:
            raise ValueError(curr_node_type)

        self.curr_node_toks = []
        self.curr_node_char_start = None
        self.curr_node_char_end = None
    
    def collect(self, struct_in, tokenized_txt, _n_words_before_struct):
#         struct_words = struct_in.strip().split(' ')
        struct_words = struct_in.strip().split()
        
        for sw_id, sw in enumerate(struct_words):
            char_range = tokenized_txt.word_to_chars(sw_id + _n_words_before_struct)

            # print(sw_id, char_range, sw, self.curr_node_type, self.open_bracket)

            if sw == '(':
                self.open_bracket = True
                continue

            if sw == ')':
                self.open_bracket = False
                self.curr_node_char_end = char_range[1]
                continue

            if self.open_bracket:
                # in the list of cells, do not add tokens here to name 
                continue

            if sw == '|':
                if self.curr_node_type is not None:
                    self._register_curr_node()
                self.bar_cnt += 1
                if self.bar_cnt == 1:
                    self.curr_node_type = 'db_id'
                if self.bar_cnt > 1:
                    self.curr_node_type = 'table'
                continue

            if sw == ':':
                assert self.curr_node_type == 'table'
                self._register_curr_node()
                self.curr_node_type = 'column'
                continue

            if sw == ',':
                assert self.curr_node_type == 'column'
                self._register_curr_node()
                self.curr_node_type = 'column'
                continue

            self.curr_node_toks.append(sw)
            if self.curr_node_char_start is None:
                self.curr_node_char_start = char_range[0]
            self.curr_node_char_end = char_range[1]

        self._register_curr_node()


In [166]:
## Combine experiment codes 
def collect_node_char_ranges(sample, tokenizer=None, tokenizer_args=None, txt=None, tokenized_txt=None, debug=False):
    text_in = sample['question']
    struct_in = uskg_sample_to_struct_input(sample)
    _splitter = "; structed knowledge: "
    
    if tokenizer_args is None:
        tokenizer_args = dict()
        
    if txt is None:
        txt = "{}{}{}".format(text_in, _splitter, struct_in)
    
    if tokenized_txt is None:
        # tokenized_txt = tokenizer([txt], max_length=1024, padding="max_length", truncation=True)
        tokenized_txt = tokenizer([txt], **tokenizer_args)
        ## possible problem: exceeding max length!
    
    ratsql_graph_nodes = sample['rat_sql_graph']['nodes']
#     question_toks = sample['question_toks']
    question_toks = sample['rat_sql_graph']['q_nodes_orig']

    _q_nodes = []  # [stem token (node name)]
    q_nodes = []  # [(stem token (node name), orig question token)]
    c_nodes = []  # [(orig table name, orig column name)]
    t_nodes = []  # [orig table name]

    for n in ratsql_graph_nodes:
        if n.startswith('<C>'):
            _n = n[3:]
            _t, _c = _n.split('::')
            c_nodes.append((_t, _c))
        elif n.startswith('<T>'):
            _n = n[3:]
            t_nodes.append(_n)
        else:
            _q_nodes.append(n)

    assert len(_q_nodes) == len(question_toks), (_q_nodes, question_toks)
    q_nodes = list(zip(_q_nodes, question_toks))
    
    # Collection char ranges 
    q_node_chars = []   # [(st, ed)]; same below
    c_node_chars = []
    t_node_chars = []
    
    # Text part
    # Assumption: the mismatch between whitespace words (text_words) and question words only come from trailing puncts
    # Currently the code can handle combining question toks into whitespace words
#     text_words = text_in.strip().split(' ') + ['<SENTINAL>']
    text_words = text_in.lower().strip().split() + ['<SENTINAL>']
    text_word_char_ranges = [tokenized_txt.word_to_chars(i) for i in range(len(text_words) - 1)] + [(None, None)]  # -1 to remove the sentinal 

    curr_tw_idx = 0
    curr_tw = text_words[0]
    curr_tw_char_range = text_word_char_ranges[0]
    curr_char_ptr = 0
    for stem_tok, orig_tok in q_nodes:
        if curr_tw == orig_tok:
            # finishing current word 
            q_node_chars.append((curr_char_ptr, curr_char_ptr + len(orig_tok)))   # curr pos to curr pos + len 
            curr_tw_idx += 1
            curr_tw = text_words[curr_tw_idx]
            curr_tw_char_range = text_word_char_ranges[curr_tw_idx]
            curr_char_ptr = curr_tw_char_range[0]
        else:
            # not finishing current word 
            assert curr_tw.startswith(orig_tok), (curr_tw, orig_tok)
            q_node_chars.append((curr_char_ptr, curr_char_ptr + len(orig_tok)))   # curr pos to curr pos + len 
            curr_char_ptr += len(orig_tok)     # move ptr forward by len 
            curr_tw = curr_tw[len(orig_tok):]  # get the remaining chars in the word 

    assert [txt[st:ed].lower() for st, ed in q_node_chars] == question_toks, ([txt[st:ed] for st, ed in q_node_chars], question_toks)
    
    # Struct part 
    _str_before_struct = text_in + _splitter
    _n_words_before_struct = len(_str_before_struct.strip().split())

    struct_ranges_collector = StructCharRangesCollector()
    struct_ranges_collector.collect(struct_in, tokenized_txt, _n_words_before_struct)
    
    # Due to rat-sql stemming tokens, rat-sql nodes and uskg text may mismatch
#     for c_node in c_nodes:
#         if c_node == ('NONE', '*'):
#             # the special column in spider, using db_id 
#             c_node_chars.append(list(struct_ranges_collector.db_id2char_ranges.values())[0])   # assuming only 1 db_id, which should be true...
#         else:
#             c_node_chars.append(struct_ranges_collector.column2char_ranges[c_node])

#     for t_node in t_nodes:
#         t_node_chars.append(struct_ranges_collector.table2char_ranges[t_node])

    c_node_chars.extend(struct_ranges_collector.db_id_char_ranges_list + struct_ranges_collector.column_char_ranges_list)
    t_node_chars.extend(struct_ranges_collector.table_char_ranges_list)

    ## Check all 
    if debug:
        for q_node, (st, ed) in zip(q_nodes, q_node_chars):
            print(q_node, txt[st:ed])
        print()
        for t_node, (st, ed) in zip(t_nodes, t_node_chars):
            print(t_node, txt[st:ed])
        print()
        for c_node, (st, ed) in zip(c_nodes, c_node_chars):
            print(c_node, txt[st:ed])
        
    return {
        "q_node_chars": q_node_chars,
        "c_node_chars": c_node_chars,
        "t_node_chars": t_node_chars,
    }
    

In [167]:
idx = 209
sample = orig_dataset[idx]

text_in = sample['question']
struct_in = uskg_sample_to_struct_input(sample)

txt = "{}; structed knowledge: {}".format(text_in, struct_in)

tokenized_txt = tokenizer_fast([txt], max_length=1024, padding="max_length", truncation=True)

In [168]:
# tokenizer_args = {
#     "max_length": 1024,
#     "padding": "max_length",
#     "truncation": True
# }

# char_ranges_dict = collect_node_char_ranges(sample, tokenizer=tokenizer_fast, tokenizer_args=tokenizer_args, debug=True)
char_ranges_dict = collect_node_char_ranges(sample, txt=txt, tokenized_txt=tokenized_txt, debug=True)

('how', 'how') How
('many', 'many') many
('flight', 'flights') flights
('arrive', 'arriving') arriving
('in', 'in') in
('aberdeen', 'aberdeen') Aberdeen
('city', 'city') city
('?', '?') ?

airline airlines
airport airports
flight flights

('NONE', '*') flight_2
('airline', 'uid') uid
('airline', 'airline') airline
('airline', 'abbreviation') abbreviation
('airline', 'country') country
('airport', 'city') city ( Aberdeen  )
('airport', 'airportcode') airportcode
('airport', 'airportname') airportname
('airport', 'country') country
('airport', 'countryabbrev') countryabbrev
('flight', 'airline') airline
('flight', 'flightno') flightno
('flight', 'sourceairport') sourceairport
('flight', 'destairport') destairport


In [164]:
txt

'How many flights arriving in Aberdeen city?; structed knowledge:  | flight_2 | airlines : uid , airline , abbreviation , country | airports : city ( Aberdeen  ) , airportcode , airportname , country , countryabbrev | flights : airline , flightno , sourceairport , destairport'

In [169]:
char_ranges_dict

{'q_node_chars': [(0, 3),
  (4, 8),
  (9, 16),
  (17, 25),
  (26, 28),
  (29, 37),
  (38, 42),
  (42, 43)],
 'c_node_chars': [(68, 76),
  (90, 93),
  (96, 103),
  (106, 118),
  (121, 128),
  (142, 160),
  (163, 174),
  (177, 188),
  (191, 198),
  (201, 214),
  (227, 234),
  (237, 245),
  (248, 261),
  (264, 275)],
 't_node_chars': [(79, 87), (131, 139), (217, 224)]}

In [174]:
tokenized_txt.char_to_token(79)

20

In [175]:
tokenized_txt.tokens()[20]

'▁airlines'

In [196]:
sum(tokenized_txt.data['attention_mask'][0])

87

### Get encoding

In [191]:
def get_USKG_node_encodings(sample, model, tokenizer, tokenizer_args=None, pooling_func=None, debug=False):
    """
    Args:
        pooling_func (Callable): np.array(n_pieces, dim) ==> np.array(dim,); default is np.mean
    """
    text_in = sample['question']
    struct_in = uskg_sample_to_struct_input(sample)
    
    _splitter = "; structed knowledge: "
    txt = "{}{}{}".format(text_in, _splitter, struct_in)

    if tokenizer_args is None:
        tokenizer_args = {
            "max_length": 1024,
            "padding": "max_length",
            "truncation": True
        }
    if pooling_func is None:
        pooling_func = lambda l: np.mean(l, axis=0)
        
    tokenized_txt = tokenizer([txt], **tokenizer_args)
    
    # Get encoding tensor 
    with torch.no_grad():
        past_prompt = model.get_prompt(
            bsz=1,              # bsz = input_ids.shape[0]
            sample_size=1,      # sample_size=kwargs['num_beams']
            description=None,   
            knowledge=None,     
        )
        encoder_outputs = model.pretrain_model.encoder(
            input_ids=torch.LongTensor(tokenized_txt.data['input_ids']),
            attention_mask=torch.LongTensor(tokenized_txt.data['attention_mask']),
            past_prompt=past_prompt,
        )
    encoder_output_hidden_states = encoder_outputs.last_hidden_state.detach().squeeze(0).cpu().numpy()
    if debug:
        print('encoder_output_hidden_states:', encoder_output_hidden_states.shape)
    
    # Get node-pieces mapping via char ranges 
    char_ranges_dict = collect_node_char_ranges(sample, txt=txt, tokenized_txt=tokenized_txt)
    node_char_ranges = char_ranges_dict['q_node_chars'] + char_ranges_dict['c_node_chars'] + char_ranges_dict['t_node_chars']

    # some chars can be mapped to multiple tokens (e.g. 'i' => '▁', 'i' )
    char_to_tokens_dict = defaultdict(list)
    for token_idx, tok in enumerate(tokenized_txt.tokens()):
        if tok == '</s>':
            break
        char_span = tokenized_txt.token_to_chars(token_idx)
        for char_idx in range(char_span[0], char_span[1]):
            char_to_tokens_dict[char_idx].append(token_idx)
    
    node_pieces_ranges = []
    for st, ed in node_char_ranges:
        piece_ids = []
        for char_idx in range(st, ed):
            _piece_ids = char_to_tokens_dict[char_idx]
            piece_ids.extend(_piece_ids)

        piece_st = piece_ids[0]
        piece_ed = piece_ids[-1] + 1
        # the collected piece_ids should be continuous 
        # ^ not true... some chars can be mapped to multiple tokens (started by ▁ )
        # re-collect a char-to-token
        assert set(range(piece_st, piece_ed)) == set(piece_ids), piece_ids

        node_pieces_ranges.append((piece_st, piece_ed))
    
    if debug:
        print('node_pieces_ranges:', node_pieces_ranges)
    
    # Pool the encodings per node 
    node_encodings = []
    for piece_st, piece_ed in node_pieces_ranges:
        enc_vecs = encoder_output_hidden_states[piece_st : piece_ed]
        enc_pooled = pooling_func(enc_vecs)
        node_encodings.append(enc_pooled)
    
    return node_encodings

In [177]:
idx = 209
sample = orig_dataset[idx]

text_in = sample['question']
struct_in = uskg_sample_to_struct_input(sample)

txt = "{}; structed knowledge: {}".format(text_in, struct_in)

tokenized_txt = tokenizer_fast([txt], max_length=1024, padding="max_length", truncation=True)

In [178]:
node_encodings = get_USKG_node_encodings(sample, model, tokenizer_fast, debug=True)

encoder_output_hidden_states: (1024, 1024)
node_pieces_ranges: [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (16, 19), (23, 27), (29, 30), (32, 36), (38, 39), (44, 49), (51, 53), (55, 57), (59, 60), (62, 66), (70, 71), (73, 76), (78, 81), (83, 86), (20, 21), (40, 42), (67, 68)]


In [179]:
node_encodings[0].shape

(1024,)

## Probing: link prediction

### Data collection

In [149]:
def extract_probing_samples_link_prediction_uskg(dataset_sample,
                                                 db_schemas_dict,
                                                 model,
                                                 tokenizer,
                                                 pos=None,
                                                 max_rel_occ=None,
                                                 debug=False):
    """
    Args:
        dataset_sample (Dict): a sample dict from spider dataset
        db_schemas_dict (Dict): db_id => db_schema, precomputed for all DBs (not used here)
        model (EncDec): the rat-sql model
        pos (List[Tuple]): the position pairs to use. If none, will randomly generate
        max_rel_occ (int): each relation occur at most this many times in each (original) sample
    
    Return:
        X (List[np.array]): input features, "shape" = (n, dim)
        y (List): output labels, "shape" = (n,)
        pos (List[Tuple]): actual position (node-id) pairs for X and y
    """
    
    d = dataset_sample
    
    db_id = d['db_id']
    # db_schema = db_schemas_dict[db_id]
    question = d['question']

    # get relation matrix (relation_id2name not available as it needs rat-sql model)
    graph_dict = dataset_sample['rat_sql_graph']
    # graph_dict['relation_id2name'] = {v : k for k, v in model.encoder.encs_update.relation_ids.items()}
    
    # get encodings
    # rat_sql_encoder_state = get_rat_sql_encoder_state(question=question, db_schema=db_schema, model=model)
    # enc_repr = rat_sql_encoder_state.memory.squeeze(0).detach().cpu().numpy()
    enc_repr = get_USKG_node_encodings(sample=dataset_sample,
                                       model=model,
                                       tokenizer=tokenizer,
                                       debug=debug)
    
    X, y, pos = collect_link_prediction_samples(
        graph_dict,
        enc_repr,
        pos=pos,
        max_rel_occ=max_rel_occ,
        debug=debug)
    
    return X, y, pos



In [187]:
probing_data_dir = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/probing/text2sql/link_prediction/spider/ratsql"

orig_ds = 'dev'
prob_ds = 'test'
dataset_path = f"/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/spider/{orig_ds}+ratsql_graph.json"

pos_file_path = os.path.join(probing_data_dir, f'{orig_ds}.{prob_ds}.pos.txt')

In [188]:
with open(dataset_path, 'r') as f:
    orig_dataset = json.load(f)
    
for d in orig_dataset:
    d['rat_sql_graph']['relations'] = json.loads(d['rat_sql_graph']['relations'])

len(orig_dataset), orig_dataset[0].keys()

(1034,
 dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql', 'rat_sql_graph']))

In [189]:
with open(pos_file_path, 'r') as f:
    lines = f.read().strip().split('\n')
    all_pos_triplets = [tuple([int(s) for s in l.split('\t')]) for l in lines]
len(all_pos_triplets), all_pos_triplets[0]

# Load pos file 
sample_ds_indices = []               # [ds_idx], based on occurring order 
pos_per_sample = defaultdict(list)   # key = ds_idx, value = pos_list: List[(i, j)]

for ds_idx, i, j in all_pos_triplets:
    if not sample_ds_indices or sample_ds_indices[-1] != ds_idx:
        sample_ds_indices.append(ds_idx)
    pos_per_sample[ds_idx].append((i, j))

len(sample_ds_indices), len(pos_per_sample)

(500, 500)

In [190]:
# test loading pos file 
set(sample_ds_indices) == set(pos_per_sample.keys())

True

In [180]:
# TODO: for each sample_ds_idx, get the sample & pos_list, get the X, y, pos; save. 

all_X = []
all_y = []
all_pos = []

for sample_ds_idx in tqdm(sample_ds_indices):
    dataset_sample = orig_dataset[sample_ds_idx]
    pos_list = pos_per_sample[sample_ds_idx]

    X, y, pos = extract_probing_samples_link_prediction_uskg(dataset_sample=dataset_sample,
                                                             db_schemas_dict=None,
                                                             model=model,
                                                             tokenizer=tokenizer_fast,
                                                             pos=pos_list,
                                                             max_rel_occ=None,  # when given pos, this is not needed 
                                                             debug=False)
    
    all_X.extend(X)
    all_y.extend(y)
    pos = [(sample_ds_idx, i, j) for i, j in pos]   # add sample idx 
    all_pos.extend(pos)
    
    time.sleep(0.5)

len(all_X), len(all_y), len(all_pos)

  0%|          | 0/500 [00:00<?, ?it/s]

(16059, 16059, 16059)

In [155]:
sample_ds_idx

209

In [183]:
probing_data_out_dir = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/probing/text2sql/link_prediction/spider/uskg"
os.makedirs(probing_data_out_dir, exist_ok=True)

output_path_test_X = os.path.join(probing_data_out_dir, f'{orig_ds}.{prob_ds}.X.pkl')
output_path_test_y = os.path.join(probing_data_out_dir, f'{orig_ds}.{prob_ds}.y.pkl')
output_path_test_pos = os.path.join(probing_data_out_dir, f'{orig_ds}.{prob_ds}.pos.txt')

with open(output_path_test_X, 'wb') as f:
    pickle.dump(all_X, f)
with open(output_path_test_y, 'wb') as f:
    pickle.dump(all_y, f)
with open(output_path_test_pos, 'w') as f:
    for idx, i, j in all_pos:
        f.write(f'{idx}\t{i}\t{j}\n')

## Temp

### Test tokenizer

In [108]:
tokenizer_fast = AutoTokenizer.from_pretrained('t5-base', use_fast=True)

In [None]:
test_txt = "This is t5's tokenization.; structed knowledge: | model | plm(t5), rnn"

In [128]:
tok_fast = tokenizer_fast.tokenize(test_txt)
tok_fast

['▁This',
 '▁is',
 '▁',
 't',
 '5',
 "'",
 's',
 '▁token',
 'ization',
 '.',
 ';',
 '▁',
 'struct',
 'e',
 'd',
 '▁knowledge',
 ':',
 '▁|',
 '▁model',
 '▁|',
 '▁pl',
 'm',
 '(',
 't',
 '5)',
 ',',
 '▁',
 'r',
 'n',
 'n']

In [113]:
encoded = tokenizer.encode(test_txt)
encoded_fast = tokenizer_fast.encode(test_txt)
type(encoded), type(encoded_fast)

(list, list)

In [115]:
list(zip(encoded, encoded_fast))

[(100, 100),
 (19, 19),
 (3, 3),
 (17, 17),
 (755, 755),
 (31, 31),
 (7, 7),
 (14145, 14145),
 (1707, 1707),
 (5, 5),
 (117, 117),
 (3, 3),
 (7593, 7593),
 (15, 15),
 (26, 26),
 (1103, 1103),
 (10, 10),
 (1820, 1820),
 (825, 825),
 (1820, 1820),
 (4752, 4752),
 (51, 51),
 (599, 599),
 (17, 17),
 (9120, 9120),
 (6, 6),
 (3, 3),
 (52, 52),
 (29, 29),
 (29, 29),
 (1, 1)]

In [118]:
ed_fast = tokenizer_fast(test_txt)
type(ed_fast), ed_fast

(transformers.tokenization_utils_base.BatchEncoding,
 {'input_ids': [100, 19, 3, 17, 755, 31, 7, 14145, 1707, 5, 117, 3, 7593, 15, 26, 1103, 10, 1820, 825, 1820, 4752, 51, 599, 17, 9120, 6, 3, 52, 29, 29, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]})

In [133]:
len(ed_fast.data['input_ids'])

31

In [124]:
ed_fast.word_ids()

[0,
 1,
 2,
 2,
 2,
 2,
 2,
 3,
 3,
 3,
 3,
 4,
 4,
 4,
 4,
 5,
 5,
 6,
 7,
 8,
 9,
 9,
 9,
 9,
 9,
 9,
 10,
 10,
 10,
 10,
 None]

In [131]:
len(ed_fast.word_ids())

31

In [132]:
len(tok_fast)

30

In [135]:
for sw_id, w_id in enumerate(ed_fast.word_ids()[:-1]):
    # the last is eos
    sw = tok_fast[sw_id]
    w = test_txt.split(' ')[w_id]
    print(f'{sw:<15s}{w}')

▁This          This
▁is            is
▁              t5's
t              t5's
5              t5's
'              t5's
s              t5's
▁token         tokenization.;
ization        tokenization.;
.              tokenization.;
;              tokenization.;
▁              structed
struct         structed
e              structed
d              structed
▁knowledge     knowledge:
:              knowledge:
▁|             |
▁model         model
▁|             |
▁pl            plm(t5),
m              plm(t5),
(              plm(t5),
t              plm(t5),
5)             plm(t5),
,              plm(t5),
▁              rnn
r              rnn
n              rnn
n              rnn


In [143]:
ed_fast.word_to_tokens(3)

TokenSpan(start=7, end=11)

In [144]:
ed = tokenizer(test_txt)
type(ed), ed

(transformers.tokenization_utils_base.BatchEncoding,
 {'input_ids': [100, 19, 3, 17, 755, 31, 7, 14145, 1707, 5, 117, 3, 7593, 15, 26, 1103, 10, 1820, 825, 1820, 4752, 51, 599, 17, 9120, 6, 3, 52, 29, 29, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]})

In [None]:
ed.word_ids()

#### Different tokenizers consistency

In [301]:
ds_path = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/spider/dev+ratsql_graph.json"

with open(ds_path, 'r') as f:
    dataset2 = json.load(f)

for d in dataset2:
    d['rat_sql_graph']['relations'] = json.loads(d['rat_sql_graph']['relations'])

len(dataset2)

1034

In [302]:
for i, d in enumerate(tqdm(dataset2)):
    text_in = d['question']
    struct_in = uskg_sample_to_struct_input(d)
    txt = "{}; structed knowledge: {}".format(text_in, struct_in)
    
    if "<" not in txt:
        continue
    
    uskg_tokenized_txt = tokenizer(txt, max_length=1024, padding="max_length", truncation=True)
    base_tokenized_txt = tokenizer_fast(txt, max_length=1024, padding="max_length", truncation=True)
    large_tokenized_txt = tokenizer_large_fast(txt, max_length=1024, padding="max_length", truncation=True)
    
    assert uskg_tokenized_txt.data == base_tokenized_txt.data == large_tokenized_txt.data, (i, txt)
    

  0%|          | 0/1034 [00:00<?, ?it/s]

In [288]:
vocab_uskg = tokenizer.get_vocab()
vocab_base = tokenizer_fast.get_vocab()
vocab_large = tokenizer_large_fast.get_vocab()

In [298]:
for k in set(vocab_uskg.keys()) | set(vocab_base.keys()) | set(vocab_large.keys()):
    v = vocab_uskg.get(k, None)
    v1 = vocab_base.get(k, None)
    v2 = vocab_large.get(k, None)
    if not v == v1 == v2:
        print(k, v, v1, v2)
#     assert v == v2, (k, v, v2)

 <= 32101 None None
 < 32100 None None


In [295]:
vocab_uskg[' <']

32100

In [303]:
for i, d in enumerate(tqdm(dataset2)):
    text_in = d['question']
    struct_in = uskg_sample_to_struct_input(d)
    txt = "{}; structed knowledge: {}".format(text_in, struct_in)
    
    if "<" in txt:
        print(i, txt)

  0%|          | 0/1034 [00:00<?, ?it/s]

#### max length

In [305]:
idx = 777
text_in = dataset2[idx]['question']
struct_in = uskg_sample_to_struct_input(d)
txt = "{}; structed knowledge: {}".format(text_in, struct_in)
txt

'What are the Asian countries which have a population larger than that of any country in Africa?; structed knowledge:  | real_estate_properties | ref_feature_types : feature_type_code , feature_type_name | ref_property_types : property_type_code ( Apartment , House ) , property_type_description | other_available_features : feature_id , feature_type_code , feature_name , feature_description | properties : property_id , property_type_code ( Apartment , House ) , date_on_market , date_sold , property_name , property_address , room_count , vendor_requested_price , buyer_offered_price , agreed_selling_price , apt_feature_1 , apt_feature_2 , apt_feature_3 , fld_feature_1 , fld_feature_2 , fld_feature_3 , hse_feature_1 , hse_feature_2 , hse_feature_3 , oth_feature_1 , oth_feature_2 , oth_feature_3 , shp_feature_1 , shp_feature_2 , shp_feature_3 , other_property_details | other_property_features : property_id , feature_id , property_feature_description'

In [308]:
_tkn_txt = tokenizer_fast(txt, max_length=32, truncation=True)

In [316]:
for i in range(31):
    st, ed = _tkn_txt.token_to_chars(i)
    print(txt[st:ed])

What
are
the
Asian
countries
which
have
a
a
population
larger
than
that
of
any
country
in
Africa
?
;
s
struct
e
d
knowledge
:
|
real
_
e
state


In [318]:
list(enumerate(txt.split()))

[(0, 'What'),
 (1, 'are'),
 (2, 'the'),
 (3, 'Asian'),
 (4, 'countries'),
 (5, 'which'),
 (6, 'have'),
 (7, 'a'),
 (8, 'population'),
 (9, 'larger'),
 (10, 'than'),
 (11, 'that'),
 (12, 'of'),
 (13, 'any'),
 (14, 'country'),
 (15, 'in'),
 (16, 'Africa?;'),
 (17, 'structed'),
 (18, 'knowledge:'),
 (19, '|'),
 (20, 'real_estate_properties'),
 (21, '|'),
 (22, 'ref_feature_types'),
 (23, ':'),
 (24, 'feature_type_code'),
 (25, ','),
 (26, 'feature_type_name'),
 (27, '|'),
 (28, 'ref_property_types'),
 (29, ':'),
 (30, 'property_type_code'),
 (31, '('),
 (32, 'Apartment'),
 (33, ','),
 (34, 'House'),
 (35, ')'),
 (36, ','),
 (37, 'property_type_description'),
 (38, '|'),
 (39, 'other_available_features'),
 (40, ':'),
 (41, 'feature_id'),
 (42, ','),
 (43, 'feature_type_code'),
 (44, ','),
 (45, 'feature_name'),
 (46, ','),
 (47, 'feature_description'),
 (48, '|'),
 (49, 'properties'),
 (50, ':'),
 (51, 'property_id'),
 (52, ','),
 (53, 'property_type_code'),
 (54, '('),
 (55, 'Apartment'),

In [320]:
st, ed = _tkn_txt.word_to_chars(20)
st, ed, txt[st:ed]

(120, 131, 'real_estate')

In [322]:
_tkn_txt.word_to_chars(21)

TypeError: type object argument after * must be an iterable, not NoneType

### Test get encoding

In [60]:
type(model.pretrain_model.encoder)

models.prompt.modeling_t5.T5Stack

In [61]:
# pred = tokenizer.batch_decode(
#   model.generate(
#     torch.LongTensor(tokenized_txt.data['input_ids']),
#     torch.LongTensor(tokenized_txt.data['attention_mask']),
#     num_beams=1, 
#     max_length=256
#     ), 
#   skip_special_tokens=True 
# )

In [62]:
# generated_ids = self.pretrain_model.generate(
#     input_ids=input_ids,
#     attention_mask=attention_mask,
#     past_prompt=past_prompt,
#     use_cache=True,
#     **kwargs,
# )

# model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# outputs = self(
#     **model_inputs,
#     return_dict=True,
#     output_attentions=output_attentions,
#     output_hidden_states=output_hidden_states,
# )

In [64]:
past_prompt = model.get_prompt(
    bsz=1,              # bsz = input_ids.shape[0]
    sample_size=1,      # sample_size=kwargs['num_beams']
    description=None,   
    knowledge=None,     
)

In [65]:
encoder_outputs = model.pretrain_model.encoder(
    input_ids=torch.LongTensor(tokenized_txt.data['input_ids']),
    attention_mask=torch.LongTensor(tokenized_txt.data['attention_mask']),
    past_prompt=past_prompt,
)

In [66]:
past_prompt[0].keys()

dict_keys(['decoder_prompt', 'cross_attention_prompt', 'encoder_prompt'])

In [67]:
encoder_outputs.__dict__.keys()

dict_keys(['last_hidden_state', 'past_key_values', 'hidden_states', 'attentions', 'cross_attentions'])

In [68]:
encoder_outputs.last_hidden_state.size()

torch.Size([1, 1024, 1024])

In [69]:
sum(tokenized_txt.data['attention_mask'][0])

158

In [70]:
# has zero-padding and EOS, no BOS  
len(tokenized_txt.data['input_ids'][0]), np.greater(tokenized_txt.data['input_ids'][0], 0).sum()

(1024, 158)

### Tokenized pieces-nodes mapping (old)

In [None]:
tokenized_txt.data['input_ids'][0][:150]

In [213]:
len(tokenizer.tokenize(txt))

157

In [214]:
tokenizer.tokenize(txt)

['▁What',
 '▁are',
 '▁the',
 '▁Asian',
 '▁countries',
 '▁which',
 '▁have',
 '▁',
 'a',
 '▁population',
 '▁larger',
 '▁than',
 '▁that',
 '▁of',
 '▁any',
 '▁country',
 '▁in',
 '▁Africa',
 '?',
 ';',
 '▁',
 'struct',
 'e',
 'd',
 '▁knowledge',
 ':',
 '▁|',
 '▁world',
 '_',
 '1',
 '▁|',
 '▁city',
 '▁',
 ':',
 '▁',
 'i',
 'd',
 '▁',
 ',',
 '▁name',
 '▁',
 ',',
 '▁country',
 'code',
 '▁',
 ',',
 '▁district',
 '▁',
 ',',
 '▁population',
 '▁|',
 '▁sq',
 'lite',
 '_',
 's',
 'e',
 'que',
 'nce',
 '▁',
 ':',
 '▁name',
 '▁',
 ',',
 '▁se',
 'q',
 '▁|',
 '▁country',
 '▁',
 ':',
 '▁code',
 '▁',
 ',',
 '▁name',
 '▁',
 ',',
 '▁continent',
 '▁(',
 '▁Africa',
 '▁',
 ',',
 '▁Asia',
 '▁',
 ')',
 '▁',
 ',',
 '▁region',
 '▁',
 ',',
 '▁surface',
 'area',
 '▁',
 ',',
 '▁in',
 'de',
 'p',
 'year',
 '▁',
 ',',
 '▁population',
 '▁',
 ',',
 '▁life',
 'ex',
 'pe',
 'c',
 't',
 'ancy',
 '▁',
 ',',
 '▁',
 'g',
 'n',
 'p',
 '▁',
 ',',
 '▁',
 'g',
 'n',
 'pol',
 'd',
 '▁',
 ',',
 '▁local',
 'name',
 '▁',
 ',',
 '▁gove

In [215]:
' '.join(sample['rat_sql_graph']['nodes']), txt

('what be the asian country which have a population larger than that of any country in africa ? <C>NONE::* <C>city::id <C>city::name <C>city::countrycode <C>city::district <C>city::population <C>sqlite_sequence::name <C>sqlite_sequence::seq <C>country::code <C>country::name <C>country::continent <C>country::region <C>country::surfacearea <C>country::indepyear <C>country::population <C>country::lifeexpectancy <C>country::gnp <C>country::gnpold <C>country::localname <C>country::governmentform <C>country::headofstate <C>country::capital <C>country::code2 <C>countrylanguage::countrycode <C>countrylanguage::language <C>countrylanguage::isofficial <C>countrylanguage::percentage <T>city <T>sqlite_sequence <T>country <T>countrylanguage',
 'What are the Asian countries which have a population larger than that of any country in Africa?; structed knowledge:  | world_1 | city : id , name , countrycode , district , population | sqlite_sequence : name , seq | country : code , name , continent ( Afri

In [216]:
' '.join(sample['question_toks'])

'What are the Asian countries which have a population larger than that of any country in Africa ?'

In [218]:
type(tokenized_txt)

transformers.tokenization_utils_base.BatchEncoding

#### Nodes to char ranges

In [219]:
# ratsql nodes to sentence char ids 

ratsql_graph_nodes = sample['rat_sql_graph']['nodes']
question_toks = sample['question_toks']

_q_nodes = []  # [stem token (node name)]
q_nodes = []  # [(stem token (node name), orig question token)]
c_nodes = []  # [(orig table name, orig column name)]
t_nodes = []  # [orig table name]

for n in ratsql_graph_nodes:
    if n.startswith('<C>'):
        _n = n[3:]
        _t, _c = _n.split('::')
        c_nodes.append((_t, _c))
    elif n.startswith('<T>'):
        _n = n[3:]
        t_nodes.append(_n)
    else:
        _q_nodes.append(n)

assert len(_q_nodes) == len(question_toks), (_q_nodes, question_toks)
q_nodes = list(zip(_q_nodes, question_toks))

q_nodes, c_nodes, t_nodes

([('what', 'What'),
  ('be', 'are'),
  ('the', 'the'),
  ('asian', 'Asian'),
  ('country', 'countries'),
  ('which', 'which'),
  ('have', 'have'),
  ('a', 'a'),
  ('population', 'population'),
  ('larger', 'larger'),
  ('than', 'than'),
  ('that', 'that'),
  ('of', 'of'),
  ('any', 'any'),
  ('country', 'country'),
  ('in', 'in'),
  ('africa', 'Africa'),
  ('?', '?')],
 [('NONE', '*'),
  ('city', 'id'),
  ('city', 'name'),
  ('city', 'countrycode'),
  ('city', 'district'),
  ('city', 'population'),
  ('sqlite_sequence', 'name'),
  ('sqlite_sequence', 'seq'),
  ('country', 'code'),
  ('country', 'name'),
  ('country', 'continent'),
  ('country', 'region'),
  ('country', 'surfacearea'),
  ('country', 'indepyear'),
  ('country', 'population'),
  ('country', 'lifeexpectancy'),
  ('country', 'gnp'),
  ('country', 'gnpold'),
  ('country', 'localname'),
  ('country', 'governmentform'),
  ('country', 'headofstate'),
  ('country', 'capital'),
  ('country', 'code2'),
  ('countrylanguage', 'count

In [220]:
txt

'What are the Asian countries which have a population larger than that of any country in Africa?; structed knowledge:  | world_1 | city : id , name , countrycode , district , population | sqlite_sequence : name , seq | country : code , name , continent ( Africa , Asia ) , region , surfacearea , indepyear , population , lifeexpectancy , gnp , gnpold , localname , governmentform , headofstate , capital , code2 | countrylanguage : countrycode , language , isofficial , percentage'

In [222]:
# tokenized_txt = tokenizer_fast([txt], max_length=1024, padding="max_length", truncation=True)

In [223]:
# txt = "This is t5's tokenization.; structed knowledge: | model | plm(t5), rnn"

# tokenized_txt = tokenizer_fast([test_txt], max_length=1024, padding="max_length", truncation=True)

In [224]:
# number of whitespace words 
max([w_id for w_id in tokenized_txt.word_ids() if w_id is not None]) + 1, len(txt.split())

(86, 86)

In [230]:
_splitter = "; structed knowledge:  "  # struct_in has a preceding white space 
text_in, struct_in = txt.split(_splitter)

q_node_chars = []   # [(st, ed)]; same below
c_node_chars = []
t_node_chars = []

In [231]:
# Text part
# Assumption: the mismatch between whitespace words (text_words) and question words only come from trailing puncts

text_words = text_in.strip().split(' ') + ['<SENTINAL>']
text_word_char_ranges = [tokenized_txt.word_to_chars(i) for i in range(len(text_words) - 1)] + [(None, None)]  # -1 to remove the sentinal 

len(text_words), len(text_word_char_ranges), text_word_char_ranges[0]

(18, 18, CharSpan(start=0, end=4))

In [232]:
curr_tw_idx = 0
curr_tw = text_words[0]
curr_tw_char_range = text_word_char_ranges[0]
curr_char_ptr = 0

for stem_tok, orig_tok in q_nodes:
    if curr_tw == orig_tok:
        # finishing current word 
        q_node_chars.append((curr_char_ptr, curr_char_ptr + len(orig_tok)))   # curr pos to curr pos + len 
        curr_tw_idx += 1
        curr_tw = text_words[curr_tw_idx]
        curr_tw_char_range = text_word_char_ranges[curr_tw_idx]
        curr_char_ptr = curr_tw_char_range[0]
    else:
        # not finishing current word 
        assert curr_tw.startswith(orig_tok), (curr_tw, orig_tok)
        q_node_chars.append((curr_char_ptr, curr_char_ptr + len(orig_tok)))   # curr pos to curr pos + len 
        curr_char_ptr += len(orig_tok)     # move ptr forward by len 
        curr_tw = curr_tw[len(orig_tok):]  # get the remaining chars in the word 


In [233]:
[txt[st:ed] for st, ed in q_node_chars] == sample['question_toks']

True

In [235]:
# Struct part 
assert not struct_in.startswith(' ')
_str_before_struct = text_in + _splitter
_n_words_before_struct = len(_str_before_struct.strip().split(' '))

struct_words_st_idx = _n_words_before_struct
assert len(_str_before_struct) == tokenized_txt.word_to_chars(struct_words_st_idx)[0], \
    (len(_str_before_struct), tokenized_txt.word_to_chars(struct_words_st_idx)[0])   # len(before) == starting char idx 

In [236]:
struct_in

'| world_1 | city : id , name , countrycode , district , population | sqlite_sequence : name , seq | country : code , name , continent ( Africa , Asia ) , region , surfacearea , indepyear , population , lifeexpectancy , gnp , gnpold , localname , governmentform , headofstate , capital , code2 | countrylanguage : countrycode , language , isofficial , percentage'

In [287]:
struct_ranges_collector = StructCharRangesCollector()
struct_ranges_collector.collect(struct_in, tokenized_txt, _n_words_before_struct)

In [288]:
for db_id_name, (st, ed) in struct_ranges_collector.db_id2char_ranges.items():
    assert txt[st:ed] == db_id_name, (st, ed, txt[st:ed], db_id_name) 

for table_name, (st, ed) in struct_ranges_collector.table2char_ranges.items():
    assert txt[st:ed] == table_name, (st, ed, txt[st:ed], table_name) 

for (_, col_name), (st, ed) in struct_ranges_collector.column2char_ranges.items():
    txt_piece = txt[st:ed].split(' ( ')[0]
    assert txt_piece == col_name, (st, ed, txt[st:ed], col_name) 

In [289]:
# for db_id_name, (st, ed) in struct_ranges_collector.db_id2char_ranges.items():
#     print(st, ed, txt[st:ed], db_id_name) 
# print()
# for table_name, (st, ed) in struct_ranges_collector.table2char_ranges.items():
#     print(st, ed, txt[st:ed], table_name) 
# print()
# for (_, col_name), (st, ed) in struct_ranges_collector.column2char_ranges.items():
#     txt_piece = txt[st:ed].split(' ( ')[0]
#     print(st, ed, txt[st:ed], col_name) 

In [290]:
for c_node in c_nodes:
    if c_node == ('NONE', '*'):
        # the special column in spider, using db_id 
        c_node_chars.append(list(struct_ranges_collector.db_id2char_ranges.values())[0])   # assuming only 1 db_id, which should be true...
    else:
        c_node_chars.append(struct_ranges_collector.column2char_ranges[c_node])

for t_node in t_nodes:
    t_node_chars.append(struct_ranges_collector.table2char_ranges[t_node])

#### Nodes to tokenized pieces

In [78]:
node_char_ranges = char_ranges_dict['q_node_chars'] + char_ranges_dict['c_node_chars'] + char_ranges_dict['t_node_chars']
len(node_char_ranges), len(sample['rat_sql_graph']['nodes'])

(49, 49)

In [79]:
# some chars can be mapped to multiple tokens (e.g. 'i' => '▁', 'i' )
char_to_tokens_dict = defaultdict(list)

for token_idx, tok in enumerate(tokenized_txt.tokens()):
    if tok == '</s>':
        break
    char_span = tokenized_txt.token_to_chars(token_idx)
    for char_idx in range(char_span[0], char_span[1]):
        char_to_tokens_dict[char_idx].append(token_idx)

len(char_to_tokens_dict)

393

In [None]:
node_pieces_ranges = []

for st, ed in node_ranges:
    piece_ids = []
    for char_idx in range(st, ed):
        _piece_ids = char_to_tokens_dict[char_idx]
        piece_ids.extend(_piece_ids)
    
    piece_st = piece_ids[0]
    piece_ed = piece_ids[-1] + 1
    # the collected piece_ids should be continuous 
    # ^ not true... some chars can be mapped to multiple tokens (started by ▁ )
    # re-collect a char-to-token
    assert set(range(piece_st, piece_ed)) == set(piece_ids), piece_ids
    
    node_pieces_ranges.append((piece_st, piece_ed))

In [None]:
for n, (p_st, p_ed) in zip(sample['rat_sql_graph']['nodes'], node_pieces_ranges):
    print(n, '\t', tokenized_txt.tokens()[p_st:p_ed])

### data: server vs. local

In [197]:
local_data_path = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/probing/text2sql/link_prediction/spider/uskg/local/dev.train.X.pkl"
server_data_path = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/probing/text2sql/link_prediction/spider/uskg/dev.train.X.pkl"

with open(local_data_path, 'rb') as f:
    local_data_X = pickle.load(f)
with open(server_data_path, 'rb') as f:
    server_data_X = pickle.load(f)

type(local_data_X), type(server_data_X)

(list, list)

In [201]:
local_data_X = np.array(local_data_X)
server_data_X = np.array(server_data_X)
local_data_X.shape, server_data_X.shape

((16059, 3072), (16059, 3072))

In [211]:
large_diff_ids = np.argsort(np.max(np.abs(local_data_X - server_data_X), axis=1))[::-1]
large_diff_ids[:10]

array([ 5264,  5255,  5247,   288,   292,   294, 11989, 11979,  8755,
        8756])

In [216]:
ds_path = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/spider/dev+ratsql_graph.json"
pos_path = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SDR-analysis/data/probing/text2sql/link_prediction/spider/uskg/dev.train.pos.txt"

with open(ds_path, 'r') as f:
    ds_samples = json.load(f)
    for d in ds_samples:
        d['rat_sql_graph']['relations'] = json.loads(d['rat_sql_graph']['relations'])
with open(pos_path, 'r') as f:
    pos_lines = f.read().strip().split('\n')

len(ds_samples), len(pos_lines)

(1034, 16059)

In [258]:
_idx = large_diff_ids[51]
_idx

5245

In [259]:
# _idx = 8970
local_data_X[_idx][::300], server_data_X[_idx][::300], max(local_data_X[_idx] - server_data_X[_idx])

(array([-0.16274624, -0.06520639,  0.11448385,  0.02375873,  0.00876573,
        -0.10075585,  0.11073106,  0.00379219,  0.07125453,  0.01576799,
         0.01186033], dtype=float32),
 array([-0.17650707, -0.07676765,  0.13225155,  0.01257438,  0.00267248,
        -0.1266068 ,  0.12796973,  0.00241799,  0.05103843,  0.03357478,
         0.01657346], dtype=float32),
 0.07821107)

In [260]:
sid, i, j = [int(s) for s in pos_lines[_idx].split('\t')]
sid, i, j

(696, 6, 6)

In [261]:
_nodes = ds_samples[sid]['rat_sql_graph']['nodes']
_nodes

['what',
 'be',
 'the',
 'number',
 'of',
 'vote',
 'from',
 'state',
 '`',
 'ny',
 "'",
 'or',
 '`',
 'ca',
 "'",
 '?',
 '<C>NONE::*',
 '<C>area_code_state::area_code',
 '<C>area_code_state::state',
 '<C>contestant::contestant_number',
 '<C>contestant::contestant_name',
 '<C>vote::vote_id',
 '<C>vote::phone_number',
 '<C>vote::state',
 '<C>vote::contestant_number',
 '<C>vote::create',
 '<T>area_code_state',
 '<T>contestant',
 '<T>vote']

In [262]:
_nodes[i], _nodes[j]

('from', 'from')

### Model params

In [323]:
# Set args here for runnning on notebook, we make them out here to make it more illustrative.
sys.argv = ['/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py', # This is the name of your .py launcher when you run this line of code.
            # belows are the parameters we set, take spider for example
            '--cfg', 'Salesforce/T5_base_finetune_spider_with_cell_value.cfg', 
            '--output_dir', './tmp']
parser = HfArgumentParser((WrappedSeq2SeqTrainingArguments,))
training_args, = parser.parse_args_into_dataclasses()
set_seed(training_args.seed)
tmp_args = Configure.Get(training_args.cfg)

In [330]:
tmp_args.bert.location

't5-base'

In [340]:
# model_path = 't5-base'
# model_path = 'hkunlp/from_all_T5_large_prefix_spider_with_cell_value2'
# model_path = '/Users/mac/Desktop/syt/Deep-Learning/Repos/UnifiedSKG/output/server_runs/A-T5_base_prefix_spider_with_cell_value-asr_mixed/checkpoint-79500/'
# model_path = '/Users/mac/Desktop/syt/Deep-Learning/Repos/UnifiedSKG/output/server_runs/A-T5_base_prefix_spider_with_cell_value-rewritten_mixed/checkpoint-56500/'

# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

# for reconstruction
# tokenizer_fast = AutoTokenizer.from_pretrained('t5-base', use_fast=True)

tmp_model = finetune.Model(tmp_args)

In [341]:
tmp_model.pretrain_model.encoder.block[0].layer[0].SelfAttention.q.weight[:3,:3]

tensor([[ 0.0762, -0.0471,  0.0309],
        [ 0.0381, -0.0075,  0.0003],
        [-0.0047, -0.0262, -0.0298]], grad_fn=<SliceBackward>)

In [342]:
# model_path = 't5-base'
# tmp_model.load(model_path)
# # need tmp_model.pretrained_model.load(...)?

In [343]:
tmp_model.pretrain_model.init_weights()

In [344]:
tmp_model.pretrain_model.encoder.block[0].layer[0].SelfAttention.q.weight[:3,:3]

tensor([[ 0.0013,  0.0083,  0.0015],
        [ 0.0008, -0.0017,  0.0039],
        [-0.0088, -0.0019, -0.0036]], grad_fn=<SliceBackward>)

### others

In [53]:
args.bert.location

't5-base'

In [62]:
args.model.__dict__

{'__self__': None,
 '__default__': {'__call__',
  '__class__',
  '__default__',
  '__delattr__',
  '__dict__',
  '__dir__',
  '__doc__',
  '__eq__',
  '__format__',
  '__ge__',
  '__getattribute__',
  '__gt__',
  '__hash__',
  '__init__',
  '__init_subclass__',
  '__iter__',
  '__le__',
  '__len__',
  '__lt__',
  '__module__',
  '__ne__',
  '__new__',
  '__reduce__',
  '__reduce_ex__',
  '__repr__',
  '__self__',
  '__setattr__',
  '__sizeof__',
  '__str__',
  '__subclasshook__',
  '__weakref__'},
 'name': 'unified.prefixtuning',
 'use_description': False,
 'concatenate_description': False,
 'map_description': False,
 'knowledge_usage': 'concatenate',
 'freeze_plm': True,
 'freeze_prefix': False}