In [2]:
import os
import sys
import json
import pickle
import torch

from preprocess.common_utils import quote_normalization

In [3]:
!python --version

Python 3.6.5 :: Anaconda, Inc.


## Data inspection

### Tables

In [4]:
table_preproc_path = 'data/tables.bin'
with open(table_preproc_path, 'rb') as f:
    table_preproc = pickle.load(f)
len(table_preproc)

166

In [7]:
table_preproc['concert_singer']

{'column_names': [[-1, '*'],
  [0, 'stadium id'],
  [0, 'location'],
  [0, 'name'],
  [0, 'capacity'],
  [0, 'highest'],
  [0, 'lowest'],
  [0, 'average'],
  [1, 'singer id'],
  [1, 'name'],
  [1, 'country'],
  [1, 'song name'],
  [1, 'song release year'],
  [1, 'age'],
  [1, 'is male'],
  [2, 'concert id'],
  [2, 'concert name'],
  [2, 'theme'],
  [2, 'stadium id'],
  [2, 'year'],
  [3, 'concert id'],
  [3, 'singer id']],
 'column_names_original': [[-1, '*'],
  [0, 'Stadium_ID'],
  [0, 'Location'],
  [0, 'Name'],
  [0, 'Capacity'],
  [0, 'Highest'],
  [0, 'Lowest'],
  [0, 'Average'],
  [1, 'Singer_ID'],
  [1, 'Name'],
  [1, 'Country'],
  [1, 'Song_Name'],
  [1, 'Song_release_year'],
  [1, 'Age'],
  [1, 'Is_male'],
  [2, 'concert_ID'],
  [2, 'concert_Name'],
  [2, 'Theme'],
  [2, 'Stadium_ID'],
  [2, 'Year'],
  [3, 'concert_ID'],
  [3, 'Singer_ID']],
 'column_types': ['text',
  'number',
  'text',
  'text',
  'number',
  'number',
  'number',
  'number',
  'number',
  'text',
  'text',

In [14]:
len(table_preproc['concert_singer']['processed_table_names']), \
len(table_preproc['concert_singer']['processed_column_names'])

(4, 22)

### Original preprocessing datasets

In [11]:
dev_preproc_path = 'data/dev.bin'
with open(dev_preproc_path, 'rb') as f:
    dev_preproc_dataset = pickle.load(f)
len(dev_preproc_dataset)

1034

In [12]:
dev_preproc_dataset[0]

{'db_id': 'concert_singer',
 'query': 'SELECT count(*) FROM singer',
 'query_toks': ['SELECT', 'count', '(', '*', ')', 'FROM', 'singer'],
 'query_toks_no_value': ['select', 'count', '(', '*', ')', 'from', 'singer'],
 'question': 'How many singers do we have?',
 'question_toks': ['How', 'many', 'singers', 'do', 'we', 'have', '?'],
 'sql': {'except': None,
  'from': {'conds': [], 'table_units': [['table_unit', 1]]},
  'groupBy': [],
  'having': [],
  'intersect': None,
  'limit': None,
  'orderBy': [],
  'select': [False, [[3, [0, [0, 0, False], None]]]],
  'union': None,
  'where': []},
 'raw_question_toks': ['how', 'many', 'singers', 'do', 'we', 'have', '?'],
 'processed_question_toks': ['how', 'many', 'singer', 'do', 'we', 'have', '?'],
 'pos_tags': ['WRB', 'JJ', 'NNS', 'VBP', 'PRP', 'VB', '.'],
 'relations': [['question-question-identity',
   'question-question-dist1',
   'question-question-dist2',
   'question-question-generic',
   'question-question-generic',
   'question-question-

In [23]:
dev_preproc_dataset[0].keys()

dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql', 'raw_question_toks', 'processed_question_toks', 'pos_tags', 'relations', 'schema_linking', 'used_tables', 'used_columns', 'ast', 'actions'])

In [13]:
len(dev_preproc_dataset[0]['processed_question_toks'])

7

In [21]:
dev_lgesql_path = 'data/dev.lgesql.bin'
with open(dev_lgesql_path, 'rb') as f:
    dev_lgesql_dataset = pickle.load(f)
len(dev_lgesql_dataset)

Using backend: pytorch
  return f(*args, **kwds)


1034

In [24]:
dev_lgesql_dataset[0].keys()

dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql', 'raw_question_toks', 'processed_question_toks', 'pos_tags', 'relations', 'schema_linking', 'used_tables', 'used_columns', 'ast', 'actions', 'graph'])

In [27]:
dev_lgesql_dataset[0]['graph']

<utils.graph_example.GraphExample at 0x7fe8fcd42630>

In [28]:
dev_lgesql_dataset[0]['graph'].__dict__.keys()

dict_keys(['global_g', 'global_edges', 'local_g', 'local_edges', 'question_mask', 'schema_mask', 'gp', 'node_label', 'lg'])

In [None]:
_g_edges = dev_lgesql_dataset[0]['graph'].global_edges
_g_edges

In [None]:
_l_edges = dev_lgesql_dataset[0]['graph'].local_edges
_l_edges

In [36]:
set(_l_edges) - set(_g_edges)

set()

In [None]:
for e in _g_edges:
    if e not in _l_edges:
        print(e)

### With ratsql_graph

In [55]:
dev_ratsql_g_preproc_path = '/home/yshao/Projects/SDR-analysis/data/spider/train_others+ratsql_graph.bin'
with open(dev_ratsql_g_preproc_path, 'rb') as f:
    dev_ratsql_g_preproc_dataset = pickle.load(f)
len(dev_ratsql_g_preproc_dataset)

1659

In [17]:
dev_ratsql_g_preproc_dataset[0].keys()

dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql', 'rat_sql_graph', 'raw_question_toks', 'processed_question_toks', 'pos_tags', 'relations', 'schema_linking', 'used_tables', 'used_columns', 'ast', 'actions'])

In [18]:
dev_ratsql_g_preproc_dataset[0]['rat_sql_graph']

{'nodes': ['how',
  'many',
  'singer',
  'do',
  'we',
  'have',
  '?',
  '<C>NONE::*',
  '<C>stadium::stadium_id',
  '<C>stadium::location',
  '<C>stadium::name',
  '<C>stadium::capacity',
  '<C>stadium::highest',
  '<C>stadium::lowest',
  '<C>stadium::average',
  '<C>singer::singer_id',
  '<C>singer::name',
  '<C>singer::country',
  '<C>singer::song_name',
  '<C>singer::song_release_year',
  '<C>singer::age',
  '<C>singer::be_male',
  '<C>concert::concert_id',
  '<C>concert::concert_name',
  '<C>concert::theme',
  '<C>concert::stadium_id',
  '<C>concert::year',
  '<C>singer_in_concert::concert_id',
  '<C>singer_in_concert::singer_id',
  '<T>stadium',
  '<T>singer',
  '<T>concert',
  '<T>singer_in_concert'],
 'q_nodes_orig': ['how', 'many', 'singers', 'do', 'we', 'have', '?'],
 'relations': '[[2, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6], [1, 2, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6,

In [25]:
# compare ratsql tokens and lgesql tokens (question)
for d in dev_ratsql_g_preproc_dataset:
    lgesql_toks = d['processed_question_toks']
    ratsql_toks = [n for n in d['rat_sql_graph']['nodes'] if not n.startswith('<')]
    if len(lgesql_toks) != len(ratsql_toks):
        print('ratsql:', ratsql_toks)
        print('lgesql:', lgesql_toks)
        print()

ratsql: ['what', 'be', 'the', 'name', 'and', 'country', 'of', 'origin', 'of', 'every', 'singer', 'who', 'have', 'a', 'song', 'with', 'the', 'word', '`', 'hey', "'", 'in', 'its', 'title', '?']
lgesql: ['what', 'be', 'the', 'name', 'and', 'country', 'of', 'origin', 'of', 'every', 'singer', 'who', 'have', 'a', 'song', 'with', 'the', 'word', '"', 'hey', '"', 'in', 'it', "'s", 'title', '?']

ratsql: ['for', 'each', 'continent', ',', 'list', 'its', 'id', ',', 'name', ',', 'and', 'how', 'many', 'country', 'it', 'have', '?']
lgesql: ['for', 'each', 'continent', ',', 'list', 'it', 'be', 'id', ',', 'name', ',', 'and', 'how', 'many', 'country', 'it', 'have', '?']

ratsql: ['what', 'be', 'the', 'full', 'name', 'of', 'each', 'car', 'maker', ',', 'along', 'with', 'its', 'id', 'and', 'how', 'many', 'model', 'it', 'produce', '?']
lgesql: ['what', 'be', 'the', 'full', 'name', 'of', 'each', 'car', 'maker', ',', 'along', 'with', 'it', "'s", 'id', 'and', 'how', 'many', 'model', 'it', 'produce', '?']

rats

In [36]:
# compare ratsql tokens and lgesql tokens (schema)
for d in dev_ratsql_g_preproc_dataset:
    db_id = d['db_id']
    table_d = table_preproc[db_id]
    lgesql_tables = [n for n in table_d['processed_table_names']]
    ratsql_tables = [n.split('<T>')[1].replace('_', ' ') for n in d['rat_sql_graph']['nodes'] if n.startswith('<T>')]
    if len(lgesql_tables) != len(ratsql_tables):
        print('ratsql tables:', ratsql_tables)
        print('lgesql tables:', lgesql_tables)
        print()
    lgesql_cols = [n for n in table_d['processed_column_names']]
    ratsql_cols = [n.split('::')[1].replace('_', ' ') for n in d['rat_sql_graph']['nodes'] if n.startswith('<C>')]
    if len(lgesql_cols) != len(ratsql_cols):
        print('ratsql cols:', ratsql_cols)
        print('lgesql cols:', lgesql_cols)
        print()

### Wikisql

In [44]:
wikisql_dev_lgesql_path = '/home/yshao/Projects/SDR-analysis/data/wikisql/dev+ratsql_graph.lgesql.bin'
with open(wikisql_dev_lgesql_path, 'rb') as f:
    wikisql_dev_lgesql_dataset = pickle.load(f)
len(wikisql_dev_lgesql_dataset)

8421

In [52]:
wikisql_dev_lgesql_dataset[0].keys()

dict_keys(['phase', 'table_id', 'question', 'sql', 'rat_sql_graph', 'db_id', 'question_toks', 'raw_question_toks', 'processed_question_toks', 'pos_tags', 'relations', 'schema_linking', 'used_tables', 'used_columns', 'ast', 'actions', 'graph'])

In [53]:
wikisql_dev_lgesql_dataset[0]['sql']

{'sel': 3, 'conds': [[5, 0, 'Butler CC (KS)']], 'agg': 0}

In [45]:
wikisql_dev_lgesql_dataset[0]

{'phase': 1,
 'table_id': '1-10015132-11',
 'question': 'What position does the player who played for butler cc (ks) play?',
 'sql': {'sel': 3, 'conds': [[5, 0, 'Butler CC (KS)']], 'agg': 0},
 'rat_sql_graph': {'nodes': ['what',
   'position',
   'do',
   'the',
   'player',
   'who',
   'play',
   'for',
   'butler',
   'cc',
   '-lrb-',
   'k',
   '-rrb-',
   'play',
   '?',
   '<C>NONE::*',
   '<C>l::player',
   '<C>l::no',
   '<C>l::nationality',
   '<C>l::position',
   '<C>l::year_in_toronto',
   '<C>l::school_or_club_team',
   '<T>l'],
  'q_nodes_orig': ['what',
   'position',
   'does',
   'the',
   'player',
   'who',
   'played',
   'for',
   'butler',
   'cc',
   '(',
   'ks',
   ')',
   'play',
   '?'],
  'relations': '[[2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6], [1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 37, 5, 5, 6], [0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6], [0, 0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,

In [46]:
wikisql_dev_lgesql_dataset[0]['graph'].local_edges

[(0, 1, 'question-question-dist1'),
 (0, 15, 'question-table-nomatch'),
 (0, 16, 'question-column-nomatch'),
 (0, 17, 'question-column-nomatch'),
 (0, 18, 'question-column-nomatch'),
 (0, 19, 'question-column-nomatch'),
 (0, 20, 'question-column-nomatch'),
 (0, 21, 'question-column-nomatch'),
 (0, 22, 'question-column-nomatch'),
 (1, 0, 'question-question-dist-1'),
 (1, 2, 'question-question-dist1'),
 (1, 15, 'question-table-nomatch'),
 (1, 16, 'question-column-nomatch'),
 (1, 17, 'question-column-nomatch'),
 (1, 18, 'question-column-nomatch'),
 (1, 19, 'question-column-nomatch'),
 (1, 20, 'question-column-exactmatch'),
 (1, 21, 'question-column-nomatch'),
 (1, 22, 'question-column-nomatch'),
 (2, 1, 'question-question-dist-1'),
 (2, 3, 'question-question-dist1'),
 (2, 15, 'question-table-nomatch'),
 (2, 16, 'question-column-nomatch'),
 (2, 17, 'question-column-nomatch'),
 (2, 18, 'question-column-nomatch'),
 (2, 19, 'question-column-nomatch'),
 (2, 20, 'question-column-nomatch'),
 (2,

## Temp

In [12]:
q = "What is the checking balance of the account whose owner\u2019s name contains the substring \u2018ee\u2019?"
q

'What is the checking balance of the account whose owner’s name contains the substring ‘ee’?'

In [13]:
print('Question:', q)

Question: What is the checking balance of the account whose owner’s name contains the substring ‘ee’?


In [17]:
q_toks = [
    "What",
    "is",
    "the",
    "checking",
    "balance",
    "of",
    "the",
    "account",
    "whose",
    "owner\u2019s",
    "name",
    "contains",
    "the",
    "substring",
    "\u2018ee\u2019",
    "?"]
quote_normalization(q_toks)

['What',
 'is',
 'the',
 'checking',
 'balance',
 'of',
 'the',
 'account',
 'whose',
 'owner’s',
 'name',
 'contains',
 'the',
 'substring',
 '"',
 'ee',
 '"',
 '?']

In [7]:
q.encode('ascii')

UnicodeEncodeError: 'ascii' codec can't encode character '\u2018' in position 65: ordinal not in range(128)

In [11]:
s = 'ç'
s.encode('ascii')

UnicodeEncodeError: 'ascii' codec can't encode character '\xe7' in position 0: ordinal not in range(128)

### Test functions

In [58]:
A = torch.zeros((10, 5), dtype=int)
mask = torch.BoolTensor([0, 0, 1, 0, 0, 1, 1, 0, 0, 1]).unsqueeze(-1)
src = torch.eye(5, dtype=int)
A, mask, src

(tensor([[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]),
 tensor([[False],
         [False],
         [ True],
         [False],
         [False],
         [ True],
         [ True],
         [False],
         [False],
         [ True]]),
 tensor([[1, 0, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 0, 1]]))

In [59]:
torch.masked_scatter(A, mask, src)

tensor([[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0]])