# Augmented Test to SQL Grammar Parser
Uses an augmented context free grammar (CFG) to parse natural language queries into SQL queries to search the Air Traffic Information Systems (ATIS) database

In [36]:
import os
import nltk
from cryptography.fernet import Fernet
import copy
import datetime
import math
import re
import sys
import warnings
import wget
import sqlite3
import torch
import torch.nn as nn
import torchtext.legacy as tt
from cryptography.fernet import Fernet
from tqdm import tqdm

In [37]:
# Set random seeds
seed = 1234
torch.manual_seed(seed)
# Set timeout for executing SQL
TIMEOUT = 3 # seconds

# GPU check: Set runtime type to use GPU where available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)

cpu


In [38]:
## Download needed scripts and data
os.makedirs('data', exist_ok=True)
os.makedirs('scripts', exist_ok=True)
source_url = "https://raw.githubusercontent.com/nlp-course/data/master"

# Grammar to augment
if not os.path.isfile('data/grammar'):
  wget.download(f"{source_url}/ATIS/grammar_distrib4.crypt", out="data/")

  # Decrypt the grammar file
  key = b'bfksTY2BJ5VKKK9xZb1PDDLaGkdu7KCDFYfVePSEfGY='
  fernet = Fernet(key)
  with open('./data/grammar_distrib4.crypt', 'rb') as f:
    restored = Fernet(key).decrypt(f.read())
  with open('./data/grammar', 'wb') as f:
    f.write(restored)

# Download scripts and ATIS database
wget.download(f"{source_url}/scripts/trees/transform.py", out="scripts/")
wget.download(f"{source_url}/ATIS/atis_sqlite.db", out="data/")

100% [....................................................] 16404480 / 16404480

'data//atis_sqlite (1).db'

In [39]:
# Import downloaded scripts for parsing augmented grammars
sys.path.insert(1, './scripts')
import transform as xform

Making grammar specific convenience functions for augmentations

In [40]:
def constant(value):
  """Return `value`, ignoring any arguments"""
  return lambda *args: value

def first(*args):
  """Return the value of the first (and perhaps only) subconstituent, 
     ignoring any others"""  
  return args[0]

def numeric_template(rhs):
  """Ignore the subphrase meanings and lookup the first right-hand-side symbol 
     as a number"""
  return constant({'zero':0, 'one':1, 'two':2, 'three':3, 'four':4, 'five':5,
          'six':6, 'seven':7, 'eight':8, 'nine':9, 'ten':10}[rhs[0]])

def forward(F, A):
  """Forward application: Return the application of the first 
     argument to the second"""
  return F(A)

def backward(A, F):
  """Backward application: Return the application of the second 
     argument to the first"""
  return F(A)

def second(*args):
  """Return the value of the second subconstituent, ignoring any others"""
  return args[1]

def ignore(*args):
  """Return `None`, ignoring everything about the constituent. (Good as a
     placeholder until a better augmentation can be devised.)"""
  return None

def upper(term):
  return '"' + term.upper() + '"'

def weekday(day):
  return f"flight.flight_days IN (SELECT days.days_code FROM days WHERE days.day_name = '{day.upper()}')"

def month_name(month):
  return {'JANUARY' : 1,
          'FEBRUARY' : 2,
          'MARCH' : 3,
          'APRIL' : 4,
          'MAY' : 5,
          'JUNE' : 6,
          'JULY' : 7,
          'AUGUST' : 8,
          'SEPTEMBER' : 9,
          'OCTOBER' : 10,
          'NOVEMBER' : 11,
          'DECEMBER' : 12}[month.upper()]

def airports_from_airport_name(airport_name):
  return f"(SELECT airport.airport_code FROM airport WHERE airport.airport_name = {upper(airport_name)})"

def airports_from_city(city):
  return f"""
    (SELECT airport_service.airport_code FROM airport_service WHERE airport_service.city_code IN
      (SELECT city.city_code FROM city WHERE city.city_name = {upper(city)}))
  """

def null_condition(*args, **kwargs):
  return 1

def depart_around(time):
  return f"""
    flight.departure_time >= {add_delta(miltime(time), -15).strftime('%H%M')}
    AND flight.departure_time <= {add_delta(miltime(time), 15).strftime('%H%M')}
    """.strip()

def arrive_around(time):
  return f"""
    flight.arrival_time >= {add_delta(miltime(time), -15).strftime('%H%M')}
    AND flight.arrival_time <= {add_delta(miltime(time), 15).strftime('%H%M')}
    """.strip()

def arrive_before(time):
  return f"""
    flight.arrival_time < {add_delta(miltime(time), -15).strftime('%H%M')}
    """.strip()

def arrive_after(time):
  return f"""
    flight.arrival_time > {add_delta(miltime(time), -15).strftime('%H%M')}
    """.strip()

def arrive(time):
  return f"""
    flight.arrival_time = {add_delta(miltime(time), -15).strftime('%H%M')}
    """.strip()

def depart_before(time):
  return f"""
    flight.departure_time <= {add_delta(miltime(time), -15).strftime('%H%M')}
    """.strip()

def depart_after(time):
  return f"""
    flight.departure_time < {add_delta(miltime(time), -15).strftime('%H%M')}
    """.strip()

def depart(time):
  return f"""
    flight.departure_time = {add_delta(miltime(time), -15).strftime('%H%M')}
    """.strip()

def add_delta(tme, delta):
    # transform to a full datetime first
    return (datetime.datetime.combine(datetime.date.today(), tme) + 
            datetime.timedelta(minutes=delta)).time()

def miltime(minutes):
  return datetime.time(hour=int(minutes/100), minute=(minutes % 100))


def s_node(NP):
  return f'SELECT DISTINCT flight.flight_id FROM flight WHERE {NP}'

def to_airport(place):
  return f'flight.to_airport IN {place}'

def from_airport(place):
  return f'flight.from_airport IN {place}'

def between_airports(origin, destination):
  return f'flight.from_airport IN {origin} AND flight.to_airport IN {destination}'

def conjoin(A,B):
  return f'{B} AND {A}'

def conjoin_forward(A, B):
  return f'{A} AND {B}'

def airline_code(airline):
  return f"flight.airline_code = '{airline}'"

Load and preprocess the grammar

In [41]:
# Acquire the datasets - training, development, and test splits of the 
# ATIS queries and corresponding SQL queries
wget.download(f"{source_url}/ATIS/test_flightid.nl", out="data/")
wget.download(f"{source_url}/ATIS/test_flightid.sql", out="data/")
wget.download(f"{source_url}/ATIS/dev_flightid.nl", out="data/")
wget.download(f"{source_url}/ATIS/dev_flightid.sql", out="data/")
wget.download(f"{source_url}/ATIS/train_flightid.nl", out="data/")
wget.download(f"{source_url}/ATIS/train_flightid.sql", out="data/")

100% [......................................................] 2591248 / 2591248

'data//train_flightid (1).sql'

Use torchtext to process the data, with field SRC for the natural language questions and TGT for the SQL queries.

In [42]:
## Tokenizer
tokenizer = nltk.tokenize.RegexpTokenizer('\d+|st\.|[\w-]+|\$[\d\.]+|\S+')
def tokenize(string):
  return tokenizer.tokenize(string.lower())

In [43]:
SRC = tt.data.Field(include_lengths=True,         # include lengths
                    batch_first=False,            # batches will be max_len x batch_size
                    tokenize=tokenize,            # use our tokenizer
                   ) 
TGT = tt.data.Field(include_lengths=False,
                    batch_first=False,            # batches will be max_len x batch_size
                    tokenize=lambda x: x.split(), # use split to tokenize
                    init_token="<bos>",           # prepend <bos>
                    eos_token="<eos>")            # append <eos>
fields = [('src', SRC), ('tgt', TGT)]

In [44]:
 # Make splits for data
train_data, val_data, test_data = tt.datasets.TranslationDataset.splits(
    ('_flightid.nl', '_flightid.sql'), fields, path='./data/',
    train='train', validation='dev', test='test')

MIN_FREQ = 3
SRC.build_vocab(train_data.src, min_freq=MIN_FREQ)
TGT.build_vocab(train_data.tgt, min_freq=MIN_FREQ)

print (f"Size of English vocab: {len(SRC.vocab)}")
print (f"Most common English words: {SRC.vocab.freqs.most_common(10)}\n")

print (f"Size of SQL vocab: {len(TGT.vocab)}")
print (f"Most common SQL words: {TGT.vocab.freqs.most_common(10)}\n")

print (f"Index for start of sequence token: {TGT.vocab.stoi[TGT.init_token]}")
print (f"Index for end of sequence token: {TGT.vocab.stoi[TGT.eos_token]}")

Size of English vocab: 421
Most common English words: [('to', 3478), ('from', 3019), ('flights', 2094), ('the', 1550), ('on', 1230), ('me', 973), ('flight', 972), ('show', 845), ('what', 833), ('boston', 813)]

Size of SQL vocab: 392
Most common SQL words: [('=', 38876), ('AND', 36564), (',', 22772), ('airport_service', 8314), ('city', 8313), ('(', 6432), (')', 6432), ('flight_1.flight_id', 4536), ('flight', 4221), ('SELECT', 4178)]

Index for start of sequence token: 2
Index for end of sequence token: 3


Batch the data to facilitate processing on a GPU. sort_key function allows for sorting on length, to minimize the padding on the source side.

In [45]:
BATCH_SIZE = 16 # batch size for training/validation
TEST_BATCH_SIZE = 1 # batch size for test, we use 1 to make beam search implementation easier

train_iter, val_iter = tt.data.BucketIterator.splits((train_data, val_data),
                                                     batch_size=BATCH_SIZE, 
                                                     device=device,
                                                     repeat=False, 
                                                     sort_key=lambda x: len(x.src), 
                                                     sort_within_batch=True)
test_iter = tt.data.BucketIterator(test_data, 
                                   batch_size=TEST_BATCH_SIZE, 
                                   device=device,
                                   repeat=False, 
                                   sort=False, 
                                   train=False)

In [46]:
batch = next(iter(train_iter))
train_batch_text, train_batch_text_lengths = batch.src
train_batch_sql = batch.tgt

Set up a SQL database to test the parses correctly return the right database entries using sqlite3 module.

In [47]:
def execute_sql(sql):
  conn = sqlite3.connect('data/atis_sqlite.db')  # establish the DB based on the downloaded data
  c = conn.cursor()                              # build a "cursor"
  c.execute(sql)
  results = list(c.fetchall())
  c.close()
  conn.close()
  return results

To run query, use execute function and retrieve results with fetchall. We can also build a parser with the augmented grammar. To interpret the tree, we recursively add the meanings of the child nodes until we have a completed tree representation to a semantic representation.

In [48]:
def interpret(tree, augmentations):
  syntactic_rule = tree.productions()[0]
  semantic_rule = augmentations[syntactic_rule]
  child_meanings = [interpret(child, augmentations) 
                    for child in tree 
                    if isinstance(child, nltk.Tree)]
  return semantic_rule(*child_meanings)

An example of a parse tree for a query is shown below:

In [49]:
def parse_tree(sentence):
  """Parse a sentence and return the parse tree, or None if failure."""
  try:
    parses = list(atis_parser.parse(tokenize(sentence)))
    if len(parses) == 0:
      return None
    else:
      return parses[0]
  except:
    return None



sample_query = "flights to boston"
print(tokenize(sample_query))
sample_tree = parse_tree(sample_query)
sample_tree.pretty_print()

['flights', 'to', 'boston']
                S                         
                |                          
            NP_FLIGHT                     
                |                          
            NOM_FLIGHT                    
                |                          
             N_FLIGHT                     
      __________|_________                 
     |                    PP              
     |                    |                
     |                 PP_PLACE           
     |           _________|_________       
  N_FLIGHT      |                N_PLACE  
     |          |                   |      
TERM_FLIGHT  P_PLACE            TERM_PLACE
     |          |                   |      
  flights       to                boston  



Given a sentence, we first construct its parse tree using the syntactic rules, then compose the corresponding semantic rules bottom-up, until eventually we arrive at the root node with a finished SQL statement. 

In [51]:
atis_grammar, atis_augmentations = xform.read_augmented_grammar('data/grammar', globals=globals())
atis_parser = nltk.parse.BottomUpChartParser(atis_grammar)
predicted_sql = interpret(sample_tree, atis_augmentations)

# print out the predicted SQL from the grammar
print("Predicted SQL:\n\n", predicted_sql, "\n")

Predicted SQL:

 SELECT DISTINCT flight.flight_id FROM flight WHERE flight.to_airport IN 
    (SELECT airport_service.airport_code FROM airport_service WHERE airport_service.city_code IN
      (SELECT city.city_code FROM city WHERE city.city_name = "BOSTON"))
   AND 1 



We can create a function verify to compare the predicted SQL to the ground truth SQL from the database

In [52]:
def verify(predicted_sql, gold_sql, silent=True):
  """
  Compare the correctness of the generated SQL by executing on the 
  ATIS database and comparing the returned results.
  Arguments:
      predicted_sql: the predicted SQL query
      gold_sql: the reference SQL query to compare against
      silent: print outputs or not
  Returns: True if the returned results are the same, otherwise False
  """
  # Execute predicted SQL
  try:
    predicted_result = execute_sql(predicted_sql)
  except BaseException as e:
    if not silent:
      print(f"predicted sql exec failed: {e}")
    return False
  if not silent:
    print("Predicted DB result:\n\n", predicted_result[:10], "\n")

  # Execute gold SQL
  try:
    gold_result = execute_sql(gold_sql)
  except BaseException as e:
    if not silent:
      print(f"gold sql exec failed: {e}")
    return False
  if not silent:
    print("Gold DB result:\n\n", gold_result[:10], "\n")
  
  # Verify correctness
  if gold_result == predicted_result:
    return True

We can make a simple checking function to see how accurate our parses are.

In [53]:
def rule_based_trial(sentence, gold_sql):
  print("Sentence: ", sentence, "\n")
  tree = parse_tree(sentence)
  print("Parse:\n\n")
  tree.pretty_print()

  predicted_sql = interpret(tree, atis_augmentations)
  print("Predicted SQL:\n\n", predicted_sql, "\n")

  if verify(predicted_sql, gold_sql, silent=False):
    print ('Correct!')
  else:
    print ('Incorrect!')

In [54]:
example_1 = 'flights from phoenix to milwaukee'
gold_sql_1 = """
  SELECT DISTINCT flight_1.flight_id 
  FROM flight flight_1 , 
       airport_service airport_service_1 , 
       city city_1 , 
       airport_service airport_service_2 , 
       city city_2 
  WHERE flight_1.from_airport = airport_service_1.airport_code 
        AND airport_service_1.city_code = city_1.city_code 
        AND city_1.city_name = 'PHOENIX' 
        AND flight_1.to_airport = airport_service_2.airport_code 
        AND airport_service_2.city_code = city_2.city_code 
        AND city_2.city_name = 'MILWAUKEE'
  """

rule_based_trial(example_1, gold_sql_1)

Sentence:  flights from phoenix to milwaukee 

Parse:


                                  S                                 
                                  |                                  
                              NP_FLIGHT                             
                                  |                                  
                              NOM_FLIGHT                            
                                  |                                  
                               N_FLIGHT                             
                __________________|_________________                 
            N_FLIGHT                                |               
      _________|________                            |                
     |                  PP                          PP              
     |                  |                           |                
     |               PP_PLACE                    PP_PLACE           
     |          ________|_________       

In [55]:
# Example 2
example_2 = 'i would like a united flight'
gold_sql_2 = """
  SELECT DISTINCT flight_1.flight_id
  FROM flight flight_1 
  WHERE flight_1.airline_code = 'UA'
  """

rule_based_trial(example_2, gold_sql_2)

Sentence:  i would like a united flight 

Parse:


                                                 S                                                                      
                                     ____________|____________________________________________________                   
                                    |                                                             NP_FLIGHT             
                                    |                                                                 |                  
                                PREIGNORE                                                         NOM_FLIGHT            
        ____________________________|____________                                          ___________|___________       
       |                                     PREIGNORE                                   ADJ                      |     
       |                _________________________|____________                            |        

Rather than one off checks, we can more systematically check our translation from natural language to SQL query with a function that checks for the precision, recall, and F1 of the predictions. It takes as an argument a predictor function that maps token sequences to predicted SQL queries. 

The augmented parser is not augmented to capture all sentences, only about a precision of 70% is expected, since not all sentences are able to be predicted in the first place.

In [56]:
def evaluate(predictor, dataset, num_examples=0, silent=True):
  """Evaluate accuracy of `predictor` by executing predictions on a
  SQL database and comparing returned results against those of gold queries.
  
  Arguments:
      predictor:    a function that maps a token sequence (provided by torchtext)
                    to a predicted SQL query string
      dataset:      the dataset of token sequences and gold SQL queries
      num_examples: number of examples from `dataset` to use; all of
                    them if 0
      silent: if set to False, will print out logs
  Returns: precision, recall, and F1 score
  """
  # Prepare to count results
  if num_examples <= 0:
    num_examples = len(dataset)
  example_count = 0
  predicted_count = 0
  correct = 0
  incorrect = 0

  # Process the examples from the dataset
  for example in tqdm(dataset[:num_examples]):
    example_count += 1
    # obtain query SQL
    predicted_sql = predictor(example.src)
    if predicted_sql == None:
      continue
    predicted_count += 1
    # obtain gold SQL
    gold_sql = ' '.join(example.tgt)

    # check that they're compatible
    if verify(predicted_sql, gold_sql):
      correct += 1
    else:
      incorrect += 1
   
  # Compute and return precision, recall, F1
  precision = correct / predicted_count if predicted_count > 0 else 0
  recall = correct / example_count
  f1 = (2 * precision * recall) / (precision + recall) if precision + recall > 0 else 0
  return precision, recall, f1

In [57]:
def rule_based_predictor(tokens):
  query = ' '.join(tokens)    # detokenized query
  tree = parse_tree(query)
  if tree is None:
    return None
  try:
    predicted_sql = interpret(tree, atis_augmentations)
  except Exception as err:
    return None
  return predicted_sql

In [58]:
precision, recall, f1 = evaluate(rule_based_predictor, test_iter.dataset, num_examples=0)
print(f"precision: {precision:3.2f}")
print(f"recall:    {recall:3.2f}")
print(f"F1:        {f1:3.2f}")

100%|████████████████████████████████████████| 332/332 [00:02<00:00, 138.81it/s]

precision: 0.73
recall:    0.27
F1:        0.39



