In [None]:
import json
import sys
import os
from tqdm import notebook
from collections import Counter
import importlib
import pandas as pd
import tqdm

sys.path.append("")#path to src code

from split_logic.grammar.sparql_parser import SPARQLParser
from split_logic.grammar.sql_parser import SQLParser
from split_logic.grammar import atom_and_compound_cache

importlib.reload(atom_and_compound_cache)

In [None]:
PROJECT_PATH = ''#path to project

In [None]:
sparql_dataset = json.load(open('', 'r'))#path to whole sparql dataset
sparql_list = [sample["masked_query"] for sample in sparql_dataset]
sparql_parser = SPARQLParser(sparql_list)
parser_dict = {'wikidata': sparql_parser}
compound_path = os.path.join(PROJECT_PATH, 'dataset/lcquad/tmcd_split')
cached_sparql_parser = atom_and_compound_cache.AtomAndCompoundCache(parser_dict, query_key_name = None,
                                                                    kb_id_key_name=None,
                                                                    return_compound_list_flag = False,
                                                                   compound_cache_path=compound_path)

In [None]:
sql_dataset = json.load(open(''))#path to sql dataset
db2attr_dict = json.load(open(os.path.join(PROJECT_PATH,
                                           "dataset/wikisql/table_id2new_attrs_for_parsing.json")))

compound_path = os.path.join(PROJECT_PATH, 'dataset/wikisql/tmcd_split')
parser_dict = dict()
for sample in tqdm.tqdm(sql_dataset, total=len(sql_dataset)):
    db_id = sample['kb_id']
    db_attributes = db2attr_dict[db_id]
    if db_id not in parser_dict:
        parser_instance = SQLParser(db_attributes)
        parser_dict[db_id] = parser_instance
        
        
cached_sql_parser = atom_and_compound_cache.AtomAndCompoundCache(parser_dict,
                                                                 query_key_name=None, kb_id_key_name=None,
                                                                 return_compound_list_flag=False,
                                                                 compound_cache_path=compound_path)

## Reading data

### SPARQL

In [None]:
sparql_os_train = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/original_split/english_train_split.json'), 'r'))
sparql_os_test = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/original_split/english_test_split.json'), 'r'))

len(sparql_os_train), len(sparql_os_test)

In [None]:
sparql_tmcd_train = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/tmcd_split/english_train_split_coef_0.1.json'), 'r'))
sparql_tmcd_dev = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/tmcd_split/english_dev_split_coef_0.1.json'), 'r'))
sparql_tmcd_test = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/tmcd_split/english_test_split_coef_0.1.json'), 'r'))

len(sparql_tmcd_train), len(sparql_tmcd_dev), len(sparql_tmcd_test)

In [None]:
sparql_tl_train = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/target_length_split/english_train_split_above_50_percentile.json'), 'r'))
sparql_tl_dev = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/target_length_split/english_dev_split_below_50_percentile.json'), 'r'))
sparql_tl_test = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/target_length_split/english_test_split_below_50_percentile.json'), 'r'))

len(sparql_tl_train), len(sparql_tl_dev), len(sparql_tl_test)

In [None]:
sparql_iid_train = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/language_variation_split/english_train_split.json'), 'r'))
sparql_iid_dev = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/language_variation_split/english_dev_split.json'), 'r'))
sparql_iid_test = json.load(open(os.path.join(PROJECT_PATH, 'dataset/lcquad/language_variation_split/english_test_split.json'), 'r'))

len(sparql_iid_train), len(sparql_iid_dev), len(sparql_iid_test)

### SQL

In [None]:
sql_os_train = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/original_split/train_split.json'), 'r'))
sql_os_dev = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/original_split/dev_split.json'), 'r'))
sql_os_test = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/original_split/test_split.json'), 'r'))

len(sql_os_train), len(sql_os_dev), len(sql_os_test)

In [None]:
sql_tmcd_train = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/tmcd_split/english_train_split_coef_0.1.json'), 'r'))
sql_tmcd_dev = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/tmcd_split/english_dev_split_coef_0.1.json'), 'r'))
sql_tmcd_test = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/tmcd_split/english_test_split_coef_0.1.json'), 'r'))

len(sql_tmcd_train), len(sql_tmcd_dev), len(sql_tmcd_test)

In [None]:
sql_tl_train = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/target_length_split/english_train_split_85_percentile.json'), 'r'))
sql_tl_dev = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/target_length_split/english_dev_split_85_percentile.json'), 'r'))
sql_tl_test = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/target_length_split/english_test_split_85_percentile.json'), 'r'))

len(sql_tl_train), len(sql_tl_dev), len(sql_tl_test)

In [None]:
sql_iid_train = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/language_variation_split/english_train_split.json'), 'r'))
sql_iid_dev = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/language_variation_split/english_dev_split.json'), 'r'))
sql_iid_test = json.load(open(os.path.join(PROJECT_PATH, 'dataset/wikisql/language_variation_split/english_test_split.json'), 'r'))

len(sql_iid_train), len(sql_iid_dev), len(sql_iid_test)

## Question/queries intersection in train/test

In [None]:
def check_question_query_interection(train_set, dev_set, test_set, dataset_name):
    train_questions_set = set([sample['question'] for sample in train_set])
    test_questions_set = set([sample['question'] for sample in test_set])
    if dev_set:
        dev_questions_set = set([sample['question'] for sample in dev_set])
        
    train_query_set = set([sample['masked_query'] for sample in train_set])
    test_query_set = set([sample['masked_query'] for sample in test_set])
    if dev_set:
        dev_query_set = set([sample['masked_query'] for sample in dev_set])
        
    print(f"Stats for {dataset_name}")
    
    print('Unique questions in Train: ', len(train_questions_set))
    print('Unique questions in Test: ', len(test_questions_set))
    if dev_set:
        print('Unique questions in Dev: ', len(dev_questions_set))
    print()
    
    test_train_question_intersection = len(test_questions_set.intersection(train_questions_set)) / len(test_questions_set)
    print('Test questions in Train Percent: ', round(test_train_question_intersection, 2))
        
    print("---------------------------------")
    
    print('Unique queries in Train: ', len(train_query_set))
    print('Unique queries in Test: ', len(test_questions_set))
    if dev_set:
        print('Unique queries in Dev: ', len(dev_query_set))
    print()
    
    test_train_query_intersection = len(test_query_set.intersection(train_query_set)) / len(test_questions_set)
    print('Test queries in Train Percent: ', round(test_train_query_intersection, 2))

### SPARQL

In [None]:
check_question_query_interection(train_set=sparql_os_train,
                                dev_set=None,
                                test_set=sparql_os_test, dataset_name='Sparql Original Split')

In [None]:
check_question_query_interection(train_set=sparql_iid_train,
                                dev_set=sparql_iid_dev,
                                test_set=sparql_iid_test, dataset_name='Sparql Lang Var Split')

In [None]:
check_question_query_interection(train_set=sparql_tl_train,
                                dev_set=sparql_tl_dev,
                                test_set=sparql_tl_test, dataset_name='Sparql Target Length Split')

In [None]:
check_question_query_interection(train_set=sparql_tmcd_train,
                                dev_set=sparql_tmcd_dev,
                                test_set=sparql_tmcd_test, dataset_name='Sparql TMCD Split')

### SQL

In [None]:
check_question_query_interection(train_set=sql_os_train,
                                dev_set=sql_os_dev,
                                test_set=sql_os_test, dataset_name='SQL Original Split')

In [None]:
check_question_query_interection(train_set=sql_iid_train,
                                dev_set=sql_iid_dev,
                                test_set=sql_iid_test, dataset_name='SQL Language Variation Split')

In [None]:
check_question_query_interection(train_set=sql_tl_train,
                                dev_set=sql_tl_dev,
                                test_set=sql_tl_test, dataset_name='SQL Target Length Split')

In [None]:
check_question_query_interection(train_set=sql_tmcd_train,
                                dev_set=sql_tmcd_dev,
                                test_set=sql_tmcd_test, dataset_name='SQL TMCD Split')

## Распределение по длинне целевых запросов

In [None]:
def calculate_split_target_length(train_set, test_set, dataset_name):
    train_length_array = [len(s['masked_query'].split()) for s in train_set]
    test_length_array = [len(s['masked_query'].split()) for s in test_set]
    
    print(f'Dataset stats {dataset_name}')
    
    print('Train Average lenght: ', round(np.mean(train_length_array), 3))
    print('Train 95 percentile: ', round(np.percentile(train_length_array, 95), 3))
    
    print('Test Average lenght: ', round(np.mean(test_length_array), 3))
    print('Test 95 percentile: ', round(np.percentile(test_length_array, 95), 3))

### SPARQL

In [None]:
calculate_split_target_length(train_set=sparql_os_train, 
                            test_set=sparql_os_test,
                            dataset_name='Original split SPARQL')

In [None]:
calculate_split_target_length(train_set=sparql_iid_train, 
                            test_set=sparql_iid_test,
                            dataset_name='Lang var SPARQL')

In [None]:
calculate_split_target_length(train_set=sparql_tl_train, 
                            test_set=sparql_tl_test,
                            dataset_name='target length SPARQL')

In [None]:
calculate_split_target_length(train_set=sparql_tmcd_train, 
                            test_set=sparql_tmcd_test,
                            dataset_name='TMCD SPARQL')

### SQL

In [None]:
calculate_split_target_length(train_set=sql_os_train, 
                            test_set=sql_os_test,
                            dataset_name='Original split SQL')

In [None]:
calculate_split_target_length(train_set=sql_iid_train, 
                            test_set=sql_iid_test,
                            dataset_name='Lang var SQL')

In [None]:
calculate_split_target_length(train_set=sql_tl_train, 
                            test_set=sql_tl_test,
                            dataset_name='target length SQL')

In [None]:
calculate_split_target_length(train_set=sql_tmcd_train, 
                            test_set=sql_tmcd_test,
                            dataset_name='TMCD SQL')

## Train/test compound distribution

In [None]:
def check_compound_train_test_distr(train_set, test_set, parser, dataset_name):
    parser_compounds = list(parser.query_parser_dict.values())[0].compound_parsers_dict
    
    parsed_train_queries_set = set()
    train_compound_dict = {key: [] for key in parser_compounds}
    train_queries = [sample['masked_query'] for sample in train_set]
    train_kbs = [sample['kb_id'] for sample in train_set]
    
    parsed_test_queries_set = set()
    test_compound_dict = {key: [] for key in parser_compounds}
    test_queries = [sample['masked_query'] for sample in test_set]
    test_kbs = [sample['kb_id'] for sample in test_set]
    
    print('Parsing train queries!')
    for query, kb in tqdm.tqdm(zip(train_queries, train_kbs), total=len(train_queries)):
        compound_dict = parser.get_compounds(query, kb)
        for key in compound_dict:
            train_compound_dict[key] += compound_dict[key]
        
    print('Parsing test queries')
    for query, kb in tqdm.tqdm(zip(test_queries, test_kbs), total=len(test_queries)):
        compound_dict = parser.get_compounds(query, kb)
        for key in compound_dict:
            test_compound_dict[key] += compound_dict[key]
    
    print(f'Unique Compound proportion in {dataset_name}')
    train_length = len(train_set)
    for key in train_compound_dict:
        compound_proportion = round(len(set(train_compound_dict[key])) / train_length, 3)
        print(f"train {key} proportion: ", compound_proportion)
    print()
    test_length = len(test_set)
    for key in test_compound_dict:
        compound_proportion = round(len(set(test_compound_dict[key])) / test_length, 3)
        print(f"test {key} proportion: ", compound_proportion)

### SPARQL

In [None]:
check_compound_train_test_distr(train_set=sparql_os_train, 
                            test_set=sparql_os_test, parser=cached_sparql_parser, 
                            dataset_name='Original split SPARQL')

In [None]:
check_compound_train_test_distr(train_set=sparql_iid_train, 
                            test_set=sparql_iid_test, parser=cached_sparql_parser, 
                            dataset_name='Lang var SPARQL')

In [None]:
check_compound_train_test_distr(train_set=sparql_tl_train, 
                            test_set=sparql_tl_test, parser=cached_sparql_parser, 
                            dataset_name='Target length SPARQL')

In [None]:
check_compound_train_test_distr(train_set=sparql_tmcd_train, 
                            test_set=sparql_tmcd_test, parser=cached_sparql_parser, 
                            dataset_name='TMCD SPARQL')

### SQL

In [None]:
check_compound_train_test_distr(train_set=sql_os_train, 
                            test_set=sql_os_test, parser=cached_sql_parser, 
                            dataset_name='Original split SQL')

In [None]:
check_compound_train_test_distr(train_set=sql_iid_train, 
                            test_set=sql_iid_test, parser=cached_sql_parser, 
                            dataset_name='Lang var SQL')

In [None]:
check_compound_train_test_distr(train_set=sql_tl_train, 
                            test_set=sql_tl_test, parser=cached_sql_parser, 
                            dataset_name='Target length SQL')

In [None]:
check_compound_train_test_distr(train_set=sql_tmcd_train, 
                            test_set=sql_tmcd_test, parser=cached_sql_parser, 
                            dataset_name='TMCD SQL')

## Train/test compound intersection

In [None]:
def check_compound_intersection(train_set, test_set, parser, dataset_name):
    parser_compounds = list(parser.query_parser_dict.values())[0].compound_parsers_dict
    
    parsed_train_queries_set = set()
    train_compound_dict = {key: [] for key in parser_compounds}
    train_queries = [sample['masked_query'] for sample in train_set]
    train_kbs = [sample['kb_id'] for sample in train_set]
    
    parsed_test_queries_set = set()
    test_compound_dict = {key: [] for key in parser_compounds}
    test_queries = [sample['masked_query'] for sample in test_set]
    test_kbs = [sample['kb_id'] for sample in test_set]
    
    print('Parsing train queries!')
    for query, kb in tqdm.tqdm(zip(train_queries, train_kbs), total=len(train_queries)):
        compound_dict = parser.get_compounds(query, kb)
        for key in compound_dict:
            train_compound_dict[key] += compound_dict[key]
        
    print('Parsing test queries')
    for query, kb in tqdm.tqdm(zip(test_queries, test_kbs), total=len(test_queries)):
        compound_dict = parser.get_compounds(query, kb)
        for key in compound_dict:
            test_compound_dict[key] += compound_dict[key]
                
    print(f'Compound intersection for {dataset_name}')
    for key in train_compound_dict:
        compound_train_set = set(train_compound_dict[key])
        compound_test_set = set(test_compound_dict[key])
        intersect_size = len(compound_train_set.intersection(compound_test_set))
        if len(compound_test_set) == 0:
            intersect_ratio = 0
        else:
            intersect_ratio = round(intersect_size / len(compound_test_set), 2)
        print(f'Amount of train samples in test for compound {key}: ', intersect_ratio)

### SPARQL

In [None]:
check_compound_intersection(train_set=sparql_os_train, 
                            test_set=sparql_os_test, parser=cached_sparql_parser, 
                            dataset_name='Original split SPARQL')

In [None]:
check_compound_intersection(train_set=sparql_iid_train, 
                            test_set=sparql_iid_test, parser=cached_sparql_parser, 
                            dataset_name='Language variation SPARQL')

In [None]:
check_compound_intersection(train_set=sparql_tl_train, 
                            test_set=sparql_tl_test, parser=cached_sparql_parser, 
                            dataset_name='target length SPARQL')

In [None]:
check_compound_intersection(train_set=sparql_tmcd_train, 
                            test_set=sparql_tmcd_test, parser=cached_sparql_parser, 
                            dataset_name='TMCD SPARQL')

### SQL

In [None]:
check_compound_intersection(train_set=sql_os_train, 
                            test_set=sql_os_test, parser=cached_sql_parser, 
                            dataset_name='Original split SQL')

In [None]:
check_compound_intersection(train_set=sql_iid_train, 
                            test_set=sql_iid_test, parser=cached_sql_parser, 
                            dataset_name='Lanugage variation split SQL')

In [None]:
check_compound_intersection(train_set=sql_tl_train, 
                            test_set=sql_tl_test, parser=cached_sql_parser, 
                            dataset_name='Target length SQL')

In [None]:
check_compound_intersection(train_set=sql_tmcd_train, 
                            test_set=sql_tmcd_test, parser=cached_sql_parser, 
                            dataset_name='TMCD SQL')

In [None]:
import matplotlib.pylab as plt
import numpy as np

%pylab inline

In [None]:
data = json.load(open('/Users/somov-od/Documents/phd/projects/CompGen/my_dataset/wikisql/whole_wikisql.json', 'r'))


len(data)

In [None]:
data[0]

In [None]:
np.percentile(length_list, 50)

In [None]:
len(list(filter(lambda x: x > np.percentile(length_list, 50), length_list)))

In [None]:
length_list = [len(sample['masked_query'].split()) for sample in data]

plt.hist(length_list, bins= 50)

In [None]:
len(length_list)