In [1]:
import numpy as np
import os
import json
import collections

# Dataset characteristic
1. number of DBs
2. number of tables per DB
3. number of columns per table
4. column matching ratio
5. value matching ratio

In [43]:
def get_where_cond(sql):
    conds = []
    for unit in sql['where']:
        if isinstance(unit, list):
            col = unit[2][1][1]
            val1 = unit[3]
            val2 = unit[4]
            if isinstance(val1, dict):
                conds += get_where_cond(val1)
                val1 = None
            if isinstance(val2, dict):
                conds += get_where_cond(val2)
                val2 = None
            conds.append([col, val1, val2])
    if sql['intersect'] is not None:
        conds += get_where_cond(sql['intersect'])
    if sql['union'] is not None:
        conds += get_where_cond(sql['union'])
    if sql['except'] is not None:
        conds += get_where_cond(sql['except'])
    return conds
def get_select_col(sql):
    return [x[1][1][1] for x in sql['select'][1]]
def get_orderby(sql):
    cols = []
    if sql['orderBy']:
        cols += [x[1][1] for x in sql['orderBy'][1]]
    if sql['intersect'] is not None:
        cols += get_orderby(sql['intersect'])
    if sql['union'] is not None:
        cols += get_orderby(sql['union'])
    if sql['except'] is not None:
        cols += get_orderby(sql['except'])
    return cols
def get_groupby(sql):
    cols = []
    if sql['groupBy']:
        cols += [x[1] for x in sql['groupBy']]
    if sql['intersect'] is not None:
        cols += get_groupby(sql['intersect'])
    if sql['union'] is not None:
        cols += get_groupby(sql['union'])
    if sql['except'] is not None:
        cols += get_groupby(sql['except'])
    return cols
def get_having_cond(sql):
    conds = []
    for unit in sql['having']:
        if isinstance(unit, list):
            col = unit[2][1][1]
            val1 = unit[3]
            val2 = unit[4]
            if isinstance(val1, dict):
                conds += get_having_cond(val1)
                val1 = None
            if isinstance(val2, dict):
                conds += get_having_cond(val2)
                val2 = None
            conds.append([col, val1, val2])
    if sql['intersect'] is not None:
        conds += get_having_cond(sql['intersect'])
    if sql['union'] is not None:
        conds += get_having_cond(sql['union'])
    if sql['except'] is not None:
        conds += get_having_cond(sql['except'])
    return conds

def process_val(field_value):
    if isinstance(field_value, bytes):
        field_value_str = field_value.encode('latin1')
    elif isinstance(field_value, str):
        field_value_str = field_value
    elif isinstance(field_value, float):
        if float(field_value) == int(field_value):
            field_value_str = str(int(field_value))
        else:
            field_value_str = str(float(field_value))
    else:
        field_value_str = str(field_value)
    field_value_str = field_value_str.strip("\"")
    # regualr expression
    field_value_str = field_value_str.strip("%")
    return field_value_str.lower()

def get_matching_ratio(dataset, debug=False):
    cmp_col_matching = []
    cmp_col_matching_per_example = []
    non_select_col_matching = []
    select_col_matching = []
    val_matching = []
    per_db_cmp_cols = {}
    per_db_non_select_cols = {}
    per_db_select_cols = {}
    
    for i,x in enumerate(dataset):
        if x['db_id'] not in per_db_cmp_cols:
            per_db_cmp_cols[x['db_id']] = set()
            per_db_non_select_cols[x['db_id']] = set()
            per_db_select_cols[x['db_id']] = set()
            
        where_conds = [z for z in get_where_cond(x['sql'])]
        select = get_select_col(x['sql'])
        groupby = get_groupby(x['sql'])
        orderby = get_orderby(x['sql'])
        having_conds = [z for z in get_having_cond(x['sql'])]
        orig_q = x['question'].lower()
        def get_col_names(cols):
            cols = set(cols)
#             cols.discard(0)
            return sorted([dbs[x['db_id']]['column_names'][col][1].lower() for col in cols])
        cmp_cols = get_col_names([z[0] for z in where_conds+having_conds if type(z[1]) in {float, str} or type(z[2]) in {float, str}])
#         cmp_cols = get_col_names([z[0] for z in where_conds if type(z[1]) in {float, str} or type(z[2]) in {float, str}])
        cmp_cols = [z for z in cmp_cols if z!='*']
        if len(cmp_cols)!= 0:
            cmp_col_matching.append(len([col_name for col_name in cmp_cols if orig_q.find(col_name)!=-1])/len(cmp_cols))
            if cmp_col_matching[-1]==1:
                cmp_col_matching_per_example.append(1)
            else:
                cmp_col_matching_per_example.append(0)
        else:
            pass
#             cmp_col_matching_per_example.append(1)
        per_db_cmp_cols[x['db_id']].add(tuple(cmp_cols))
        if debug:
            if len([col_name for col_name in cmp_cols if orig_q.find(col_name)!=-1])!=len(cmp_cols):
                print(i, orig_q)
                print('cmp',' | '.join(cmp_cols))
        non_select_cols = get_col_names([z[0] for z in where_conds+having_conds]+groupby+orderby)
        if len(non_select_cols)!= 0:
            non_select_col_matching.append(len([col_name for col_name in non_select_cols if orig_q.find(col_name)!=-1])/len(non_select_cols))
        per_db_non_select_cols[x['db_id']].add(tuple(non_select_cols))
        select_cols = get_col_names(select)
        if len(select_cols)!= 0:
            select_col_matching.append(len([col_name for col_name in select_cols if orig_q.find(col_name)!=-1])/len(select_cols))
#         if debug:
#             if len([col_name for col_name in select_cols if orig_q.find(col_name)!=-1])!=len(select_cols):
#                 print(i, orig_q)
#                 print('select',' | '.join(select_cols))
        per_db_select_cols[x['db_id']].add(tuple(select_cols))
        vals = [z for cond in where_conds+having_conds for z in cond[1:] if z is not None]
        vals = set([process_val(val) for val in vals])
        vals.discard('')
        if len(vals)!=0:
            val_matching.append(len([val for val in vals if val.lower() in orig_q])/len(vals))
    print('column matching ratio: %.3f, %.3f, %.3f'%(np.mean(cmp_col_matching), np.mean(non_select_col_matching), np.mean(select_col_matching)))
    print('cmp column matching ratio per example: %.3f'%(np.mean(cmp_col_matching_per_example)))
    print('unique columns: %.3f, %.3f, %.3f'%(
        np.mean([len(x) for _,x in per_db_cmp_cols.items()]),
        np.mean([len(x) for _,x in per_db_non_select_cols.items()]),
        np.mean([len(x) for _,x in per_db_select_cols.items()])
    ))
    print('value matching ratio', np.mean(val_matching))
    return per_db_select_cols

In [5]:
datasets = {}
dbs = {}

In [6]:
with open('../../../featurestorage/data/spider-20200607/train_spider.json', 'r') as f:
    datasets['spider_train'] = json.load(f)

In [7]:
with open('../../../featurestorage/data/spider-20200607/dev.json', 'r') as f:
    dataset = json.load(f)
with open('../../../featurestorage/data/spider-20200607/tables.json', 'r') as f:
    for db in json.load(f):
        dbs[db['db_id']] = db
datasets['spider_dev'] = dataset
with open('../../../featurestorage/data/spider-20200607/spider_modified_b_dev.json', 'r') as f:
    datasets['spider_dev_modified'] = json.load(f)
for db_id in ['atis', 'geography', 'restaurants', 'scholar', 'imdb', 'yelp', 'advising', 'academic']:
    with open(f'../../../featurestorage/data/spider-20200607/{db_id}_dev.json', 'r') as f:
        datasets[f'{db_id}_dev'] = json.load(f)
    with open(f'../../../featurestorage/data/spider-20200607/{db_id}_tables.json', 'r') as f:
        for db in json.load(f):
            dbs[db['db_id']] = db

In [216]:
get_matching_ratio(datasets['spider_dev'], True)

4 what is the average, minimum, and maximum age of all singers from france?
cmp country
5 what is the average, minimum, and maximum age for all french singers?
cmp country
21 how many concerts occurred in 2014 or 2015?
cmp year
25 what is the name and capacity of the stadium with the most concerts after 2013 ?
cmp year
32 what are the names of all stadiums that did not have a concert in 2014?
cmp year
38 what are the names of the singers who performed in a concert in 2014?
cmp year
39 what is the name and nation of the singer who have a song having 'hey' in its name?
cmp song name
40 what is the name and country of origin of every singer who has a song with the word 'hey' in its title?
cmp song name
42 what are the names and locations of the stadiums that had concerts that occurred in both 2014 and 2015?
cmp year
51 find number of pets owned by students who are older than 20.
cmp age
53 find the number of dog pets that are raised by female students (with sex f).
cmp pet type | sex
54 h

{'concert_singer': {('*',),
  ('*', 'concert name', 'theme'),
  ('*', 'country'),
  ('*', 'name'),
  ('age',),
  ('age', 'country', 'name'),
  ('average', 'capacity'),
  ('capacity',),
  ('capacity', 'name'),
  ('country',),
  ('country', 'name'),
  ('location', 'name'),
  ('name',),
  ('song name',),
  ('song name', 'song release year'),
  ('year',)},
 'pets_1': {('*',),
  ('*', 'student id'),
  ('age',),
  ('age', 'first name'),
  ('age', 'major'),
  ('first name',),
  ('first name', 'sex'),
  ('last name',),
  ('pet age', 'pet type'),
  ('pet id',),
  ('pet id', 'weight'),
  ('pet type',),
  ('pet type', 'weight'),
  ('student id',),
  ('weight',)},
 'car_1': {('*',),
  ('*', 'cont id', 'continent'),
  ('*', 'continent'),
  ('*', 'full name'),
  ('*', 'full name', 'id'),
  ('accelerate',),
  ('accelerate', 'cylinders'),
  ('country id', 'country name'),
  ('country name',),
  ('cylinders',),
  ('edispl',),
  ('full name', 'id'),
  ('horsepower',),
  ('horsepower', 'make'),
  ('id', 

In [76]:
np.mean([x for _, x in collections.Counter([x[0] for x in dbs['atis']['column_names'][1:]]).items()])

5.24

In [30]:
col_type = set()
for _,dataset in datasets.items():
    for example in dataset:
        where_conds = get_where_cond(example['sql'])
        cols = [type(z[1]) for z in where_conds]+[type(z[2]) for z in where_conds]
        col_type |= set(cols)
display(col_type)

{NoneType, float, list, str}

In [44]:
print('dataset\t# examples\t# DBs\t# table per DB\t# column per table')
for dataset_id, dataset in datasets.items():
    print(dataset_id)
    db_ids = list(set([x['db_id'] for x in dataset]))
    print('%d,%d,%.3f,%.3f'%(
        len(dataset),len(db_ids),
        np.mean([len(dbs[db_id]['table_names']) for db_id in db_ids]),
        np.mean([np.mean([x for _, x in collections.Counter([x[0] for x in dbs[db_id]['column_names'][1:]]).items()]) for db_id in db_ids])
    ))
    get_matching_ratio(dataset)
    print('-'*20)

dataset	# examples	# DBs	# table per DB	# column per table
spider_train
7000,140,5.264,5.186
column matching ratio: 0.408, 0.416, 0.524
cmp column matching ratio per example: 0.380
unique columns: 7.371, 13.950, 15.386
value matching ratio 0.8901880196523054
--------------------
spider_dev
1034,20,4.050,5.464
column matching ratio: 0.418, 0.399, 0.556
cmp column matching ratio per example: 0.392
unique columns: 6.750, 13.800, 15.750
value matching ratio 0.9146586345381527
--------------------
spider_dev_modified
508,19,4.053,5.528
column matching ratio: 0.019, 0.053, 0.543
cmp column matching ratio per example: 0.018
unique columns: 5.842, 8.842, 8.684
value matching ratio 0.9327731092436975
--------------------
atis_dev
283,1,25.000,5.240
column matching ratio: 0.000, 0.002, 0.000
cmp column matching ratio per example: 0.000
unique columns: 32.000, 37.000, 14.000
value matching ratio 0.96903073286052
--------------------
geography_dev
525,1,7.000,4.143
column matching ratio: 0.039, 0.

In [42]:
print('dataset\t# examples\t# DBs\t# table per DB\t# column per table')
for dataset_id, dataset in datasets.items():
    print(dataset_id)
    db_ids = list(set([x['db_id'] for x in dataset]))
    print('%d,%d,%.3f,%.3f'%(
        len(dataset),len(db_ids),
        np.mean([len(dbs[db_id]['table_names']) for db_id in db_ids]),
        np.mean([np.mean([x for _, x in collections.Counter([x[0] for x in dbs[db_id]['column_names'][1:]]).items()]) for db_id in db_ids])
    ))
    get_matching_ratio(dataset)
    print('-'*20)

dataset	# examples	# DBs	# table per DB	# column per table
spider_train
7000,140,5.264,5.186
column matching ratio: 0.408, 0.416, 0.524
cmp column matching ratio per example: 0.716
unique columns: 7.371, 13.950, 15.386
value matching ratio 0.8901880196523054
--------------------
spider_dev
1034,20,4.050,5.464
column matching ratio: 0.418, 0.399, 0.556
cmp column matching ratio per example: 0.744
unique columns: 6.750, 13.800, 15.750
value matching ratio 0.9146586345381527
--------------------
spider_dev_modified
508,19,4.053,5.528
column matching ratio: 0.019, 0.053, 0.543
cmp column matching ratio per example: 0.341
unique columns: 5.842, 8.842, 8.684
value matching ratio 0.9327731092436975
--------------------
atis_dev
283,1,25.000,5.240
column matching ratio: 0.000, 0.002, 0.000
cmp column matching ratio per example: 0.004
unique columns: 32.000, 37.000, 14.000
value matching ratio 0.96903073286052
--------------------
geography_dev
525,1,7.000,4.143
column matching ratio: 0.039, 0.

In [148]:
spider_train_column_names = set()
spider_train_cmp_cols = set()
spider_train_non_select_cols = set()
spider_train_select_cols = set()
spider_train_values = set()
db_ids = list(set([x['db_id'] for x in datasets['spider_train']]))
for db_id in db_ids:
    spider_train_column_names |= set([x[1].lower() for x in dbs[db_id]['column_names'][1:]])
for x in datasets['spider_train']:
    where_conds = [z for z in get_where_cond(x['sql'])]
    select = get_select_col(x['sql'])
    groupby = get_groupby(x['sql'])
    orderby = get_orderby(x['sql'])
    having_conds = [z for z in get_having_cond(x['sql'])]
    def get_col_names(cols):
        cols = set(cols)
#             cols.discard(0)
        return sorted([dbs[x['db_id']]['column_names'][col][1].lower() for col in cols])
    cmp_cols = get_col_names([z[0] for z in where_conds+having_conds if z[1] is not None or z[2] is not None])
    spider_train_cmp_cols |= set(cmp_cols)
    non_select_cols = get_col_names([z[0] for z in where_conds+having_conds]+groupby+orderby)
    spider_train_non_select_cols |= set(non_select_cols)
    select_cols = get_col_names(select)
    spider_train_select_cols |= set(select_cols)
    vals = [z for cond in where_conds+having_conds for z in cond[1:] if z is not None]
    vals = set([process_val(val) for val in vals if val])
    vals.discard('')
    spider_train_values |= vals

In [193]:
totto_column_names = set()
totto_non_select_cols = set()
totto_values = set()
def decode(tokenized):
    decoded = []
    for x in tokenized:
        if not x.startswith('##'):
            decoded.append(x)
        else:
            decoded[-1] += x[2:]
    return ' '.join(decoded)
with open('../../../table_pretrain/data/train_with_value.json', 'r') as f:
    totto_train = json.load(f)
for headers, h_label, _, _, column_vals, _  in totto_train:
    totto_column_names |= set([decode(x) for x in headers])
    totto_non_select_cols |= set([decode(x) for i,x in enumerate(headers) if h_label[i]==1])
    totto_values |= set([decode(x) for i,col in enumerate(column_vals) if h_label[i]==1 for x in col])

In [197]:
totto_non_select_cols

{'rd - 0213',
 'opening acts north america – leg 1 north america – leg 2 europe — leg 3 total',
 'qty',
 'suzette',
 'album us country',
 'reached number one 1960 1961 1962 1963 1964 1965 1966 1967 2019',
 '2017 – 18 season total 2016 – 17 season total 2015 – 16 season total 2014 – 15 season total 2013 – 14 season total 2012 – 13 season total 2011 – 12 season total',
 'tds 13',
 'opponents',
 'league goals 92 career totals',
 'contributors , sortable by editor ( s ) ( in bold )',
 'date africa oceania north america',
 'athlete boys under 20 ( junior ) girls under 20 ( junior ) boys under 17 ( youth ) girls under 17 ( youth )',
 'club performance league japan thailand japan japan 4 total',
 'climate data for sanya ( 1971 – 2000 ) jan',
 'artist additional',
 'year representing egypt',
 'united states presidential election in massachusetts , 1916 party',
 'representing university of texas at austin distance',
 'n / a kyle sandilands mel b dannii minogue dannii minogue mel b',
 'internati

In [178]:
for dataset_id, dataset in datasets.items():
    print(dataset_id)
    val_overlap = []
    per_db_column_name = {}
    cmp_cols_overlap = []
    non_select_cols_overlap = []
    select_cols_overlap = []
    for x in dataset:
        if x['db_id'] not in per_db_column_name:
            per_db_column_name[x['db_id']] = set([x[1].lower() for x in dbs[x['db_id']]['column_names'][1:]])
            per_db_val_used[x['db_id']] = set()
            per_db_cmp_cols[x['db_id']] = set()
            per_db_non_select_cols[x['db_id']] = set()
            per_db_select_cols[x['db_id']] = set()
        where_conds = [z for z in get_where_cond(x['sql'])]
        select = get_select_col(x['sql'])
        groupby = get_groupby(x['sql'])
        orderby = get_orderby(x['sql'])
        having_conds = [z for z in get_having_cond(x['sql'])]
        def get_col_names(cols):
            cols = set(cols)
    #             cols.discard(0)
            return sorted([dbs[x['db_id']]['column_names'][col][1].lower() for col in cols])
        cmp_cols = set(get_col_names([z[0] for z in where_conds+having_conds if z[1] is not None or z[2] is not None]))
        if len(cmp_cols)!=0:
            cmp_cols_overlap.append(len(cmp_cols&spider_train_cmp_cols)/len(cmp_cols))
        non_select_cols = set(get_col_names([z[0] for z in where_conds+having_conds]+groupby+orderby))
        if len(non_select_cols)!=0:
            non_select_cols_overlap.append(len(non_select_cols&spider_train_non_select_cols)/len(non_select_cols))
        select_cols = set(get_col_names(select))
        if len(select_cols)!=0:
            select_cols_overlap.append(len(select_cols&spider_train_select_cols)/len(select_cols))
        vals = [z for cond in where_conds+having_conds for z in cond[1:] if z is not None]
        vals = set([process_val(val) for val in vals if val])
        vals.discard('')
        if len(vals)!=0:
            val_overlap.append(len(vals&spider_train_values)/len(vals))
    print('%.3f,%.3f,%.3f,%.3f'%(
        np.mean([len(x&spider_train_column_names)/len(x) for _,x in per_db_column_name.items()]),
        np.mean(cmp_cols_overlap),
        np.mean(non_select_cols_overlap),
        np.mean(select_cols_overlap)))
    print('%.3f'%np.mean(val_overlap))
    print('-'*20)

spider_dev
0.486,0.470,0.503,0.634
0.499
--------------------
spider_dev_modified
0.495,0.437,0.515,0.690
0.457
--------------------
atis_dev
0.147,0.017,0.017,0.003
0.204
--------------------
geography_dev
0.278,0.032,0.226,0.262
0.446
--------------------
restaurants_dev
0.600,0.098,0.414,0.616
0.286
--------------------
scholar_dev
0.294,0.112,0.173,0.160
0.116
--------------------
imdb_dev
0.444,0.769,0.741,0.898
0.217
--------------------
yelp_dev
0.538,0.665,0.667,0.807
0.241
--------------------
advising_dev
0.178,0.405,0.399,0.618
0.092
--------------------
academic_dev
0.222,0.865,0.846,0.774
0.183
--------------------
spider_train
1.000,1.000,1.000,1.000
1.000
--------------------


In [144]:
spider_train_column_values

set()

In [141]:
all_vals & spider_train_column_values

set()

In [143]:
'mtw' in spider_train_column_values

False

In [175]:
dataset_id = 'yelp_dev'
dataset = datasets[dataset_id]
print(dataset_id)
val_overlap = []
per_db_column_name = {}
cmp_cols_overlap = []
non_select_cols_overlap = []
select_cols_overlap = []
for x in dataset:
    if x['db_id'] not in per_db_column_name:
        per_db_column_name[x['db_id']] = set([x[1].lower() for x in dbs[x['db_id']]['column_names'][1:]])
        per_db_val_used[x['db_id']] = set()
        per_db_cmp_cols[x['db_id']] = set()
        per_db_non_select_cols[x['db_id']] = set()
        per_db_select_cols[x['db_id']] = set()
    where_conds = [z for z in get_where_cond(x['sql'])]
    select = get_select_col(x['sql'])
    groupby = get_groupby(x['sql'])
    orderby = get_orderby(x['sql'])
    having_conds = [z for z in get_having_cond(x['sql'])]
    def get_col_names(cols):
        cols = set(cols)
#             cols.discard(0)
        return sorted([dbs[x['db_id']]['column_names'][col][1].lower() for col in cols])
    cmp_cols = set(get_col_names([z[0] for z in where_conds+having_conds if z[1] is not None or z[2] is not None]))
    if len(cmp_cols)!=0:
        cmp_cols_overlap.append(len(cmp_cols&spider_train_cmp_cols)/len(cmp_cols))
    non_select_cols = set(get_col_names([z[0] for z in where_conds+having_conds]+groupby+orderby))
    if len(non_select_cols)!=0:
        non_select_cols_overlap.append(len(non_select_cols&spider_train_non_select_cols)/len(non_select_cols))
    select_cols = set(get_col_names(select))
    if len(select_cols)!=0:
        select_cols_overlap.append(len(select_cols&spider_train_select_cols)/len(select_cols))
    vals = [z for cond in where_conds+having_conds for z in cond[1:] if z is not None]
    vals = set([process_val(val) for val in vals if val])
    vals.discard('')
    if len(vals)!=0:
        val_overlap.append(len(vals&spider_train_values)/len(vals))
print(np.mean([len(x&spider_train_column_names)/len(x) for _,x in per_db_column_name.items()]))
print(np.mean(cmp_cols_overlap))
print(np.mean(non_select_cols_overlap))
print(np.mean(select_cols_overlap))
print(np.mean(val_overlap))
print('-'*20)

yelp_dev
0.5384615384615384
0.66454802259887
0.6673497267759562
0.8073770491803278
0.24081920903954804
--------------------


In [162]:
cmp_cols_overlap

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0

In [176]:
per_db_column_name['yelp']&spider_train_cmp_cols

{'city',
 'id',
 'latitude',
 'month',
 'name',
 'rating',
 'seq',
 'state',
 'text',
 'year'}

In [202]:
for dataset_id, dataset in datasets.items():
    print(dataset_id)
    val_overlap = []
    per_db_column_name = {}
    non_select_cols_overlap = []
    for x in dataset:
        if x['db_id'] not in per_db_column_name:
            per_db_column_name[x['db_id']] = set([x[1].lower() for x in dbs[x['db_id']]['column_names'][1:]])
            per_db_val_used[x['db_id']] = set()
            per_db_non_select_cols[x['db_id']] = set()
        where_conds = [z for z in get_where_cond(x['sql'])]
        select = get_select_col(x['sql'])
        groupby = get_groupby(x['sql'])
        orderby = get_orderby(x['sql'])
        having_conds = [z for z in get_having_cond(x['sql'])]
        def get_col_names(cols):
            cols = set(cols)
    #             cols.discard(0)
            return sorted([dbs[x['db_id']]['column_names'][col][1].lower() for col in cols])
        cmp_cols = set(get_col_names([z[0] for z in where_conds+having_conds if z[1] is not None or z[2] is not None]))
        if len(cmp_cols)!=0:
            cmp_cols_overlap.append(len(cmp_cols&totto_non_select_cols)/len(cmp_cols))
        non_select_cols = set(get_col_names([z[0] for z in where_conds+having_conds]+groupby+orderby))
        if len(non_select_cols)!=0:
            non_select_cols_overlap.append(len(non_select_cols&totto_non_select_cols)/len(non_select_cols))
        vals = [z for cond in where_conds+having_conds for z in cond[1:] if z is not None]
        vals = set([process_val(val) for val in vals if val])
        vals.discard('')
        if len(vals)!=0:
            val_overlap.append(len(vals&totto_values)/len(vals))
    print('%.3f,%.3f,%.3f'%(
        np.mean([len(x&totto_column_names)/len(x) for _,x in per_db_column_name.items()]),
        np.mean(cmp_cols_overlap),
        np.mean(non_select_cols_overlap)))
    print('%.3f'%np.mean(val_overlap))
    print('-'*20)

spider_dev
0.477,0.947,0.499
0.824
--------------------
spider_dev_modified
0.494,0.926,0.644
0.825
--------------------
atis_dev
0.275,0.916,0.763
0.960
--------------------
geography_dev
0.778,0.871,0.475
0.997
--------------------
restaurants_dev
0.700,0.874,0.964
0.643
--------------------
scholar_dev
0.118,0.802,0.149
0.217
--------------------
imdb_dev
0.407,0.804,0.856
0.904
--------------------
yelp_dev
0.538,0.802,0.723
0.702
--------------------
advising_dev
0.267,0.823,0.898
0.569
--------------------
academic_dev
0.407,0.824,0.863
0.291
--------------------
spider_train
0.437,0.721,0.469
0.786
--------------------


In [48]:
datasets = {'imdb':
[39, 49, 50, 51, 52, 53, 57, 58, 64, 70, 71, 78, 85, 86, 91, 95, 102, 103, 104, 110, 121, 122, 126, 127],
'geography':
[76, 102, 119, 128, 130, 138, 149, 169, 170, 172, 285, 286, 301, 303, 304, 308, 320, 323, 327, 328, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 371, 372, 458, 463, 469, 471, 498, 504, 505, 506, 507, 508, 521, 522, 524, 525, 527, 530, 538, 558, 563, 565, 573, 574, 582, 585, 590, 593, 595],
'restaurants':
[0, 1, 2, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377],
'academic':
[50, 64, 67, 71, 75, 77, 99, 125, 126, 131, 132, 138, 139, 162, 189, 193],
'atis':
[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 28, 29, 30, 31, 32, 113, 115, 119, 120, 125, 126, 135, 141, 142, 143, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 167, 168, 169, 184, 185, 188, 205, 207, 214, 215, 222, 223, 224, 225, 233, 234, 237, 238, 239, 240, 241, 242, 243, 245, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 270, 271, 273, 281, 283, 285, 286, 287, 288, 290, 296, 297, 298, 299, 300, 301, 304, 305, 306, 307, 308, 310, 311, 312, 313, 316, 319, 321, 322, 323, 324, 330, 331, 332, 333, 334, 336, 341, 342, 343, 344, 349, 350, 351, 352, 362, 367, 368, 371, 372, 376, 380, 384, 385, 387, 388, 389, 390, 391, 394, 397, 398, 402, 403, 409, 415, 416, 418, 421, 423, 424, 426, 429, 430, 431, 432, 435, 436, 437, 438, 439, 441, 442, 444, 445, 447, 449, 450, 453, 454, 456, 457, 458, 461, 463, 464, 466, 468, 469, 470, 471, 473, 474, 475, 477, 478, 479, 480, 481, 483],
'advising':
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 45, 48, 55, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 339, 341, 342, 343, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 360, 362, 365, 367, 368, 370, 371, 372, 373, 374, 375, 376, 377, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 513, 514, 516, 519, 529, 532, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1127, 1128, 1129, 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1144, 1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1199, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1220, 1221, 1222, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237, 1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1294, 1295, 1296, 1297, 1298, 1299, 1300, 1301, 1302, 1303, 1304, 1305, 1306, 1307, 1308, 1309, 1310, 1311, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1320, 1321, 1322, 1323, 1324, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1359, 1360, 1373, 1374, 1375, 1376, 1377, 1378, 1379, 1380, 1381, 1382, 1383, 1384, 1385, 1386, 1387, 1388, 1389, 1390, 1391, 1392, 1393, 1394, 1395, 1396, 1397, 1398, 1399, 1400, 1401, 1402, 1403, 1404, 1405, 1406, 1407, 1408, 1409, 1410, 1411, 1412, 1413, 1414, 1415, 1416, 1417, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1430, 1431, 1432, 1433, 1434, 1435, 1436, 1437, 1438, 1439, 1440, 1441, 1442, 1443, 1444, 1445, 1446, 1447, 1448, 1449, 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1488, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1561, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1726, 1727, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751, 1752, 1753, 1754, 1755, 1756, 1757, 1758, 1759, 1760, 1761, 1762, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784, 1785, 1786, 1787, 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1841, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1870, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1882, 1883, 1884, 1885, 1886, 1887, 1888, 1889, 1890, 1891, 1892, 1893, 1894, 1895, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1938, 1939, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1996, 1997, 1998, 2000, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025, 2026, 2027, 2028, 2029, 2030, 2041, 2042, 2043, 2044, 2045, 2046, 2047, 2048, 2049, 2050, 2051, 2052, 2053, 2054, 2055, 2056, 2057, 2058, 2059, 2060, 2061, 2062, 2063, 2064, 2065, 2066, 2067, 2068, 2069, 2070, 2071, 2072, 2073, 2074, 2075, 2087, 2088, 2089, 2091, 2092, 2093, 2094, 2095, 2096, 2097, 2098, 2099, 2100, 2101, 2102, 2103, 2104, 2105, 2106, 2107, 2108, 2109, 2110, 2111, 2112, 2113, 2114, 2115, 2116, 2117, 2118, 2119, 2120, 2121, 2122, 2123, 2124, 2125, 2126, 2127, 2128, 2129, 2130, 2131, 2132, 2133, 2134, 2135, 2136, 2137, 2138, 2139, 2140, 2141, 2142, 2143, 2144, 2145, 2146, 2147, 2148, 2149, 2150, 2151, 2152, 2153, 2154, 2155, 2156, 2157, 2158, 2159, 2160, 2161, 2162, 2163, 2164, 2165, 2166, 2167, 2168, 2169, 2170, 2171, 2172, 2173, 2174, 2175, 2176, 2177, 2178, 2179, 2180, 2181, 2193, 2194, 2195, 2196, 2197, 2198, 2199, 2200, 2201, 2202, 2203, 2204, 2205, 2206, 2207, 2208, 2209, 2210, 2211, 2212, 2213, 2214, 2215, 2216, 2217, 2218, 2219, 2220, 2221, 2222, 2223, 2224, 2225, 2226, 2227, 2228, 2229, 2230, 2231, 2232, 2233, 2234, 2235, 2236, 2237, 2238, 2239, 2240, 2241, 2242, 2243, 2244, 2245, 2246, 2247, 2248, 2249, 2250, 2251, 2252, 2253, 2254, 2255, 2256, 2257, 2258, 2259, 2260, 2261, 2262, 2263, 2264, 2265, 2266, 2267, 2268, 2269, 2270, 2271, 2272, 2273, 2274, 2275, 2276, 2277, 2278, 2279, 2280, 2281, 2282, 2283, 2284, 2285, 2286, 2287, 2288, 2289, 2302, 2307, 2308, 2309, 2315, 2318, 2319, 2320, 2321, 2322, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2336, 2337, 2339, 2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347, 2348, 2349, 2350, 2351, 2352, 2353, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2362, 2363, 2364, 2365, 2376, 2377, 2378, 2379, 2380, 2381, 2382, 2383, 2384, 2385, 2386, 2387, 2388, 2389, 2390, 2391, 2392, 2393, 2394, 2395, 2396, 2397, 2398, 2399, 2400, 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2409, 2410, 2411, 2412, 2413, 2414, 2415, 2416, 2417, 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438, 2439, 2440, 2441, 2442, 2443, 2444, 2445, 2446, 2447, 2448, 2449, 2450, 2451, 2452, 2453, 2454, 2455, 2466, 2467, 2468, 2469, 2470, 2471, 2472, 2473, 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2482, 2483, 2484, 2485, 2486, 2487, 2488, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2497, 2498, 2499, 2500, 2501, 2502, 2503, 2504, 2505, 2506, 2507, 2508, 2509, 2510, 2511, 2512, 2513, 2514, 2515, 2516, 2517, 2518, 2519, 2520, 2521, 2522, 2523, 2524, 2525, 2526, 2527, 2528, 2529, 2530, 2531, 2532, 2533, 2534, 2535, 2536, 2537, 2539, 2545, 2547, 2548, 2549, 2550, 2551, 2552, 2553, 2554, 2555, 2556, 2557, 2558, 2559, 2560, 2561, 2562, 2563, 2564, 2565, 2566, 2567, 2568, 2569, 2570, 2571, 2572, 2573, 2574, 2575, 2576, 2577, 2578, 2579, 2580, 2581, 2582, 2583, 2584, 2585, 2586, 2587, 2588, 2589, 2590, 2591, 2592, 2593, 2594, 2595, 2596, 2597, 2598, 2599, 2603, 2610, 2613, 2614, 2615, 2616, 2617, 2618, 2619, 2620, 2621, 2622, 2623, 2624, 2625, 2626, 2627, 2628, 2629, 2630, 2631, 2632, 2633, 2634, 2635, 2636, 2637, 2638, 2639, 2640, 2641, 2642, 2643, 2644, 2645, 2646, 2647, 2648, 2649, 2650, 2651, 2652, 2653, 2654, 2655, 2656, 2657, 2658, 2659, 2660, 2661, 2662, 2663, 2664, 2665, 2666, 2667, 2668, 2669, 2670, 2671, 2672, 2673, 2674, 2675, 2676, 2677, 2678, 2679, 2680, 2681, 2682, 2683, 2684, 2685, 2686, 2687, 2688, 2689, 2690, 2691, 2692, 2693, 2694, 2695, 2696, 2697, 2698, 2699, 2700, 2701, 2702, 2703, 2704, 2705, 2706, 2707, 2708, 2709, 2710, 2711, 2712, 2713, 2714, 2715, 2716, 2717, 2718, 2719, 2720, 2721, 2722, 2723, 2724, 2725, 2726, 2727, 2728, 2729, 2730, 2731, 2732, 2733, 2734, 2735, 2736, 2737, 2738, 2739, 2740, 2741, 2742, 2743, 2744, 2745, 2746, 2747, 2748, 2749, 2750, 2751, 2752, 2753, 2754, 2755, 2756, 2757, 2758, 2759, 2760, 2761, 2762, 2763, 2764, 2765, 2766, 2767, 2768, 2769, 2770, 2771, 2772, 2773, 2774, 2775, 2776, 2777, 2778, 2779, 2780, 2781, 2782, 2783, 2784, 2785, 2786, 2787, 2788, 2789, 2790, 2791, 2792, 2793, 2794, 2795, 2796, 2797, 2798, 2799, 2800, 2801, 2802, 2803, 2804, 2805, 2806, 2807, 2808, 2809, 2810, 2811, 2812, 2813, 2814, 2815, 2816, 2817, 2818, 2819, 2820, 2821, 2822, 2823, 2824, 2825, 2826, 2827, 2828, 2829, 2830, 2831, 2832, 2833, 2834, 2835, 2836, 2837, 2838, 2839, 2840, 2841, 2842, 2843, 2844, 2845, 2846, 2847, 2848, 2849, 2850, 2851, 2852, 2853, 2854, 2855, 2856, 2857],
'scholar':
[0, 1, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 46, 67, 97, 102, 105, 106, 107, 108, 109, 110, 128, 173, 179, 180, 181, 187, 188, 189, 190, 192, 193, 197, 204, 206, 207, 211, 237, 238, 239, 240, 241, 242, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 274, 283, 286, 288, 295, 302, 311, 312, 313, 314, 315, 325, 327, 328, 329, 330, 331, 336, 337, 338, 339, 342, 345, 346, 347, 348, 356, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 390, 391, 399, 400, 401, 403, 409, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 432, 433, 434, 438, 441, 442, 443, 444, 445, 446, 447, 448, 449, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 479, 480, 481, 482, 483, 484, 485, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 502, 503, 515, 517, 528, 529, 530, 532, 534, 535, 544, 545, 546, 555, 556, 557, 558, 562, 563, 564, 565, 566, 577, 578, 579, 580, 583, 584, 585, 586, 588, 590, 591, 592, 594, 596, 597],
'yelp':
[4, 7, 8, 9, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 38, 39, 40, 41, 43, 44, 45, 46, 47, 49, 51, 52, 53, 54, 55, 56, 59, 62, 63, 64, 67, 68, 69, 73, 81, 82, 83, 84, 85, 86, 87, 89, 91, 95, 96, 97, 102, 108, 110, 111, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124]
}
datadir = '/datadrive/xiaden/workspace/featurestorage/data/spider-20200607'

In [94]:
import re
def is_value(unit):
    if unit[0] in {'"',"'"} or unit[0].isdigit():
        return True
    return False
def is_column(unit):
    if '.' in unit and not is_value(unit):
        return True
    return False
after_column = set()
for dataset_id in datasets:
    print(dataset_id)
    with open(os.path.join(datadir, f'{dataset_id}_origin.json'), 'r') as f:
        data = json.load(f)
    num_match = 0
    num_ok = 0
    num_col = 0
    for i, example in enumerate(data):
        if i in datasets[dataset_id]:
            continue
        question = example['question'].lower()
        query = example['query'].lower()
        query = re.sub('>=|<=|>|<|not like|like|between| not in | in | is not | is |=',' = ', query)
        query = query.split()
        all_cmp_columns = set()
        for i in range(len(query)):
            unit = query[i]
            if is_column(unit):
                if i<len(query):
                    after_column.add(query[i+1])
                    if query[i+1] == '=' and is_value(query[i+2]):
                        all_cmp_columns.add(unit)
        num_ok += 1
        if len(all_cmp_columns)!=0:
            num_col += 1
            if all([question.find(col.split('.')[-1].replace('_',' '))!=-1 for col in all_cmp_columns]):
                num_match+=1
    print(num_ok,'%.1f, %.1f'%(num_match/num_col*100,(num_match+num_ok-num_col)/num_ok*100))

imdb
107 1.0, 2.8
geography
532 3.9, 35.3
restaurants
27 0.0, 0.0
academic
180 11.4, 13.9
atis
289 0.0, 0.3
advising
324 0.3, 0.3
scholar
393 0.0, 1.5
yelp
54 8.0, 14.8


In [86]:
all_cmp_columns

set()

In [59]:
is_value('businessalias0.name')

False

In [88]:
num_match

0