Prepare data from UMich, including:
- ATIS
- GEO
- Restaurants
- Scholar
- Academic
- IMDB
- Yelp
- Advising

Also create new subset from spider dev, removing column mention

In [1]:
import json
import pandas as pd
import os
import sqlite3
import re
import collections

In [2]:
"""
Utility functions for reading the standardised text2sql datasets presented in
`"Improving Text to SQL Evaluation Methodology" <https://arxiv.org/abs/1806.09029>`_
"""
# from allennlp
from typing import List, Dict, NamedTuple, Iterable, Tuple, Set
from collections import defaultdict

class SqlData(NamedTuple):
    """
    A utility class for reading in text2sql data.
    Parameters
    ----------
    text : ``List[str]``
        The tokens in the text of the query.
    text_with_variables : ``List[str]``
        The tokens in the text of the query with variables
        mapped to table names/abstract variables.
    variable_tags : ``List[str]``
        Labels for each word in ``text`` which correspond to
        which variable in the sql the token is linked to. "O"
        is used to denote no tag.
    sql : ``List[str]``
        The tokens in the SQL query which corresponds to the text.
    text_variables : ``Dict[str, str]``
        A dictionary of variables associated with the text, e.g. {"city_name0": "san fransisco"}
    sql_variables : ``Dict[str, Dict[str, str]]``
        A dictionary of variables and column references associated with the sql query.
    """

    text: List[str]
    text_with_variables: List[str]
    variable_tags: List[str]
    sql: List[str]
    text_variables: Dict[str, str]
    sql_variables: Dict[str, Dict[str, str]]


class TableColumn(NamedTuple):
    name: str
    column_type: str
    is_primary_key: bool


def column_has_string_type(column: TableColumn) -> bool:
    if "varchar" in column.column_type:
        return True
    elif column.column_type == "text":
        return True
    elif column.column_type == "longtext":
        return True

    return False


def column_has_numeric_type(column: TableColumn) -> bool:
    if "int" in column.column_type:
        return True
    elif "float" in column.column_type:
        return True
    elif "double" in column.column_type:
        return True
    return False


def replace_variables(
    sentence: List[str], sentence_variables: Dict[str, str]
) -> Tuple[List[str], List[str]]:
    """
    Replaces abstract variables in text with their concrete counterparts.
    """
    tokens = []
    tags = []
    for token in sentence:
        if token not in sentence_variables:
            tokens.append(token)
            tags.append("O")
        else:
            for word in sentence_variables[token].split():
                tokens.append(word)
                tags.append(token)
    return tokens, tags


def split_table_and_column_names(table: str) -> Iterable[str]:
    partitioned = [x for x in table.partition(".") if x != ""]
    # Avoid splitting decimal strings.
    if partitioned[0].isnumeric() and partitioned[-1].isnumeric():
        return [table]
    return partitioned


def clean_and_split_sql(sql: str) -> List[str]:
    """
    Cleans up and unifies a SQL query. This involves unifying quoted strings
    and splitting brackets which aren't formatted consistently in the data.
    """
    sql_tokens: List[str] = []
    for token in sql.strip().split():
        token = token.replace('"', "'")#.replace("%", "")
        if token.endswith("(") and len(token) > 1:
            sql_tokens.extend(split_table_and_column_names(token[:-1]))
            sql_tokens.extend(split_table_and_column_names(token[-1]))
        else:
            sql_tokens.extend(split_table_and_column_names(token))
    return sql_tokens


def resolve_primary_keys_in_schema(
    sql_tokens: List[str], schema: Dict[str, List[TableColumn]]
) -> List[str]:
    """
    Some examples in the text2sql datasets use ID as a column reference to the
    column of a table which has a primary key. This causes problems if you are trying
    to constrain a grammar to only produce the column names directly, because you don't
    know what ID refers to. So instead of dealing with that, we just replace it.
    """
    primary_keys_for_tables = {
        name: max(columns, key=lambda x: x.is_primary_key).name for name, columns in schema.items()
    }
    resolved_tokens = []
    for i, token in enumerate(sql_tokens):
        if i > 2:
            table_name = sql_tokens[i - 2]
            if token == "ID" and table_name in primary_keys_for_tables.keys():
                token = primary_keys_for_tables[table_name]
        resolved_tokens.append(token)
    return resolved_tokens


def clean_unneeded_aliases(sql_tokens: List[str]) -> List[str]:

    unneeded_aliases = {}
    previous_token = sql_tokens[0]
    for (token, next_token) in zip(sql_tokens[1:-1], sql_tokens[2:]):
        if token == "AS" and previous_token is not None:
            # Check to see if the table name without the alias
            # is the same.
            table_name = next_token[:-6]
            if table_name == previous_token:
                # If so, store the mapping as a replacement.
                unneeded_aliases[next_token] = previous_token

        previous_token = token

    dealiased_tokens: List[str] = []
    for token in sql_tokens:
        new_token = unneeded_aliases.get(token, None)

        if new_token is not None and dealiased_tokens[-1] == "AS":
            dealiased_tokens.pop()
            continue
        elif new_token is None:
            new_token = token

        dealiased_tokens.append(new_token)

    return dealiased_tokens


def read_dataset_schema(schema_path: str) -> Dict[str, List[TableColumn]]:
    """
    Reads a schema from the text2sql data, returning a dictionary
    mapping table names to their columns and respective types.
    This handles columns in an arbitrary order and also allows
    either ``{Table, Field}`` or ``{Table, Field} Name`` as headers,
    because both appear in the data. It also uppercases table and
    column names if they are not already uppercase.
    Parameters
    ----------
    schema_path : ``str``, required.
        The path to the csv schema.
    Returns
    -------
    A dictionary mapping table names to typed columns.
    """
    schema: Dict[str, List[TableColumn]] = defaultdict(list)
    for i, line in enumerate(open(schema_path, "r")):
        if i == 0:
            header = [x.strip() for x in line.split(",")]
        elif line[0] == "-":
            continue
        else:
            data = {key: value for key, value in zip(header, [x.strip() for x in line.split(",")])}

            table = data.get("Table Name", None) or data.get("Table")
            column = data.get("Field Name", None) or data.get("Field")
            is_primary_key = data.get("Primary Key",data.get("Is Primary Key")) == "y"
            schema[table.upper()].append(TableColumn(column.upper(), data["Type"], is_primary_key))

    return {**schema}


def process_sql_data(
    data: List[Dict],
    use_all_sql: bool = False,
    use_all_queries: bool = False,
    remove_unneeded_aliases: bool = False,
    schema: Dict[str, List[TableColumn]] = None,
) -> Iterable[SqlData]:
    """
    A utility function for reading in text2sql data. The blob is
    the result of loading the json from a file produced by the script
    ``scripts/reformat_text2sql_data.py``.
    Parameters
    ----------
    data : ``JsonDict``
    use_all_sql : ``bool``, optional (default = False)
        Whether to use all of the sql queries which have identical semantics,
        or whether to just use the first one.
    use_all_queries : ``bool``, (default = False)
        Whether or not to enforce query sentence uniqueness. If false,
        duplicated queries will occur in the dataset as separate instances,
        as for a given SQL query, not only are there multiple queries with
        the same template, but there are also duplicate queries.
    remove_unneeded_aliases : ``bool``, (default = False)
        The text2sql data by default creates alias names for `all` tables,
        regardless of whether the table is derived or if it is identical to
        the original (e.g SELECT TABLEalias0.COLUMN FROM TABLE AS TABLEalias0).
        This is not necessary and makes the action sequence and grammar manipulation
        much harder in a grammar based decoder. Note that this does not
        remove aliases which are legitimately required, such as when a new
        table is formed by performing operations on the original table.
    schema : ``Dict[str, List[TableColumn]]``, optional, (default = None)
        A schema to resolve primary keys against. Converts 'ID' column names
        to their actual name with respect to the Primary Key for the table
        in the schema.
    """
    for example in data:
        seen_sentences: Set[str] = set()
        for sent_info in example["sentences"]:
            # Loop over the different sql statements with "equivalent" semantics
            for sql in example["sql"]:
                text_with_variables = sent_info["text"].strip().split()
                text_vars = sent_info["variables"]

                query_tokens, tags = replace_variables(text_with_variables, text_vars)
                if not use_all_queries:
                    key = " ".join(query_tokens)
                    if key in seen_sentences:
                        continue
                    else:
                        seen_sentences.add(key)

                sql_tokens = clean_and_split_sql(sql)
                if remove_unneeded_aliases:
                    sql_tokens = clean_unneeded_aliases(sql_tokens)
                if schema is not None:
                    sql_tokens = resolve_primary_keys_in_schema(sql_tokens, schema)

                sql_variables = {}
                for variable in example["variables"]:
                    sql_variables[variable["name"]] = {
                        "text": variable["example"],
                        "type": variable["type"],
                    }

                sql_data = SqlData(
                    text=query_tokens,
                    text_with_variables=text_with_variables,
                    variable_tags=tags,
                    sql=sql_tokens,
                    text_variables=text_vars,
                    sql_variables=sql_variables,
                )
                yield sql_data

                # Some questions might have multiple equivalent SQL statements.
                # By default, we just use the first one. TODO(Mark): Use the shortest?
                if not use_all_sql:
                    break

In [15]:
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Official evaluation script for natural language to SQL datasets.

Arguments:
    predictions_filepath (str): Path to a predictions file (in JSON format).
    output_filepath (str): Path to the file where the result of execution is
        saved.
    cache_filepath (str): Path to a JSON file containing a mapping from gold SQL
        queries to cached resulting tables.  Should be ran locally. All filepaths
        above should refer to the local filesystem.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os

import numpy as np
import sqlite3
import timeout_decorator
from tqdm import tqdm

# Maximum allowable timeout for executing predicted and gold queries.
TIMEOUT = 30

# Maximum number of candidates we should consider
MAX_CANDIDATE = 20


# These are substrings of exceptions from sqlite3 that indicate certain classes
# of schema and syntax errors.
SCHEMA_INCOHERENCE_STRINGS = {
        'no such table', 'no such column', 'ambiguous column name'
}
SYNTAX_INCORRECTNESS_STRINGS = {
        'bad syntax', 'unrecognized token', 'incomplete input',
        'misuse of aggregate', 'left and right', 'wrong number of arguments',
        'sub-select returns', '1st order by term does not match any column',
        'no such function', 'clause is required before',
        'incorrect number of bindings', 'datatype mismatch', 'syntax error'
}


def normalize_sql_str(string):
    """Normalizes the format of a SQL string for string comparison."""
    string = string.lower()
    while '  ' in string:
        string = string.replace('  ', ' ')
    string = string.strip()
    string = string.replace('( ', '(').replace(' )', ')')
    string = string.replace(' ;', ';')
    string = string.replace('"', '\'')

    if ';' not in string:
        string += ';'
    return string


def string_acc(s1, s2):
    """Computes string accuracy between two SQL queries."""
    return normalize_sql_str(s1) == normalize_sql_str(s2)


def result_table_to_string(table):
    """Converts a resulting SQL table to a human-readable string."""
    string_val = '\t' + '\n\t'.join(
            [str(row) for row in table[:min(len(table), 5)]]) + '\n'
    if len(table) > 5:
        string_val += '... and %d more rows.\n' % (len(table) - 5)
    return string_val


def try_executing_query(prediction, cursor, case_sensitive=True, verbose=False):
    """Attempts to execute a SQL query against a database given a cursor."""
    exception_str = None

    prediction_str = prediction[:]
    prediction_str = prediction_str.replace(';', '').strip()
#     print('Current prediction:' + prediction_str)

    try:
        if not case_sensitive:
            new_prediction = ''
            last_quote = ''
            for char in prediction:
                new_prediction += char
                if char in {'"', '\''} and not last_quote:
                    last_quote = char
                elif char == last_quote:
                    last_quote = ''
                    new_prediction += ' COLLATE NOCASE'
            prediction = new_prediction

            if verbose:
                print('Executing case-insensitive query:')
                print(new_prediction)
        pred_results = timeout_execute(cursor, prediction)
    except timeout_decorator.timeout_decorator.TimeoutError:
        print('!time out!')
        pred_results = []
        exception_str = 'timeout'
    except (sqlite3.Warning, sqlite3.Error, sqlite3.DatabaseError,
                    sqlite3.IntegrityError, sqlite3.ProgrammingError,
                    sqlite3.OperationalError, sqlite3.NotSupportedError) as e:
        exception_str = str(e).lower()
        pred_results = []

    return pred_results, exception_str


@timeout_decorator.timeout(seconds=TIMEOUT, use_signals=False)
def timeout_execute(cursor, prediction):
    cursor.execute(prediction)
    pred_results = cursor.fetchall()
    pred_results = [list(result) for result in pred_results]
    return pred_results


def find_used_entities_in_string(query, columns, tables):
    """Heuristically finds schema entities included in a SQL query."""
    used_columns = set()
    used_tables = set()

    nopunct_query = query.replace('.', ' ').replace('(', ' ').replace(')', ' ')

    for token in nopunct_query.split(' '):
        if token.lower() in columns:
            used_columns.add(token.lower())
        if token.lower() in tables:
            used_tables.add(token.lower())
    return used_columns, used_tables


def compute_f1(precision, recall):
    if precision + recall > 0.:
        return 2 * precision * recall / (precision + recall)
    else:
        return 0.


def compute_set_f1(pred_set, gold_set):
    """Computes F1 of items given two sets of items."""
    prec = 1.
    if pred_set:
        prec = float(len(pred_set & gold_set)) / len(pred_set)

    rec = 1.
    if gold_set:
        rec = float(len(pred_set & gold_set)) / len(gold_set)
    return compute_f1(prec, rec)


def col_tab_f1(schema, gold_query, predicted_query):
    """Computes the F1 of tables and columns mentioned in the two queries."""

    # Get the schema entities.
    db_columns = set()
    db_tables = set()
    for name, cols in schema.items():
        for col in cols:
            db_columns.add(col['field name'].lower())
        db_tables.add(name.lower())

    # Heuristically find the entities used in the gold and predicted queries.
    pred_columns, pred_tables = find_used_entities_in_string(
            predicted_query, db_columns, db_tables)
    gold_columns, gold_tables = find_used_entities_in_string(
            gold_query, db_columns, db_tables)

    # Compute and return column and table F1.
    return (compute_set_f1(pred_columns,gold_columns), compute_set_f1(pred_tables,gold_tables))


def execute_prediction(prediction, empty_table_cursor, cursor, case_sensitive,verbose):
    """Executes a single example's prediction(s).

    If more than one prediction is available, the most likely executable
    prediction is used as the "official" prediction.

    Args:
        prediction: A dictionary containing information for a single example's
            prediction.
        empty_table_cursor: The cursor to a database containing no records, to be
            used only to determine whether a query is executable in the database.
        cursor: The sqlite3 database cursor to execute queries on.
        case_sensitive: Boolean indicating whether the execution should be case
            sensitive with respect to string values.
        verbose: Whether to print details about what queries are being executed.

    Returns:
        Tuple containing the highest-ranked executable query, the resulting table,
        and any exception string associated with executing this query.
    """

    # Go through predictions in order of probability and test their executability
    # until you get an executable prediction. If you don't find one, just
    # "predict" the most probable one.
    paired_preds_and_scores = zip(prediction['predictions'], prediction['scores'])
    sorted_by_scores = sorted(
            paired_preds_and_scores, key=lambda x: x[1], reverse=True)

    best_prediction = None
    pred_results = None
    exception_str = None

    if len(sorted_by_scores) > MAX_CANDIDATE:
        sorted_by_scores = sorted_by_scores[:MAX_CANDIDATE]

    for i, (pred, _) in enumerate(sorted_by_scores):
        # Try predicting
        if verbose:
            print('Trying to execute query:\n\t' + pred)
            print('... on empty database')
        temp_exception_str = try_executing_query(pred, empty_table_cursor,case_sensitive, verbose)[1]

        if temp_exception_str:
            if i == 0:
                # By default, set the prediction to the first (highest-scoring)
                # one.
                best_prediction = pred

                # Get the actual results
                if verbose:
                    print('... on actual database')
                pred_results, exception_str = try_executing_query(
                        pred, cursor, case_sensitive, verbose)
            if exception_str == 'timeout':
                # Technically, this query didn't have a syntax problem, so
                # continue and set this as the best prediction.
                best_prediction = pred

                if verbose:
                    print('... on actual database')
                pred_results, exception_str = try_executing_query(
                        pred, cursor, case_sensitive, verbose)
                break
        else:
            best_prediction = pred
            exception_str = None

            if verbose:
                print('No exception... on actual database')
            pred_results = try_executing_query(pred, cursor, case_sensitive,verbose)[0]
            break

    return best_prediction, pred_results, exception_str


def _convert_to_unicode_string(value):
    if isinstance(value, int) or isinstance(value, float):
        return str(value).decode('utf-8', 'ignore')
    elif isinstance(value, unicode):
        return value
    elif isinstance(value, str):
        return value.decode('utf-8', 'ignore')
    else:
        return str(value).decode('utf-8', 'ignore')


def execute_predictions(predictions, cache_dict, ofile, case_sensitive, verbose,update_cache):
    """Executes predicted/gold queries and computes performance.

    Writes results to ofile.

    Args:
        predictions: A list of dictionaries defining the predictions made by a
            model.
        cache_dict: A dictionary mapping from gold queries to the resulting tables.
        ofile: A file pointer to be written to.
        case_sensitive: A Boolean indicating whether execution of queries should be
            case sensitive with respect to strings.
        verbose: Whether to print detailed information about evaluation (e.g., for
            debugging).
        update_cache: Whether to execute and cache gold queries.
    """
    # Keeps tracks of metrics throughout all of the evaluation.
    exec_results_same = list()
    string_same = list()

    precision = list()
    recall = list()

    column_f1s = list()
    table_f1s = list()

    conversion_errors = 0

    schema_errors = 0
    syntax_errors = 0
    timeouts = 0

    gold_error = 0

    i = 0

    predictions_iterator = tqdm
    if verbose:
        # Don't use TQDM if verbose: it might mess up the verbose messages
        predictions_iterator = lambda x: x

    for prediction in predictions_iterator(predictions):
        # Attempt to connect to the database for executing.
        try:
            conn = sqlite3.connect(prediction['database_path'])
            conn.text_factory = str
        except sqlite3.OperationalError as e:
            print(e)
            print(prediction['database_path'])
            exit()

        empty_path = prediction['empty_database_path']
        try:
            empty_conn = sqlite3.connect(empty_path)
            empty_conn.text_factory = str
        except sqlite3.OperationalError as e:
            print(e)
            print(empty_path)
            exit()

        empty_cursor = empty_conn.cursor()
        cursor = conn.cursor()

        ofile.write('Example #' + str(i) + '\n')
        printable_utterance = u''.join(
                prediction['utterance']).encode('utf-8').strip()
        ofile.write(printable_utterance + '\n')

        if verbose:
            print('Finding the highest-rated prediction for utterance:\n\t' +
                        printable_utterance)

        best_prediction, pred_results, exception_str = execute_prediction(
                prediction, empty_cursor, cursor, case_sensitive, verbose)

        ofile.write('Predicted query:\n')
        if best_prediction:
            ofile.write('\t' + u''.join(best_prediction).encode('utf-8').strip() +
                                    '\n')
        else:
            ofile.write('ERROR: Cannot write prediction %r\n' % best_prediction)

        # If it didn't execute correctly, check why.
        if exception_str:
            ofile.write(exception_str + '\n')

            found_error = False
            for substring in SCHEMA_INCOHERENCE_STRINGS:
                if substring in exception_str.lower():
                    schema_errors += 1
                    found_error = True
                    break

            if not found_error:
                for substring in SYNTAX_INCORRECTNESS_STRINGS:
                    if substring in exception_str.lower():
                        syntax_errors += 1
                        found_error = True
                        break

            if not found_error and 'timeout' in exception_str:
                ofile.write('Execution (predicted) took too long.\n')
                found_error = True
                timeouts += 1

            # If the error type hasn't been identified, exit and report it.
            if not found_error:
                print(best_prediction)
                print(exception_str)
                exit(1)

            # Predicted table should be empty for all of these cases.
            pred_results = []

        # Compare to gold and update metrics
        gold_query = prediction['gold']

        ofile.write('Gold query:\n')
        ofile.write('\t' + u''.join(gold_query).encode('utf-8').strip() + '\n')

        # Get the gold results
        if cache_dict is None or gold_query not in cache_dict:
            if printable_utterance not in cache_dict:
                if update_cache:
                    if verbose:
                        print('Trying to execute the gold query:\n\t' + gold_query)
                    gold_results, gold_exception_str = try_executing_query(
                            gold_query, cursor, case_sensitive, verbose)

                    if gold_exception_str:
                        gold_error += 1
                        gold_results = []
                    elif cache_dict is not None:
                        cache_dict[u''.join(gold_query).decode('utf-8')] = gold_results
                else:
                    print(gold_query)
                    print(printable_utterance)
                    raise ValueError('Cache miss!')

            else:
                gold_results = cache_dict[cache_dict[printable_utterance]]
        else:
            gold_results = cache_dict[gold_query]

        if best_prediction:
            string_same.append(string_acc(gold_query, best_prediction))
            col_f1, tab_f1 = col_tab_f1(prediction['schema'], gold_query,
                                                                    best_prediction)
            column_f1s.append(col_f1)
            table_f1s.append(tab_f1)
            ofile.write('Column F1: %f\n' % col_f1)
            ofile.write('Table F1: %f\n' % tab_f1)

            if 'order by' in gold_query:
                results_equivalent = pred_results == gold_results
            else:
                pred_set = set()
                gold_set = set()
                for pred in pred_results:
                    if isinstance(pred, list):
                        pred_set.add(u' '.join(
                                [_convert_to_unicode_string(item) for item in pred]))
                    else:
                        pred_set.add(pred)
                for gold in gold_results:
                    if isinstance(gold, list):
                        gold_set.add(u' '.join(
                                [_convert_to_unicode_string(item) for item in gold]))
                    else:
                        gold_set.add(gold)

                results_equivalent = pred_set == gold_set

        else:
            string_same.append(0.)
            ofile.write('Column F1: 0.')
            ofile.write('Table F1: 0.')
            column_f1s.append(0.)
            table_f1s.append(0.)

            conversion_errors += 1

            # Only consider correct if the gold table was empty.
            results_equivalent = gold_results == list()

        exec_results_same.append(int(results_equivalent))
        ofile.write('Execution was correct? ' + str(results_equivalent) + '\n')

        # Add some debugging information about the tables, and compute the
        # precisions.
        if pred_results:
            if not results_equivalent:
                ofile.write('Predicted table:\n')
                ofile.write(result_table_to_string(pred_results))

            precision.append(int(results_equivalent))
        elif best_prediction is None or not results_equivalent:
            ofile.write('Predicted table was EMPTY!\n')

        if gold_results:
            ofile.write('Gold table:\n')
            ofile.write(result_table_to_string(gold_results))

            recall.append(int(results_equivalent))
        else:
            ofile.write('Gold table was EMPTY!\n')

        ofile.write('\n')
        ofile.flush()

        conn.close()
        empty_conn.close()

        i += 1

    # Write the overall metrics to the file.
    num_empty_pred = len(precision)
    num_empty_gold = len(recall)

    precision = np.mean(np.array(precision))
    recall = np.mean(np.array(recall))

    execution_f1 = compute_f1(precision, recall)

    ofile.write('String accuracy: ' +
                            '{0:.2f}'.format(100. * np.mean(np.array(string_same))) + '\n')
    ofile.write('Accuracy: ' +
                            '{0:.2f}'.format(100. * np.mean(np.array(exec_results_same))) +
                            '\n')
    ofile.write('Precision: ' + '{0:.2f}'.format(100. * precision) + ' ; ' +
                            str(num_empty_pred) + ' nonempty predicted tables' + '\n')
    ofile.write('Recall: ' + '{0:.2f}'.format(100. * recall) + ' ; ' +
                            str(num_empty_gold) + ' nonempty gold tables' + '\n')
    ofile.write('Execution F1: ' + '{0:.2f}'.format(100. * execution_f1) + '\n')
    ofile.write('Timeout: ' +
                            '{0:.2f}'.format(timeouts * 100. / len(predictions)) + '\n')
    ofile.write('Gold did not execute: ' +
                            '{0:.2f}'.format(gold_error * 100. / len(predictions)) + '\n')
    ofile.write('Average column F1: ' +
                            '{0:.2f}'.format(100. * np.mean(np.array(column_f1s))) + '\n')
    ofile.write('Average table F1: ' +
                            '{0:.2f}'.format(100. * np.mean(np.array(table_f1s))) + '\n')
    ofile.write('Schema errors: ' +
                            '{0:.2f}'.format((schema_errors) * 100. / len(predictions)) +
                            '\n')
    ofile.write('Syntax errors:  ' +
                            '{0:.2f}'.format((syntax_errors) * 100. / len(predictions)) +
                            '\n')
    ofile.write('Conversion errors: ' +
                            '{0:.2f}'.format((conversion_errors * 100.) / len(predictions)) +
                            '\n')


def main(predictions_filepath, output_filepath, cache_filepath, verbose,
                 update_cache):
    # Load the predictions filepath.
    with open(predictions_filepath) as infile:
        predictions = json.load(infile)
    print('Loaded %d predictions.' % len(predictions))

    # Load or create the cache dictionary mapping from gold queries to resulting
    # tables.
    cache_dict = None

    # Only instantiate the cache dict if using Spider.
    print('cache path: ' + cache_filepath)

    basefilename = os.path.basename(predictions_filepath).lower()

    if 'spider' not in basefilename:
        cache_dict = dict()
        if os.path.exists(cache_filepath):
            print('Loading cache from %s' % cache_filepath)
            with open(cache_filepath) as infile:
                cache_dict = json.load(infile)
            print('Loaded %d cached queries' % len(cache_dict))

    # Create the text file that results will be written to.
    with open(output_filepath, 'w') as ofile:
        execute_predictions(predictions, cache_dict, ofile,
                                                'scholar' not in basefilename, verbose, update_cache)

    if 'spider' not in basefilename:
        try:
            cache_str = json.dumps(cache_dict)
            with open(cache_filepath, 'w') as ofile:
                ofile.write(cache_str)
        except UnicodeDecodeError as e:
            print('Could not save the cache dict. Exception:')
            print(e)

In [66]:
def filter_in_suhr(example, verbose=False):
    # Filter out examples with empty gold tables.
    if example['query_exception'] is not None and example['query_exception']:
        return example['query_exception']
    empty = False
    if example['result'] is None or len(example['result'])==0:
        empty = True

    # Filter out examples with a result of [0] and that require a count.
    if (example['result'] == [0] and
      (example['query'].lower().startswith('select count') or
       example['query'].lower().startswith('select distinct count'))):
        empty = True

    # Filter out examples that require copying values that can't be copied.
    prev_value = ''
    last_quote = ''
    utterance = example['question'].lower()
    copiable = True
    in_equality = False
    numerical_value = ''
    handled_prefix = False
    too_many_selects = False
    gold_query = example['query'].lower()
        
    for i, char in enumerate(gold_query):
        # Check that it's only selecting a single table at the top
        if (not handled_prefix and i - 4 >= 0 and gold_query[i - 4:i].lower() == 'from'):
            handled_prefix = True
            if gold_query[:i].count(',') > 0:
                too_many_selects = True
        if char == last_quote:
            last_quote = ''
            prev_value = prev_value.replace('%', '').strip()
            if prev_value not in utterance:
                if verbose:
                    print(prev_value)
                copiable = False
            prev_value = ''
        elif last_quote:
            prev_value += char
        elif char in {'"', '\''}:
            last_quote = char
        if char in {'=', '>', '<'}:
            in_equality = True
            equality_used = False
        elif in_equality:
            if char.isdigit() or char == '.':
                if numerical_value or (not prev_value and gold_query[i - 1] == ' '):
                    numerical_value += char
            if char == ' ' and numerical_value:
                in_equality = False
                if numerical_value not in utterance:
                    if verbose:
                        print(numerical_value)
                    copiable = False
                numerical_value = ''
            if char != ' ':
                equality_used = True
            if char == ' ' and not last_quote and equality_used:
                in_equality = False
    if not copiable:
        return 'not copiable'
    if empty:
        return 'empty result'
    if too_many_selects:
        return 'too many selects'
    return 'pass'

In [5]:
import sys
# the mock-0.3.1 dir contains testcase.py, testutils.py & mock.py
sys.path.append('/home/t-xiaden/workspace/NL2CodeOverData')
from third_party.spider.preprocess.get_tables import dump_db_json_schema

In [6]:
from third_party.spider.process_sql import get_sql
from third_party.spider.preprocess.schema import Schema
%load_ext autoreload
%autoreload 2
def get_schemas_from_json(data):
    db_names = [db['db_id'] for db in data]

    tables = {}
    schemas = {}
    for db in data:
        db_id = db['db_id']
        schema = {} #{'table': [col.lower, ..., ]} * -> __all__
        column_names_original = db['column_names_original']
        table_names_original = db['table_names_original']
        tables[db_id] = {'column_names_original': column_names_original, 'table_names_original': table_names_original}
        for i, tabn in enumerate(table_names_original):
            table = str(tabn.lower())
            cols = [str(col.lower()) for td, col in column_names_original if td == i]
            schema[table] = cols
        schemas[db_id] = schema

    return schemas, db_names, tables

In [7]:
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')

WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
    'sql': "sql",
    'table_unit': "table_unit",
}

COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')

In [8]:
legal_ops = {'max', 'min', 'count', 'sum', 'avg', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists'}
def try_clean_brackets(sql):
    sql = sql.split()
    p1_idx = []
    p = []
    for i, token in enumerate(sql):
        if token == '(':
            if sql[i-1].lower() in legal_ops or \
                sql[i+1].lower() == 'select':
                p1_idx.append([i, 'keep'])
            else:
                p1_idx.append([i, 'remove', ' '.join(sql[i-2:i+4])])
        if token == ')':
            p.append([i,p1_idx[-1]])
            p1_idx = p1_idx[:-1]
    for idx2, idx1 in p:
        if idx1[1] == 'remove':
            sql[idx1[0]] = ''
            sql[idx2] = ''
    sql = ' '.join([x for x in sql if x])
    return sql

In [12]:
datasets = {'atis':['dev'],
            'geography':['train','dev'],
            'restaurants':[str(x) for x in range(10)],
            'scholar':['train','dev'],
            'imdb':[str(x) for x in range(10)],
            'yelp':[str(x) for x in range(10)],
            'advising':['train','dev'],
            'academic':[str(x) for x in range(10)]}
datadir = '/home/t-xiaden/workspace/text2sql-data/data'

In [80]:
conn = sqlite3.connect('/home/t-xiaden/workspace/text2sql-data/data/%s.db'%'imdb')
conn.text_factory = lambda x:str(x, 'latin1')
cursor = conn.cursor()

## main processing
i. The purpose of the filtering is to remove those cases where current model (trained on spider) always fails as such structure/operation is never seen in training. So we can have a better set for evaluation

ii. Remove parentheses. Test the execution result, remove the example if the result changes
    
    a. (A AND B) AND (C AND D) is ok, but (A AND B) OR (C AND D) will be removed
    
    b. DISTINCT()

iii. JOIN ON and COND in WHERE, some datasets use different format to represent join on, change spider parser to support it. Also change all types of join (inner, left,…) to join, remove if the execution result changes

iv. Remove examples contain derived table/column, as spider assumes that we only selects column from the database

v. Boolean result, such as SELECT COUNT(*)>0

vi. The original dataset is released with value anonymized. Some value that cannot find reference in question (in any form) will be replaced as LIKE '%', remove such condition

vii. Calculation in sql
"SELECT DISTINCT COURSEalias0.DEPARTMENT , COURSEalias0.NAME , COURSEalias0.NUMBER , SEMESTERalias0.SEMESTER FROM COURSE AS COURSEalias0 , COURSE_OFFERING AS COURSE_OFFERINGalias0 , SEMESTER AS SEMESTERalias0 WHERE COURSEalias0.COURSE_ID = COURSE_OFFERINGalias0.COURSE_ID AND COURSEalias0.DEPARTMENT LIKE '%' AND COURSEalias0.NUMBER BETWEEN 500 AND 500 + 100 AND SEMESTERalias0.SEMESTER IN ( 'FA' , 'WN' ) AND SEMESTERalias0.SEMESTER_ID = COURSE_OFFERINGalias0.SEMESTER AND SEMESTERalias0.YEAR = 2016 ;" 

viii. Multiple columns in COUNT
SELECT COUNT ( DISTINCT COURSEalias1.DEPARTMENT , COURSEalias0.NUMBER ) FROM COURSE

ix. COUNT ( 1 ) -> COUNT( * )

x. COL IN (VAL1, VAL2) -> COL = VAL1 OR COL = VAL2

xi. Filter in suhr paper:
    
    a. Remove not copiable
    
    b. Remove empty. keep empty but report non-empty exec results
    
    c. Remove too many select. Keep query with multiple select, report oracle select result

In [16]:
final_data = {}
for dataset_id in datasets:
    print(dataset_id)
    db = dump_db_json_schema('/home/t-xiaden/workspace/text2sql-data/data/%s.db'%dataset_id,dataset_id)
    db = [db]
    schemas, db_names, tables = get_schemas_from_json(db)
    schema = schemas[dataset_id]
    table = tables[dataset_id]
    schema = Schema(schema, table)
    schema_1 = read_dataset_schema('/home/t-xiaden/workspace/text2sql-data/data/%s-schema.csv'%dataset_id)
    processed_data = []

    error_sqls = []
    with sqlite3.connect('/home/t-xiaden/workspace/text2sql-data/data/%s.db'%dataset_id) as source:
        dest = sqlite3.connect(':memory:')
        source.backup(dest)
        conn = dest
        conn.text_factory = lambda x:str(x, 'latin1')
        cursor = conn.cursor()

    with open(os.path.join(datadir, '{}.json'.format(dataset_id))) as f:
        data = json.load(f)
        pairs = list()

    # The UMichigan data is split by anonymized queries, where values are
    # anonymized but table/column names are not. However, our experiments are
    # performed on the original splits of the data.
    for query in tqdm(data):
        # Take the first SQL query only. From their Github documentation:
        # "Note - we only use the first query, but retain the variants for
        #  completeness"
        anonymized_sql = query['sql'][0]

        # It's also associated with a number of natural language examples, which
        # also contain anonymous tokens. Save the de-anonymized utterance and query.
        for example in query['sentences']:
            if example['question-split'] not in datasets[dataset_id]:
                continue

            nl = example['text']
            sql = anonymized_sql
            sql_tokens = clean_and_split_sql(sql)
    #         sql_tokens = clean_unneeded_aliases(sql_tokens)
            sql_tokens = resolve_primary_keys_in_schema(sql_tokens, schema_1)
            sql = ' '.join(sql_tokens).replace(' . ','.')

            # Go through the anonymized values and replace them in both the natural
            # language and the SQL.
            #
            # It's very important to sort these in descending order. If one is a
            # substring of the other, it shouldn't be replaced first lest it ruin the
            # replacement of the superstring.
            for variable_name, value in sorted(
                  example['variables'].items(), key=lambda x: len(x[0]), reverse=True):
                if not value:
                # TODO(alanesuhr) While the Michigan repo says to use a - here, the
                # thing that works is using a % and replacing = with LIKE.
                #
                # It's possible that I should remove such clauses from the SQL, as
                # long as they lead to the same table result. They don't align well
                # to the natural language at least.
                #
                # See: https://github.com/jkkummerfeld/text2sql-data/tree/master/data
                    value = '%'

                nl = nl.replace(variable_name, value)
                sql = sql.replace(variable_name, value)

            # In the case that we replaced an empty anonymized value with %, make it
            # compilable new allowing equality with any string.
            sql = re.sub('= \'?%\'?','LIKE \'%\'',sql)
            # remove the wildcard matching
            new_sql = sql
            all_wildcards = re.findall(r'(AND|OR)? (\w+\.\w+ (=|LIKE) (\'|\")?%(\'|\")?) (AND|OR)?',new_sql)
            for before_and, wildcard,_,_,_,after_and in all_wildcards:
                if before_and:
                    wildcard = before_and +' '+wildcard
                elif after_and:
                    wildcard = wildcard +' '+after_and
                new_sql = new_sql.replace(wildcard, '')
                
            # col IN (val, +) -> col = val OR col = val
            for in_clause, col,_, vals,_ in re.findall(r'((\w+\.\w+) IN (\(( (\'.+\'|[0-9\.]+) ,?)+\)))',new_sql):
                new_clause = []
                for val in vals.split(','):
                    val = val.strip()
                    new_clause.append(col+' = '+val)
                new_clause = ' OR '.join(new_clause)
                new_sql = new_sql.replace(in_clause, new_clause)
                
            # COUNT ( 1 ) -> COUNT(*)
            new_sql = new_sql.replace('COUNT ( 1 )', 'COUNT ( * )')
            
            
            new_sql = new_sql.replace('<>', '!=')
            
            new_sql = re.sub(r'(INNER JOIN)|(LEFT JOIN)|(OUTER JOIN)|(LEFT OUTER JOIN)','JOIN',new_sql, flags=re.I)
            new_sql = try_clean_brackets(new_sql)
            
            pred_results, exception_str = try_executing_query(sql, cursor, False, False)
            if new_sql != sql:
                new_pred_results, new_exception_str = try_executing_query(new_sql, cursor, False, False)
            else:
                new_pred_results = pred_results
                new_exception_str = exception_str
            try:
                if new_sql != sql:
                    # keep only queries that have the same result after tranformation
                    assert not exception_str and len(pred_results) == len(new_pred_results) and pred_results == new_pred_results, 'result change'
                processed_data.append({
                "db_id": dataset_id,
                "query": sql,
                "query_toks": sql.split(),
                "question": nl,
                "question_toks": nl.split(),
                "sql": get_sql(schema, new_sql),
                "result": pred_results,
                "query_exception": exception_str}) 
                
            except Exception as e:
                if 'Assert' in str(e):
                    break
                error_sqls.append({
                "db_id": dataset_id,
                "query": sql,
                "query_toks": sql.split(),
                "question": nl,
                "question_toks": nl.split(),
                "sql": {},
                "error": str(e),
                "result": pred_results,
                "query_exception": exception_str})
    final_data[dataset_id] = [processed_data, error_sqls]
    print(len(processed_data)+len(error_sqls), len(processed_data), len(error_sqls))
    conn.close()

atis


 15%|█▍        | 139/947 [02:51<07:57,  1.69it/s] 

!time out!


 21%|██▏       | 202/947 [04:13<30:06,  2.42s/it]  

!time out!


 29%|██▉       | 274/947 [04:54<17:38,  1.57s/it]

!time out!


 29%|██▉       | 277/947 [05:54<1:19:51,  7.15s/it]

!time out!


 48%|████▊     | 456/947 [07:06<09:35,  1.17s/it]  

!time out!


 68%|██████▊   | 648/947 [08:02<10:21,  2.08s/it]

!time out!


 87%|████████▋ | 828/947 [08:59<10:00,  5.04s/it]

!time out!


 88%|████████▊ | 836/947 [10:15<20:52, 11.29s/it]

!time out!


 89%|████████▊ | 840/947 [10:47<21:35, 12.11s/it]

!time out!


 89%|████████▉ | 843/947 [10:48<07:48,  4.51s/it]

!time out!


 89%|████████▉ | 844/947 [11:49<36:34, 21.30s/it]

!time out!


 89%|████████▉ | 847/947 [12:20<27:45, 16.65s/it]

!time out!


 90%|████████▉ | 849/947 [12:21<13:48,  8.45s/it]

!time out!


 90%|████████▉ | 850/947 [13:22<38:55, 24.08s/it]

!time out!


 90%|█████████ | 854/947 [13:56<24:21, 15.72s/it]

!time out!


 91%|█████████ | 864/947 [14:50<15:05, 10.91s/it]

!time out!


 92%|█████████▏| 868/947 [15:29<17:09, 13.03s/it]

!time out!


 92%|█████████▏| 869/947 [15:29<12:03,  9.28s/it]

!time out!


 92%|█████████▏| 870/947 [16:30<31:37, 24.64s/it]

!time out!


 92%|█████████▏| 872/947 [17:01<26:40, 21.34s/it]

!time out!


100%|██████████| 947/947 [17:04<00:00,  1.08s/it]
  0%|          | 0/246 [00:00<?, ?it/s]

486 474 12
geography


100%|██████████| 246/246 [02:36<00:00,  1.57it/s]
  0%|          | 0/23 [00:00<?, ?it/s]

598 584 14
restaurants


100%|██████████| 23/23 [01:37<00:00,  4.22s/it]


378 378 0
scholar


 10%|█         | 20/193 [01:29<18:28,  6.41s/it]

!time out!
!time out!
!time out!


 11%|█         | 21/193 [03:31<1:57:53, 41.12s/it]

!time out!


 20%|█▉        | 38/193 [06:17<49:17, 19.08s/it]  

!time out!
!time out!
!time out!
!time out!


 20%|██        | 39/193 [08:19<2:08:16, 49.98s/it]

!time out!


 24%|██▍       | 46/193 [08:39<14:10,  5.79s/it]  

!time out!
!time out!
!time out!
!time out!


 24%|██▍       | 47/193 [11:11<2:01:16, 49.84s/it]

!time out!


 45%|████▍     | 86/193 [17:38<04:15,  2.39s/it]  

!time out!


 45%|████▌     | 87/193 [18:39<35:18, 19.98s/it]

!time out!


 50%|████▉     | 96/193 [19:05<04:09,  2.57s/it]

!time out!
!time out!


 50%|█████     | 97/193 [20:36<46:51, 29.29s/it]

!time out!


 59%|█████▊    | 113/193 [21:16<02:26,  1.84s/it]

!time out!


 59%|█████▉    | 114/193 [22:18<25:50, 19.63s/it]

!time out!


 60%|█████▉    | 115/193 [22:18<18:01, 13.87s/it]

!time out!


 60%|██████    | 116/193 [23:19<35:57, 28.02s/it]

!time out!


 63%|██████▎   | 121/193 [24:00<18:58, 15.81s/it]

!time out!


 63%|██████▎   | 122/193 [24:31<23:50, 20.15s/it]

!time out!


 75%|███████▌  | 145/193 [25:18<07:33,  9.44s/it]

!time out!


 77%|███████▋  | 148/193 [26:25<13:13, 17.63s/it]

!time out!


 78%|███████▊  | 151/193 [27:42<15:33, 22.23s/it]

!time out!


 79%|███████▉  | 152/193 [27:42<10:41, 15.66s/it]

!time out!


 79%|███████▉  | 153/193 [28:43<19:25, 29.15s/it]

!time out!


 83%|████████▎ | 160/193 [29:18<06:31, 11.86s/it]

!time out!


100%|██████████| 193/193 [29:18<00:00,  9.11s/it]


599 581 18
imdb


100%|██████████| 89/89 [01:33<00:00,  1.05s/it]


131 128 3
yelp


 58%|█████▊    | 64/110 [01:09<00:28,  1.64it/s]

!time out!
!time out!


 59%|█████▉    | 65/110 [03:08<26:59, 35.99s/it]

!time out!


 60%|██████    | 66/110 [04:09<31:49, 43.40s/it]

!time out!


 68%|██████▊   | 75/110 [04:19<01:28,  2.52s/it]

!time out!


 69%|██████▉   | 76/110 [05:19<11:18, 19.96s/it]

!time out!


 70%|███████   | 77/110 [05:20<07:48, 14.19s/it]

!time out!


 71%|███████   | 78/110 [06:21<15:01, 28.16s/it]

!time out!


 72%|███████▏  | 79/110 [06:23<10:30, 20.33s/it]

!time out!


 73%|███████▎  | 80/110 [07:24<16:12, 32.43s/it]

!time out!


100%|██████████| 110/110 [09:15<00:00,  5.05s/it]


128 122 6
advising


  1%|          | 2/205 [00:23<46:12, 13.66s/it]

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


  1%|▏         | 3/205 [08:04<8:17:35, 147.80s/it]

!time out!


  7%|▋         | 15/205 [09:25<30:08,  9.52s/it]  

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


  8%|▊         | 16/205 [16:34<7:06:59, 135.55s/it]

!time out!


 11%|█         | 23/205 [17:35<1:00:13, 19.85s/it] 

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 12%|█▏        | 24/205 [23:44<6:15:35, 124.50s/it]

!time out!


 26%|██▌       | 53/205 [28:43<31:11, 12.31s/it]   

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 26%|██▋       | 54/205 [34:52<4:59:47, 119.12s/it]

!time out!


 41%|████▏     | 85/205 [38:31<12:39,  6.33s/it]   

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 42%|████▏     | 86/205 [46:43<5:01:14, 151.89s/it]

!time out!


 48%|████▊     | 99/205 [50:07<19:21, 10.96s/it]   

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 49%|████▉     | 100/205 [56:46<3:42:51, 127.35s/it]

!time out!


 54%|█████▎    | 110/205 [57:43<14:07,  8.92s/it]   

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 54%|█████▍    | 111/205 [1:06:56<4:29:31, 172.03s/it]

!time out!


 60%|██████    | 124/205 [1:08:14<11:38,  8.63s/it]   

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 61%|██████    | 125/205 [1:13:21<2:10:48, 98.11s/it]

!time out!


 71%|███████   | 146/205 [1:15:45<07:46,  7.90s/it]  

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 72%|███████▏  | 147/205 [1:23:26<2:18:56, 143.73s/it]

!time out!


 93%|█████████▎| 190/205 [1:30:03<02:15,  9.01s/it]   

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 93%|█████████▎| 191/205 [1:36:12<27:17, 116.97s/it]

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 94%|█████████▎| 192/205 [1:41:19<37:42, 174.02s/it]

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 94%|█████████▍| 193/205 [1:50:01<55:40, 278.39s/it]

!time out!


 95%|█████████▌| 195/205 [1:50:20<23:33, 141.39s/it]

!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!
!time out!


 96%|█████████▌| 196/205 [1:56:29<31:26, 209.57s/it]

!time out!


100%|██████████| 205/205 [2:05:40<00:00, 36.78s/it] 


2858 1948 910
academic


 10%|█         | 19/185 [00:56<26:47,  9.68s/it]

!time out!


 27%|██▋       | 50/185 [01:45<21:26,  9.53s/it]

!time out!


 52%|█████▏    | 97/185 [04:15<01:39,  1.13s/it]

!time out!


 53%|█████▎    | 98/185 [05:16<27:37, 19.05s/it]

!time out!


 56%|█████▌    | 104/185 [05:27<04:41,  3.48s/it]

!time out!


 57%|█████▋    | 105/185 [06:28<27:34, 20.68s/it]

!time out!


 70%|██████▉   | 129/185 [06:46<00:31,  1.77it/s]

!time out!


 70%|███████   | 130/185 [07:47<17:05, 18.64s/it]

!time out!


 75%|███████▍  | 138/185 [08:20<08:12, 10.47s/it]

!time out!


 77%|███████▋  | 142/185 [08:24<02:16,  3.18s/it]

!time out!


 77%|███████▋  | 143/185 [09:24<14:19, 20.46s/it]

!time out!


100%|██████████| 185/185 [12:37<00:00,  4.10s/it]


196 181 15


In [124]:
import pickle
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20200607/cached_processed_michigan_data.pkl', 'wb') as f:
    pickle.dump(final_data, f)

In [17]:
def all_fix(sql):
    new_sql = sql
    all_wildcards = re.findall(r'(AND|OR)? (\w+\.\w+ (=|LIKE) (\'|\")?%(\'|\")?) (AND|OR)?',new_sql)
    for before_and, wildcard,_,_,_,after_and in all_wildcards:
        if before_and:
            wildcard = before_and +' '+wildcard
        elif after_and:
            wildcard = wildcard +' '+after_and
        new_sql = new_sql.replace(wildcard, '')

    # col IN (val, +) -> col = val OR col = val
    for in_clause, col,_, vals,_ in re.findall(r'((\w+\.\w+) IN (\(( (\'.+\'|[0-9\.]+) ,?)+\)))',new_sql):
        new_clause = []
        for val in vals.split(','):
            val = val.strip()
            new_clause.append(col+' = '+val)
        new_clause = ' OR '.join(new_clause)
        new_sql = new_sql.replace(in_clause, new_clause)

    # COUNT ( 1 ) -> COUNT(*)
    new_sql = new_sql.replace('COUNT ( 1 )', 'COUNT ( * )')


    new_sql = new_sql.replace('<>', '!=')

    new_sql = re.sub(r'(INNER JOIN)|(LEFT JOIN)|(OUTER JOIN)|(LEFT OUTER JOIN)','JOIN',new_sql, flags=re.I)
    new_sql = try_clean_brackets(new_sql)
    return new_sql

In [122]:
errors_for_debug = []
for dataset_id, data in final_data.items():
    print(dataset_id)
    print(len(data[0])+len(data[1]))
    print(len(data[0]), len([x for x in data[0] if filter_in_suhr(x)=='pass']))
    filter_result = collections.Counter()
    filter_errors = [filter_in_suhr(x) for x in data[0] if ' IS ' not in x['query']]
    filter_result.update(filter_errors)
    display(filter_result.most_common(100))
    parse_errors = collections.Counter()
    parse_errors.update([x['error'] for x in data[1]])
    errors_for_debug.extend([x['query'] for x in data[1] if x['error']=="'as'"])
    display(parse_errors.most_common(100))
    print('-'*20)

atis
486
474 286


[('pass', 275),
 ('not copiable', 143),
 ('empty result', 39),
 ('too many selects', 5)]

[('result change', 12)]

--------------------
geography
598
584 525


[('pass', 525), ('not copiable', 43), ('empty result', 16)]

[('Error col: as', 10),
 ('Error col: /', 1),
 ("'as'", 1),
 ('Error col: all', 1),
 ('result change', 1)]

--------------------
restaurants
378
378 39


[('not copiable', 210),
 ('too many selects', 114),
 ('pass', 39),
 ('empty result', 15)]

[]

--------------------
scholar
599
581 396


[('pass', 396),
 ('too many selects', 118),
 ('empty result', 32),
 ('not copiable', 28),
 ('timeout', 4),
 ('no such function: curdate', 3)]

[('result change', 14),
 ('Error col: year0', 1),
 ('Error condition: idx: 24, tok: ==', 1),
 ('Error col: as', 1),
 ("'field'", 1)]

--------------------
imdb
131
128 111


[('pass', 111),
 ('empty result', 13),
 ('too many selects', 3),
 ('not copiable', 1)]

[('result change', 2), ('Error col: as', 1)]

--------------------
yelp
128
122 68


[('pass', 68), ('empty result', 52), ('too many selects', 2)]

[('result change', 5), ("'neighborhood.name'", 1)]

--------------------
advising
2858
1948 281


[('not copiable', 1062),
 ('empty result', 308),
 ('too many selects', 297),
 ('pass', 281)]

[('result change', 322),
 ('Error col: >', 288),
 ('Error col: =', 83),
 ('Error col: +', 77),
 ("'as'", 61),
 ('Error col: year0', 13),
 ('Error col: course_offeringalias0', 11),
 ('Error col: )', 11),
 ('Error col: select', 11),
 ('Error col: ;', 11),
 ('Error col: lower', 11),
 ('Error col: as', 9),
 ('Unexpected quote', 2)]

--------------------
academic
196
181 167


[('pass', 167), ('empty result', 6), ('too many selects', 6), ('timeout', 2)]

[("'as'", 9), ('result change', 6)]

--------------------


In [78]:
# not copiable queries are usually caused by some dataset convention, not synonym
# for atis, caused by date transformations and default dates
display([[x['query'], x['question']] for x in final_data['atis'][0] if filter_in_suhr(x)=='not copiable'][:3])
# for geography, caused by convention such as major city means city with population > 150000
display([[x['query'], x['question']] for x in final_data['geography'][0] if filter_in_suhr(x)=='not copiable'][:3])
# for restaurants, caused by convention such as good restaurants means rating > 2.5
display([[x['query'], x['question']] for x in final_data['restaurants'][0] if filter_in_suhr(x)=='not copiable'][:3])
# for scholar, caused by convention such as last year
display([[x['query'], x['question']] for x in final_data['scholar'][0] if filter_in_suhr(x)=='not copiable'][:3])
# for imdb, only one
display([[x['query'], x['question']] for x in final_data['imdb'][0] if filter_in_suhr(x)=='not copiable'][:3])
# for advising, caused by convention such as next semester is 2016 Fall
display([[x['query'], x['question']] for x in final_data['advising'][0] if filter_in_suhr(x)=='not copiable'][:3])

[["SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT_SERVICE AS AIRPORT_SERVICEalias0 , AIRPORT_SERVICE AS AIRPORT_SERVICEalias1 , CITY AS CITYalias0 , CITY AS CITYalias1 , DATE_DAY AS DATE_DAYalias0 , DAYS AS DAYSalias0 , FLIGHT AS FLIGHTalias0 WHERE ( CITYalias1.CITY_CODE = AIRPORT_SERVICEalias1.CITY_CODE AND CITYalias1.CITY_NAME = 'PHILADELPHIA' AND DATE_DAYalias0.DAY_NUMBER = 20 AND DATE_DAYalias0.MONTH_NUMBER = 1 AND DATE_DAYalias0.YEAR = 1991 AND DAYSalias0.DAY_NAME = DATE_DAYalias0.DAY_NAME AND FLIGHTalias0.FLIGHT_DAYS = DAYSalias0.DAYS_CODE AND FLIGHTalias0.TO_AIRPORT = AIRPORT_SERVICEalias1.AIRPORT_CODE ) AND CITYalias0.CITY_CODE = AIRPORT_SERVICEalias0.CITY_CODE AND CITYalias0.CITY_NAME = 'DENVER' AND FLIGHTalias0.FROM_AIRPORT = AIRPORT_SERVICEalias0.AIRPORT_CODE ;",
  'what flights are available tomorrow from DENVER to PHILADELPHIA'],
 ["SELECT DISTINCT FLIGHTalias0.FLIGHT_ID FROM AIRPORT_SERVICE AS AIRPORT_SERVICEalias0 , AIRPORT_SERVICE AS AIRPORT_SERVICEalias1 , CITY AS

[["SELECT LAKEalias0.LAKE_NAME FROM LAKE AS LAKEalias0 WHERE LAKEalias0.AREA > 750 AND LAKEalias0.STATE_NAME = 'michigan' ;",
  'name the major lakes in michigan'],
 ['SELECT HIGHLOWalias0.HIGHEST_POINT , HIGHLOWalias0.STATE_NAME FROM HIGHLOW AS HIGHLOWalias0 WHERE HIGHLOWalias0.LOWEST_ELEVATION = 0 ;',
  'what is the highest point in each state whose lowest point is sea level'],
 ['SELECT COUNT ( CITYalias0.CITY_NAME ) FROM CITY AS CITYalias0 WHERE CITYalias0.POPULATION > 150000 ;',
  'how many major cities are there']]

[["SELECT LOCATIONalias0.HOUSE_NUMBER , RESTAURANTalias0.NAME FROM LOCATION AS LOCATIONalias0 , RESTAURANT AS RESTAURANTalias0 WHERE LOCATIONalias0.CITY_NAME = 'bethel island' AND LOCATIONalias0.STREET_NAME = 'bethel island rd' AND RESTAURANTalias0.ID = LOCATIONalias0.RESTAURANT_ID AND RESTAURANTalias0.RATING > 2.5 ;",
  'what are some good restaurants on bethel island rd in bethel island ?'],
 ["SELECT LOCATIONalias0.HOUSE_NUMBER , RESTAURANTalias0.NAME FROM LOCATION AS LOCATIONalias0 , RESTAURANT AS RESTAURANTalias0 WHERE LOCATIONalias0.CITY_NAME = 'bethel island' AND LOCATIONalias0.STREET_NAME = 'bethel island rd' AND RESTAURANTalias0.ID = LOCATIONalias0.RESTAURANT_ID AND RESTAURANTalias0.RATING > 2.5 ;",
  'give me some good restaurants on bethel island rd in bethel island ?'],
 ["SELECT LOCATIONalias0.HOUSE_NUMBER , RESTAURANTalias0.NAME FROM LOCATION AS LOCATIONalias0 , RESTAURANT AS RESTAURANTalias0 WHERE LOCATIONalias0.CITY_NAME = 'bethel island' AND LOCATIONalias0.STREET_NAME 

[["SELECT DISTINCT PAPERalias0.PAPERID FROM DATASET AS DATASETalias0 , PAPER AS PAPERalias0 , PAPERDATASET AS PAPERDATASETalias0 , VENUE AS VENUEalias0 WHERE DATASETalias0.DATASETNAME = 'RGB-D Object Dataset' AND PAPERDATASETalias0.DATASETID = DATASETalias0.DATASETID AND PAPERalias0.PAPERID = PAPERDATASETalias0.PAPERID AND PAPERalias0.TITLE = 'Class consistent multi-modal fusion with binary features' AND PAPERalias0.YEAR = 2016 AND VENUEalias0.VENUEID = PAPERalias0.VENUEID AND VENUEalias0.VENUENAME = 'CVPR' ;",
  "What papers were published at CVPR '16 about Class consistent multi-modal fusion with binary features applied to RGB-D Object Dataset ?"],
 ["SELECT DISTINCT PAPERalias0.PAPERID FROM KEYPHRASE AS KEYPHRASEalias0 , PAPER AS PAPERalias0 , PAPERKEYPHRASE AS PAPERKEYPHRASEalias0 , VENUE AS VENUEalias0 WHERE KEYPHRASEalias0.KEYPHRASENAME = 'semantic spaces' AND PAPERKEYPHRASEalias0.KEYPHRASEID = KEYPHRASEalias0.KEYPHRASEID AND PAPERalias0.PAPERID = PAPERKEYPHRASEalias0.PAPERID AND

[["SELECT MOVIEalias0.TITLE FROM COMPANY AS COMPANYalias0 , COPYRIGHT AS COPYRIGHTalias0 , MOVIE AS MOVIEalias0 WHERE COMPANYalias0.NAME = 'company_name0' AND COPYRIGHTalias0.CID = COMPANYalias0.ID AND MOVIEalias0.MID = COPYRIGHTalias0.MSID AND MOVIEalias0.RELEASE_YEAR > 2010 ;",
  'Find all movies produced by " Walt Disney " after 2010']]

[["SELECT DISTINCT COURSEalias0.DEPARTMENT , COURSEalias0.NAME , COURSEalias0.NUMBER FROM COURSE AS COURSEalias0 , COURSE_OFFERING AS COURSE_OFFERINGalias0 , PROGRAM_COURSE AS PROGRAM_COURSEalias0 , SEMESTER AS SEMESTERalias0 WHERE COURSEalias0.COURSE_ID = COURSE_OFFERINGalias0.COURSE_ID AND PROGRAM_COURSEalias0.CATEGORY LIKE '%ULCS%' AND PROGRAM_COURSEalias0.COURSE_ID = COURSEalias0.COURSE_ID AND SEMESTERalias0.SEMESTER = 'FA' AND SEMESTERalias0.SEMESTER_ID = COURSE_OFFERINGalias0.SEMESTER AND SEMESTERalias0.YEAR = 2016 GROUP BY COURSEalias0.DEPARTMENT , COURSEalias0.NUMBER ;",
  'What classes next semester are available as ULCS ?'],
 ["SELECT DISTINCT COURSEalias0.DEPARTMENT , COURSEalias0.NAME , COURSEalias0.NUMBER FROM COURSE AS COURSEalias0 , COURSE_OFFERING AS COURSE_OFFERINGalias0 , PROGRAM_COURSE AS PROGRAM_COURSEalias0 , SEMESTER AS SEMESTERalias0 WHERE COURSEalias0.COURSE_ID = COURSE_OFFERINGalias0.COURSE_ID AND PROGRAM_COURSEalias0.CATEGORY LIKE '%ULCS%' AND PROGRAM_COURSEal

In [125]:
for dataset_id, data in final_data.items():
    data_to_dump = []
    db = dump_db_json_schema('/home/t-xiaden/workspace/text2sql-data/data/%s.db'%dataset_id,dataset_id)
    db = [db]
#     with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%dataset_id, 'w') as f:
#         json.dump(db, f)
    for example in data[0]:
        if filter_in_suhr(example) in ['pass', 'empty result', 'too many selects']:
            fixed_sql = all_fix(example['query'])
            # remove "is null, is not null", only affect few examples in atis
            if ' IS ' in fixed_sql:
                continue
            data_to_dump.append({
                "db_id": dataset_id,
                "query": fixed_sql,
                "query_toks": fixed_sql.split(),
                "question": example['question'],
                "question_toks": example['question_toks'],
                "sql": example['sql']})
    print(dataset_id, len(data_to_dump))
    with open('/home/t-xiaden/workspace/featurestorage/data/spider-20200607/%s_dev.json'%dataset_id, 'w') as f:
        json.dump(data_to_dump, f)

atis 319
geography 541
restaurants 168
scholar 546
imdb 127
yelp 122
advising 886
academic 179


# create index when possible

In [93]:
for dataset_id in datasets:
    print(dataset_id)
    with sqlite3.connect('/home/t-xiaden/workspace/text2sql-data/data/%s.db'%dataset_id) as conn:
        conn.text_factory = lambda x:str(x, 'latin1')
        cursor = conn.cursor()
        meta = cursor.execute("SELECT * FROM sqlite_master").fetchall()
        tables = [x[2] for x in meta if x[0]=='table']
        try:
            indexes = [(x[2],re.findall(r'`\w+`',x[4])[0].strip('`')) for x in meta if x[0]=='index' and x[4] is not None]
        except:
            indexes = [(x[2],re.findall(r'"\w+"',x[4])[0].strip('"')) for x in meta if x[0]=='index' and x[4] is not None]
        for table in tables:
            columns = cursor.execute("pragma table_info(%s)"%table).fetchall()
            columns = [x[1] for x in columns if x[-1]==0]
            for x in columns:
                index_name = 'idx_{}_{}'.format(table,x)
                print(index_name)
                try:
                    cursor.execute("CREATE INDEX IF NOT EXISTS \"{}\" ON \"{}\" (`{}`)".format(index_name, table, x))
                    cursor.execute("CREATE INDEX IF NOT EXISTS \"{}\" ON \"{}\" (`{}` COLLATE NOCASE)".format(index_name+'_nocase', table, x))
                except Exception as e:
                    print(str(e))

atis
idx_aircraft_aircraft_code
idx_aircraft_aircraft_description
idx_aircraft_manufacturer
idx_aircraft_basic_type
idx_aircraft_engines
idx_aircraft_propulsion
idx_aircraft_wide_body
idx_aircraft_wing_span
idx_aircraft_length
idx_aircraft_weight
idx_aircraft_capacity
idx_aircraft_pay_load
idx_aircraft_cruising_speed
idx_aircraft_range_miles
idx_aircraft_pressurized
idx_airline_airline_code
idx_airline_airline_name
idx_airline_note
idx_airport_airport_code
idx_airport_airport_name
idx_airport_airport_location
idx_airport_state_code
idx_airport_country_name
idx_airport_time_zone_code
idx_airport_minimum_connect_time
idx_airport_service_city_code
idx_airport_service_airport_code
idx_airport_service_miles_distant
idx_airport_service_direction
idx_airport_service_minutes_distant
idx_city_city_code
idx_city_city_name
idx_city_state_code
idx_city_country_name
idx_city_time_zone_code
idx_class_of_service_rank
idx_class_of_service_class_description
idx_code_description_description
idx_compartm

idx_OFFERING_INSTRUCTOR_OFFERING_ID
idx_OFFERING_INSTRUCTOR_INSTRUCTOR_ID
idx_PROGRAM_name
idx_PROGRAM_college
idx_PROGRAM_introduction
idx_PROGRAM_COURSE_workload
idx_PROGRAM_REQUIREMENT_min_credit
idx_PROGRAM_REQUIREMENT_additional_req
idx_SEMESTER_semester
idx_SEMESTER_year
idx_STUDENT_lastname
idx_STUDENT_firstname
idx_STUDENT_program_id
idx_STUDENT_declare_major
idx_STUDENT_total_credit
idx_STUDENT_total_gpa
idx_STUDENT_entered_as
idx_STUDENT_admit_term
idx_STUDENT_predicted_graduation_semester
idx_STUDENT_degree
idx_STUDENT_minor
idx_STUDENT_internship
idx_STUDENT_RECORD_semester
idx_STUDENT_RECORD_grade
idx_STUDENT_RECORD_how
idx_STUDENT_RECORD_transfer_source
idx_STUDENT_RECORD_repeat_term
idx_STUDENT_RECORD_test_id
idx_STUDENT_RECORD_offering_id
academic
idx_author_name
idx_author_oid
idx_author_homepage
idx_author_photo
idx_conference_name
idx_conference_full_name
idx_conference_homepage
idx_domain_name
idx_ids_exist
idx_journal_name
idx_journal_full_name
idx_journal_homepage

In [71]:
cursor.execute("pragma table_info(%s)"%"paper").fetchall()

[(0, 'paperId', 'integer', 1, None, 1),
 (1, 'title', 'varchar(300)', 0, 'NULL', 0),
 (2, 'venueId', 'integer', 0, 'NULL', 0),
 (3, 'year', 'integer', 0, 'NULL', 0),
 (4, 'numCiting', 'integer', 0, 'NULL', 0),
 (5, 'numCitedBy', 'integer', 0, 'NULL', 0),
 (6, 'journalId', 'integer', 0, 'NULL', 0)]

# Rebuild foreign key

In [119]:
dbs = {}
for dataset_id in datasets:
    print(dataset_id)
    db = dump_db_json_schema('/home/t-xiaden/workspace/text2sql-data/data/%s.db'%dataset_id,dataset_id)
    db = [db]
    dbs[dataset_id] = db
fk_key = {}

atis
geography
restaurants
scholar
imdb
yelp
advising
academic


In [174]:
db = dbs['atis'][0]
column_mapping = {}
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
    if col_name not in column_mapping:
        column_mapping[col_name] = []
    column_mapping[col_name].append([0, i, db['table_names_original'][tab_id]])

In [145]:
# foreign key relation for atis
fk_key['atis'] = {
    'aircraft_code': [[1, 1, 'aircraft'], [0, 55, 'equipment_sequence']],
    'airline_code': [[1, 16, 'airline'], [0, 81, 'flight']],
    'airport_code': [[1, 19, 'airport'],
                    [0, 27, 'airport_service'],
                    [0, 110, 'ground_service']],
    'state_code': [[0, 22, 'airport'], [0, 33, 'city'], [1, 123, 'state']],
    'country_name': [[0, 23, 'airport'], [0, 34, 'city'], [1, 125, 'state']],
    'time_zone_code': [[0, 24, 'airport'],
                    [0, 35, 'city'],
                    [1, 129, 'time_zone']],
    'city_code': [[0, 26, 'airport_service'],
                [1, 31, 'city'],
                [0, 109, 'ground_service']],
    'booking_class': [[1, 36, 'class_of_service'], [0, 66, 'fare_basis']],
    'code': [[0, 39, 'code_description']],
    'compartment': [[1, 41, 'compartment_class'], [0, 107, 'food_service']],
    'class_type': [[1, 42, 'compartment_class'], [0, 67, 'fare_basis']],
    'month_number': [[0, 43, 'date_day'], [1, 113, 'month']],
    'day_number': [[0, 44, 'date_day']],
    'year': [[0, 45, 'date_day']],
    'day_name': [[0, 46, 'date_day'], [1, 48, 'days']],
    'days_code': [[0, 47, 'days']],
    'main_airline': [[1, 16, 'airline'], [0, 49, 'dual_carrier']],
    'low_flight_number': [[1, 82, 'flight'], [0, 50, 'dual_carrier']],
    'high_flight_number': [[1, 82, 'flight'], [0, 51, 'dual_carrier']],
    'dual_airline': [[1, 16, 'airline'], [0, 52, 'dual_carrier']],
    'aircraft_code_sequence': [[1, 54, 'equipment_sequence'], [0, 83, 'flight']],
    'fare_id': [[1, 56, 'fare'], [0, 90, 'flight_fare']],
    'from_airport': [[1, 19, 'airport'], [0, 57, 'fare'], [0, 76, 'flight']],
    'to_airport': [[1, 19, 'airport'], [0, 58, 'fare'], [0, 77, 'flight']],
    'fare_basis_code': [[0, 59, 'fare'], [1, 65, 'fare_basis']],
    'fare_airline': [[1, 16, 'airline'], [0, 60, 'fare']],
    'restriction_code': [[0, 61, 'fare'], [1, 115, 'restriction']],
    'basis_days': [[1, 47, 'days'], [0, 73, 'fare_basis']],
    'flight_id': [[1, 74, 'flight'],
                    [0, 89, 'flight_fare'],
                    [0, 91, 'flight_leg'],
                    [0, 94, 'flight_stop']],
    'flight_days': [[1, 47, 'days'], [0, 75, 'flight']],
    'departure_time': [[0, 78, 'flight'], [0, 101, 'flight_stop']],
    'arrival_time': [[0, 79, 'flight'], [0, 98, 'flight_stop']],
    'airline_flight': [[0, 80, 'flight']],
    'flight_number': [[0, 82, 'flight']],
    'meal_code': [[0, 84, 'flight'], [1, 105, 'food_service']],
    'leg_flight': [[1, 74, 'flight'], [0, 93, 'flight_leg']],
    'stop_days': [[1, 47, 'days'], [0, 96, 'flight_stop']],
    'stop_airport': [[1, 19, 'airport'], [0, 97, 'flight_stop']],
    'arrival_airline': [[1, 16, 'airline'], [0, 99, 'flight_stop']],
    'arrival_flight_number': [[1, 82, 'flight'], [0, 100, 'flight_stop']],
    'departure_airline': [[1, 16, 'airline'], [0, 102, 'flight_stop']],
    'departure_flight_number': [[1, 82, 'flight'], [0, 103, 'flight_stop']],
    'period': [[0, 126, 'time_interval']],
    'begin_time': [[0, 127, 'time_interval']],
    'end_time': [[0, 128, 'time_interval']],
}

In [146]:
for _, cols in fk_key['atis'].items():
    main_key = [x[1] for x in cols if x[0]==1]
    assert len(main_key) <= 1
    if len(main_key) == 1:
        main_key = main_key[0]
        for _, col_id, _ in cols:
            if col_id != main_key:
                dbs['atis'][0]['foreign_keys'].append([col_id, main_key])
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%'atis', 'w') as f:
    json.dump(dbs['atis'], f)

In [177]:
db = dbs['advising'][0]
column_mapping = {}
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
    col_name = col_name.lower()
    if col_name not in column_mapping:
        column_mapping[col_name] = []
    column_mapping[col_name].append([0, i, db['table_names_original'][tab_id]])

In [181]:
fk_key['advising'] = {
 'course_id': [[0, 1, 'AREA'], [1, 7, 'COURSE'], [0, 26, 'COURSE_OFFERING'], [0, 44, 'COURSE_PREREQUISITE'], 
               [0, 45, 'COURSE_TAGS_COUNT'], [0, 79, 'PROGRAM_COURSE'], [0, 103, 'STUDENT_RECORD']],
 'area': [[0, 2, 'AREA']],
 'instructor_id': [[0, 3, 'COMMENT_INSTRUCTOR'], [1, 68, 'INSTRUCTOR'], [0, 73, 'OFFERING_INSTRUCTOR']],
 'student_id': [[0, 4, 'COMMENT_INSTRUCTOR'], [0, 67, 'GSI'], [1, 89, 'STUDENT'], [0, 102, 'STUDENT_RECORD']],
 'score': [[0, 5, 'COMMENT_INSTRUCTOR']],
 'comment_text': [[0, 6, 'COMMENT_INSTRUCTOR']],
 'name': [[0, 8, 'COURSE'], [0, 69, 'INSTRUCTOR'], [0, 75, 'PROGRAM']],
 'department': [[0, 9, 'COURSE']],
 'number': [[0, 10, 'COURSE']],
 'credits': [[0, 11, 'COURSE']],
 'advisory_requirement': [[0, 12, 'COURSE']],
 'enforced_requirement': [[0, 13, 'COURSE']],
 'description': [[0, 14, 'COURSE']],
 'num_semesters': [[0, 15, 'COURSE']],
 'num_enrolled': [[0, 16, 'COURSE']],
 'has_discussion': [[0, 17, 'COURSE']],
 'has_lab': [[0, 18, 'COURSE']],
 'has_projects': [[0, 19, 'COURSE']],
 'has_exams': [[0, 20, 'COURSE']],
 'num_reviews': [[0, 21, 'COURSE']],
 'clarity_score': [[0, 22, 'COURSE']],
 'easiness_score': [[0, 23, 'COURSE']],
 'helpfulness_score': [[0, 24, 'COURSE']],
 'offering_id': [[1, 25, 'COURSE_OFFERING'], [0, 72, 'OFFERING_INSTRUCTOR'], [0, 111, 'STUDENT_RECORD']],
 'semester': [[0, 27, 'COURSE_OFFERING'], [1, 87, 'SEMESTER'], [0, 104, 'STUDENT_RECORD']],
 'section_number': [[0, 28, 'COURSE_OFFERING']],
 'start_time': [[0, 29, 'COURSE_OFFERING']],
 'end_time': [[0, 30, 'COURSE_OFFERING']],
 'monday': [[0, 31, 'COURSE_OFFERING']],
 'tuesday': [[0, 32, 'COURSE_OFFERING']],
 'wednesday': [[0, 33, 'COURSE_OFFERING']],
 'thursday': [[0, 34, 'COURSE_OFFERING']],
 'friday': [[0, 35, 'COURSE_OFFERING']],
 'saturday': [[0, 36, 'COURSE_OFFERING']],
 'sunday': [[0, 37, 'COURSE_OFFERING']],
 'has_final_project': [[0, 38, 'COURSE_OFFERING']],
 'has_final_exam': [[0, 39, 'COURSE_OFFERING']],
 'textbook': [[0, 40, 'COURSE_OFFERING']],
 'class_address': [[0, 41, 'COURSE_OFFERING']],
 'allow_audit': [[0, 42, 'COURSE_OFFERING']],
 'pre_course_id': [[0, 43, 'COURSE_PREREQUISITE']],
 'clear_grading': [[0, 46, 'COURSE_TAGS_COUNT']],
 'pop_quiz': [[0, 47, 'COURSE_TAGS_COUNT']],
 'group_projects': [[0, 48, 'COURSE_TAGS_COUNT']],
 'inspirational': [[0, 49, 'COURSE_TAGS_COUNT']],
 'long_lectures': [[0, 50, 'COURSE_TAGS_COUNT']],
 'extra_credit': [[0, 51, 'COURSE_TAGS_COUNT']],
 'few_tests': [[0, 52, 'COURSE_TAGS_COUNT']],
 'good_feedback': [[0, 53, 'COURSE_TAGS_COUNT']],
 'tough_tests': [[0, 54, 'COURSE_TAGS_COUNT']],
 'heavy_papers': [[0, 55, 'COURSE_TAGS_COUNT']],
 'cares_for_students': [[0, 56, 'COURSE_TAGS_COUNT']],
 'heavy_assignments': [[0, 57, 'COURSE_TAGS_COUNT']],
 'respected': [[0, 58, 'COURSE_TAGS_COUNT']],
 'participation': [[0, 59, 'COURSE_TAGS_COUNT']],
 'heavy_reading': [[0, 60, 'COURSE_TAGS_COUNT']],
 'tough_grader': [[0, 61, 'COURSE_TAGS_COUNT']],
 'hilarious': [[0, 62, 'COURSE_TAGS_COUNT']],
 'would_take_again': [[0, 63, 'COURSE_TAGS_COUNT']],
 'good_lecture': [[0, 64, 'COURSE_TAGS_COUNT']],
 'no_skip': [[0, 65, 'COURSE_TAGS_COUNT']],
 'course_offering_id': [[0, 66, 'GSI']],
 'uniqname': [[0, 70, 'INSTRUCTOR']],
 'offering_instructor_id': [[0, 71, 'OFFERING_INSTRUCTOR']],
 'program_id': [[1, 74, 'PROGRAM'], [0, 78, 'PROGRAM_COURSE'], [0, 82, 'PROGRAM_REQUIREMENT'], [0, 92, 'STUDENT']],
 'college': [[0, 76, 'PROGRAM']],
 'introduction': [[0, 77, 'PROGRAM']],
 'workload': [[0, 80, 'PROGRAM_COURSE']],
 'category': [[0, 81, 'PROGRAM_COURSE'], [0, 83, 'PROGRAM_REQUIREMENT']],
 'min_credit': [[0, 84, 'PROGRAM_REQUIREMENT']],
 'additional_req': [[0, 85, 'PROGRAM_REQUIREMENT']],
 'semester_id': [[0, 86, 'SEMESTER']],
 'year': [[0, 88, 'SEMESTER']],
 'lastname': [[0, 90, 'STUDENT']],
 'firstname': [[0, 91, 'STUDENT']],
 'declare_major': [[0, 93, 'STUDENT']],
 'total_credit': [[0, 94, 'STUDENT']],
 'total_gpa': [[0, 95, 'STUDENT']],
 'entered_as': [[0, 96, 'STUDENT']],
 'admit_term': [[0, 97, 'STUDENT']],
 'predicted_graduation_semester': [[0, 98, 'STUDENT']],
 'degree': [[0, 99, 'STUDENT']],
 'minor': [[0, 100, 'STUDENT']],
 'internship': [[0, 101, 'STUDENT']],
 'grade': [[0, 105, 'STUDENT_RECORD']],
 'how': [[0, 106, 'STUDENT_RECORD']],
 'transfer_source': [[0, 107, 'STUDENT_RECORD']],
 'earn_credit': [[0, 108, 'STUDENT_RECORD']],
 'repeat_term': [[0, 109, 'STUDENT_RECORD']],
 'test_id': [[0, 110, 'STUDENT_RECORD']]}

In [182]:
for _, cols in fk_key['advising'].items():
    main_key = [x[1] for x in cols if x[0]==1]
    assert len(main_key) <= 1
    if len(main_key) == 1:
        main_key = main_key[0]
        for _, col_id, _ in cols:
            if col_id != main_key:
                dbs['advising'][0]['foreign_keys'].append([col_id, main_key])
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%'advising', 'w') as f:
    json.dump(dbs['advising'], f)

In [183]:
db = dbs['geography'][0]
column_mapping = {}
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
    col_name = col_name.lower()
    if col_name not in column_mapping:
        column_mapping[col_name] = []
    column_mapping[col_name].append([0, i, db['table_names_original'][tab_id]])

In [126]:
fk_key['geography'] = {
 'state_name': [[0, 1, 'border_info'], [0, 6, 'city'], [0, 7, 'highlow'], [0, 15, 'lake'], 
                [0, 19, 'mountain'], [1, 24, 'state']],
 'border': [[0, 2, 'border_info']],
 'city_name': [[0, 3, 'city']],
 'population': [[0, 4, 'city'], [0, 25, 'state']],
 'country_name': [[0, 5, 'city'],
  [0, 14, 'lake'],
  [0, 18, 'mountain'],
  [0, 22, 'river'],
  [0, 27, 'state']],
 'highest_elevation': [[0, 8, 'highlow']],
 'lowest_point': [[0, 9, 'highlow']],
 'highest_point': [[0, 10, 'highlow']],
 'lowest_elevation': [[0, 11, 'highlow']],
 'lake_name': [[0, 12, 'lake']],
 'area': [[0, 13, 'lake'], [0, 26, 'state']],
 'mountain_name': [[0, 16, 'mountain']],
 'mountain_altitude': [[0, 17, 'mountain']],
 'river_name': [[0, 20, 'river']],
 'length': [[0, 21, 'river']],
 'traverse': [[0, 23, 'river'], [1, 24, 'state']],
 'capital': [[0, 28, 'state']],
 'density': [[0, 29, 'state']]}

for _, cols in fk_key['geography'].items():
    main_key = [x[1] for x in cols if x[0]==1]
    assert len(main_key) <= 1
    if len(main_key) == 1:
        main_key = main_key[0]
        for _, col_id, _ in cols:
            if col_id != main_key:
                dbs['geography'][0]['foreign_keys'].append([col_id, main_key])
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%'geography', 'w') as f:
    json.dump(dbs['geography'], f)

In [120]:
db = dbs['imdb'][0]
column_mapping = {}
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
    col_name = col_name.lower()
    if col_name not in column_mapping:
        column_mapping[col_name] = []
    column_mapping[col_name].append([0, i, db['table_names_original'][tab_id]])

In [121]:
fk_key['imdb'] = {
 'aid': [[1, 1, 'actor'], [0, 9, 'cast']],
 'gender': [[0, 2, 'actor'], [0, 26, 'director'], [0, 44, 'producer'], [0, 60, 'writer']],
 'name': [[0, 3, 'actor'], [0, 11, 'sqlite_sequence'], [0, 17, 'company'], [0, 27, 'director'],
          [0, 45, 'producer'], [0, 61, 'writer']],
 'nationality': [[0, 4, 'actor'], [0, 28, 'director'], [0, 46, 'producer'], [0, 62, 'writer']],
 'birth_city': [[0, 5, 'actor'], [0, 29, 'director'], [0, 47, 'producer'], [0, 63, 'writer']],
 'birth_year': [[0, 6, 'actor'], [0, 30, 'director'], [0, 48, 'producer'], [0, 64, 'writer']],
 'id': [[0, 7, 'cast'], [0, 13, 'classification'], [0, 19, 'copyright'],
        [0, 22, 'directed_by'],[0, 33, 'keyword'], [0, 35, 'made_by'], [0, 49, 'tags'], [0, 65, 'written_by']],
 'msid': [[1, 52, 'tv_series'],[1, 38, 'movie'],[0, 8, 'cast'], [0, 14, 'classification'], [0, 20, 'copyright'], [0, 23, 'directed_by'],
          [0, 36, 'made_by'], [0, 50, 'tags'], [0, 66, 'written_by']],
 'role': [[0, 10, 'cast']],
 'gid': [[0, 15, 'classification'], [1, 31, 'genre']],
 'country_code': [[0, 18, 'company']],
 'cid': [[0, 21, 'copyright'], [1, 16, 'company']],
 'did': [[0, 24, 'directed_by'], [1, 25, 'director']],
 'genre': [[0, 32, 'genre']],
 'keyword': [[0, 34, 'keyword']],
 'pid': [[0, 37, 'made_by'], [1, 43, 'producer']],
 'mid': [[0, 38, 'movie']],
 'title': [[0, 39, 'movie'], [0, 53, 'tv_series']],
 'release_year': [[0, 40, 'movie'], [0, 54, 'tv_series']],
 'title_aka': [[0, 41, 'movie'], [0, 57, 'tv_series']],
 'budget': [[0, 42, 'movie'], [0, 58, 'tv_series']],
 'kid': [[1, 33, 'keyword'], [0, 51, 'tags']],
 'sid': [[0, 52, 'tv_series']],
 'num_of_seasons': [[0, 55, 'tv_series']],
 'num_of_episodes': [[0, 56, 'tv_series']],
 'wid': [[1, 59, 'writer'], [0, 67, 'written_by']]}

for _, cols in fk_key['imdb'].items():
    main_keys = [x[1] for x in cols if x[0]==1]
    if len(main_keys) != 0:
        for main_key in main_keys:
            for is_main, col_id, _ in cols:
                if is_main != 1:
                    dbs['imdb'][0]['foreign_keys'].append([col_id, main_key])
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%'imdb', 'w') as f:
    json.dump(dbs['imdb'], f)

In [200]:
db = dbs['yelp'][0]
column_mapping = {}
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
    col_name = col_name.lower()
    if col_name not in column_mapping:
        column_mapping[col_name] = []
    column_mapping[col_name].append([0, i, db['table_names_original'][tab_id]])

In [203]:
fk_key['yelp'] = {
 'bid': [[0, 1, 'business']],
 'business_id': [[1, 2, 'business'], [0, 15, 'category'], [0, 18, 'checkin'],
                 [0, 22, 'neighborhood'], [0, 25, 'review'], [0, 32, 'tip']],
 'name': [[0, 3, 'business'], [0, 12, 'sqlite_sequence'], [0, 40, 'user']],
 'full_address': [[0, 4, 'business']],
 'city': [[0, 5, 'business']],
 'latitude': [[0, 6, 'business']],
 'longitude': [[0, 7, 'business']],
 'review_count': [[0, 8, 'business']],
 'is_open': [[0, 9, 'business']],
 'rating': [[0, 10, 'business'], [0, 27, 'review']],
 'state': [[0, 11, 'business']],
 'seq': [[0, 13, 'sqlite_sequence']],
 'id': [[0, 14, 'category'], [0, 21, 'neighborhood']],
 'category_name': [[0, 16, 'category']],
 'cid': [[0, 17, 'checkin']],
 'count': [[0, 19, 'checkin']],
 'day': [[0, 20, 'checkin']],
 'neighborhood_name': [[0, 23, 'neighborhood']],
 'rid': [[0, 24, 'review']],
 'user_id': [[0, 26, 'review'], [0, 34, 'tip'], [1, 39, 'user']],
 'text': [[0, 28, 'review'], [0, 33, 'tip']],
 'year': [[0, 29, 'review'], [0, 36, 'tip']],
 'month': [[0, 30, 'review'], [0, 37, 'tip']],
 'tip_id': [[0, 31, 'tip']],
 'likes': [[0, 35, 'tip']],
 'uid': [[0, 38, 'user']]}

for _, cols in fk_key['yelp'].items():
    main_key = [x[1] for x in cols if x[0]==1]
    assert len(main_key) <= 1
    if len(main_key) == 1:
        main_key = main_key[0]
        for _, col_id, _ in cols:
            if col_id != main_key:
                dbs['yelp'][0]['foreign_keys'].append([col_id, main_key])
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%'yelp', 'w') as f:
    json.dump(dbs['yelp'], f)

In [204]:
db = dbs['academic'][0]
column_mapping = {}
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
    col_name = col_name.lower()
    if col_name not in column_mapping:
        column_mapping[col_name] = []
    column_mapping[col_name].append([0, i, db['table_names_original'][tab_id]])

In [208]:
fk_key['academic'] = {
 'aid': [[1, 1, 'author'], [0, 14, 'domain_author'], [0, 52, 'writes']],
 'name': [[0, 2, 'author'], [0, 9, 'conference'], [0, 13, 'domain'], [0, 29, 'journal'], [0, 38, 'organization']],
 'oid': [[0, 3, 'author'], [1, 37, 'organization']],
 'homepage': [[0, 4, 'author'], [0, 11, 'conference'], [0, 31, 'journal'], [0, 40, 'organization']],
 'photo': [[0, 5, 'author']],
 'citing': [[0, 6, 'cite']],
 'cited': [[0, 7, 'cite']],
 'cid': [[1, 8, 'conference'], [0, 16, 'domain_conference'], [0, 45, 'publication']],
 'full_name': [[0, 10, 'conference'], [0, 30, 'journal']],
 'did': [[1, 12, 'domain'], [0, 15, 'domain_author'], [0, 17, 'domain_conference'],
         [0, 19, 'domain_journal'], [0, 21, 'domain_keyword'], [0, 24, 'domain_publication']],
 'jid': [[0, 18, 'domain_journal'], [1, 28, 'journal'], [0, 46, 'publication']],
 'kid': [[0, 20, 'domain_keyword'], [1, 32, 'keyword'], [0, 35, 'keyword_variations'], [0, 51, 'publication_keyword']],
 'rank': [[0, 22, 'domain_keyword']],
 'pid': [[0, 23, 'domain_publication'], [1, 41, 'publication'], [0, 50, 'publication_keyword'], [0, 53, 'writes']],
 'relation': [[0, 25, 'ids']],
 'id': [[0, 26, 'ids']],
 'exist': [[0, 27, 'ids']],
 'keyword': [[0, 33, 'keyword']],
 'keyword_short': [[0, 34, 'keyword']],
 'variation': [[0, 36, 'keyword_variations']],
 'continent': [[0, 39, 'organization']],
 'title': [[0, 42, 'publication']],
 'abstract': [[0, 43, 'publication']],
 'year': [[0, 44, 'publication']],
 'reference_num': [[0, 47, 'publication']],
 'citation_num': [[0, 48, 'publication']],
 'doi': [[0, 49, 'publication']]}

for _, cols in fk_key['academic'].items():
    main_key = [x[1] for x in cols if x[0]==1]
    assert len(main_key) <= 1
    if len(main_key) == 1:
        main_key = main_key[0]
        for _, col_id, _ in cols:
            if col_id != main_key:
                dbs['academic'][0]['foreign_keys'].append([col_id, main_key])
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%'academic', 'w') as f:
    json.dump(dbs['academic'], f)

In [209]:
db = dbs['scholar'][0]
column_mapping = {}
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
    col_name = col_name.lower()
    if col_name not in column_mapping:
        column_mapping[col_name] = []
    column_mapping[col_name].append([0, i, db['table_names_original'][tab_id]])

In [212]:
fk_key['scholar'] = {
 'authorid': [[1, 1, 'author'], [0, 25, 'writes']],
 'authorname': [[0, 2, 'author']],
 'citingpaperid': [[1, 11, 'paper'], [0, 3, 'cite']],
 'citedpaperid': [[1, 11, 'paper'], [0, 4, 'cite']],
 'datasetid': [[1, 5, 'dataset'], [0, 19, 'paperDataset']],
 'datasetname': [[0, 6, 'dataset']],
 'journalid': [[1, 7, 'journal'], [0, 17, 'paper']],
 'journalname': [[0, 8, 'journal']],
 'keyphraseid': [[1, 9, 'keyphrase'], [0, 21, 'paperKeyphrase']],
 'keyphrasename': [[0, 10, 'keyphrase']],
 'paperid': [[1, 11, 'paper'],
  [0, 18, 'paperDataset'],
  [0, 20, 'paperKeyphrase'],
  [0, 24, 'writes']],
 'title': [[0, 12, 'paper']],
 'venueid': [[0, 13, 'paper'], [1, 22, 'venue']],
 'year': [[0, 14, 'paper']],
 'numciting': [[0, 15, 'paper']],
 'numcitedby': [[0, 16, 'paper']],
 'venuename': [[0, 23, 'venue']]}

for _, cols in fk_key['scholar'].items():
    main_key = [x[1] for x in cols if x[0]==1]
    assert len(main_key) <= 1
    if len(main_key) == 1:
        main_key = main_key[0]
        for _, col_id, _ in cols:
            if col_id != main_key:
                dbs['scholar'][0]['foreign_keys'].append([col_id, main_key])
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%'scholar', 'w') as f:
    json.dump(dbs['scholar'], f)

In [213]:
db = dbs['restaurants'][0]
column_mapping = {}
for i, (tab_id, col_name) in enumerate(db['column_names_original']):
    col_name = col_name.lower()
    if col_name not in column_mapping:
        column_mapping[col_name] = []
    column_mapping[col_name].append([0, i, db['table_names_original'][tab_id]])

In [216]:
fk_key['restaurants'] = {
 'city_name': [[1, 1, 'GEOGRAPHIC'], [0, 7, 'RESTAURANT'], [0, 12, 'LOCATION']],
 'county': [[0, 2, 'GEOGRAPHIC']],
 'region': [[0, 3, 'GEOGRAPHIC']],
 'id': [[0, 4, 'RESTAURANT']],
 'name': [[0, 5, 'RESTAURANT']],
 'food_type': [[0, 6, 'RESTAURANT']],
 'rating': [[0, 8, 'RESTAURANT']],
 'restaurant_id': [[1, 4, 'RESTAURANT'], [0, 9, 'LOCATION']],
 'house_number': [[0, 10, 'LOCATION']],
 'street_name': [[0, 11, 'LOCATION']]}

for _, cols in fk_key['restaurants'].items():
    main_key = [x[1] for x in cols if x[0]==1]
    assert len(main_key) <= 1
    if len(main_key) == 1:
        main_key = main_key[0]
        for _, col_id, _ in cols:
            if col_id != main_key:
                dbs['restaurants'][0]['foreign_keys'].append([col_id, main_key])
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%'restaurants', 'w') as f:
    json.dump(dbs['restaurants'], f)

In [214]:
db

{'db_id': 'restaurants',
 'table_names_original': ['GEOGRAPHIC', 'RESTAURANT', 'LOCATION'],
 'table_names': ['geographic', 'restaurant', 'location'],
 'column_names_original': [(-1, '*'),
  (0, 'CITY_NAME'),
  (0, 'COUNTY'),
  (0, 'REGION'),
  (1, 'ID'),
  (1, 'NAME'),
  (1, 'FOOD_TYPE'),
  (1, 'CITY_NAME'),
  (1, 'RATING'),
  (2, 'RESTAURANT_ID'),
  (2, 'HOUSE_NUMBER'),
  (2, 'STREET_NAME'),
  (2, 'CITY_NAME')],
 'column_names': [(-1, '*'),
  (0, 'city name'),
  (0, 'county'),
  (0, 'region'),
  (1, 'id'),
  (1, 'name'),
  (1, 'food type'),
  (1, 'city name'),
  (1, 'rating'),
  (2, 'restaurant id'),
  (2, 'house number'),
  (2, 'street name'),
  (2, 'city name')],
 'column_types': ['text',
  'text',
  'text',
  'text',
  'number',
  'text',
  'text',
  'text',
  'number',
  'number',
  'number',
  'text',
  'text'],
 'primary_keys': [1, 4, 9],
 'foreign_keys': [[7, 1], [12, 1]]}

# create spider subset

In [None]:
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20200607/dev.json','r') as f:
    dev_data = json.load(f)

In [None]:
def get_where_cond(sql):
    conds = []
    for unit in sql['where']:
        if isinstance(unit, list):
            col = unit[2][1][1]
            val1 = unit[3]
            val2 = unit[4]
            if isinstance(val1, dict):
                conds += get_where_cond(val1)
                val1 = None
            if isinstance(val2, dict):
                conds += get_where_cond(val2)
                val2 = None
            conds.append([col, val1, val2])
    if sql['intersect'] is not None:
        conds += get_where_cond(sql['intersect'])
    if sql['union'] is not None:
        conds += get_where_cond(sql['union'])
    if sql['except'] is not None:
        conds += get_where_cond(sql['except'])
    return conds
def get_select_col(sql):
    return [x[1][1][1] for x in sql['select'][1]]
def get_orderby(sql):
    cols = []
    if sql['orderBy']:
        cols += [x[1][1] for x in sql['orderBy'][1]]
    if sql['intersect'] is not None:
        cols += get_orderby(sql['intersect'])
    if sql['union'] is not None:
        cols += get_orderby(sql['union'])
    if sql['except'] is not None:
        cols += get_orderby(sql['except'])
    return cols
def get_groupby(sql):
    cols = []
    if sql['groupBy']:
        cols += [x[1] for x in sql['groupBy']]
    if sql['intersect'] is not None:
        cols += get_groupby(sql['intersect'])
    if sql['union'] is not None:
        cols += get_groupby(sql['union'])
    if sql['except'] is not None:
        cols += get_groupby(sql['except'])
    return cols
def get_having_cond(sql):
    conds = []
    for unit in sql['having']:
        if isinstance(unit, list):
            col = unit[2][1][1]
            val1 = unit[3]
            val2 = unit[4]
            if isinstance(val1, dict):
                conds += get_having_cond(val1)
                val1 = None
            if isinstance(val2, dict):
                conds += get_having_cond(val2)
                val2 = None
            conds.append([col, val1, val2])
    if sql['intersect'] is not None:
        conds += get_having_cond(sql['intersect'])
    if sql['union'] is not None:
        conds += get_having_cond(sql['union'])
    if sql['except'] is not None:
        conds += get_having_cond(sql['except'])
    return conds

In [None]:
def isnumber(x):
    try:
        if isinstance(x, str):
            x = x.strip().strip('"')
        float(x)
        return True
    except:
        return False

In [None]:
with open('spider_dev_tomodify.tsv', 'w') as f:
    for i,x in enumerate(dev_data):
        where_conds = get_where_cond(x['sql'])
        groupby = get_groupby(x['sql'])
        orderby = get_orderby(x['sql'])
        having_conds = get_having_cond(x['sql'])
        orig_q = x['question']
        modified_q = dev_modified.get(i,['',''])[1] # reload previous modified question
        if where_conds or having_conds or groupby or orderby:
            f.write('{}\t{}\t{}\t{}\n'.format(i, orig_q, modified_q, x['query']))

In [None]:
import copy
dev_modified_b = []
for idx, x in dev_modified.items():
    if x[1]:
        dev_modified_b.append(copy.deepcopy(dev_data[idx]))
        dev_modified_b[-1]['question'] = x[1]
        dev_modified_b[-1]['question_toks'] = x[1].split()
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20200607/spider_modified_b_dev.json', 'w') as f:
    json.dump(dev_modified_b, f)

# Result Analysis

In [282]:
link_analysis = {}
for dataset_id in datasets:
#     if dataset_id == 'scholar':
#         continue
    print(dataset_id)

    with open(os.path.join(datadir, '{}.json'.format(dataset_id))) as f:
        data = json.load(f)
        pairs = list()

    # The UMichigan data is split by anonymized queries, where values are
    # anonymized but table/column names are not. However, our experiments are
    # performed on the original splits of the data.
    nl_to_queryid = {}
    with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_tables.json'%dataset_id,'r') as f:
        schema = json.load(f)
        columns = ['{}.{}'.format(schema[0]["table_names_original"][tab_id], col_name) for tab_id, col_name in schema[0]["column_names_original"]]
    for i, query in enumerate(tqdm(data)):
        for example in query['sentences']:
            if example['question-split'] not in datasets[dataset_id]:
                continue
            nl = example['text']
            for variable_name, value in sorted(
                  example['variables'].items(), key=lambda x: len(x[0]), reverse=True):
                if not value:
                # TODO(alanesuhr) While the Michigan repo says to use a - here, the
                # thing that works is using a % and replacing = with LIKE.
                #
                # It's possible that I should remove such clauses from the SQL, as
                # long as they lead to the same table result. They don't align well
                # to the natural language at least.
                #
                # See: https://github.com/jkkummerfeld/text2sql-data/tree/master/data
                    value = '%'

                nl = nl.replace(variable_name, value)
            nl_to_queryid[nl] = i
    with open('../../logdirs/%s_bert_value_run_0_true_1_new_nocolvalue-step40000.eval'%dataset_id,'r') as f:
        results = json.load(f)
    with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_dev.json'%dataset_id,'r') as f:
        origs = json.load(f)
    link_analysis[dataset_id] = []
    with open('../../logdirs/%s_bert_value_run_0_true_1_new_nocolvalue-step40000.infer'%dataset_id,'r') as f:
        with open('../../logdirs/%s_analysis.tsv'%dataset_id, 'w') as f_out:
            for i, line in enumerate(f):
                inferred = json.loads(line.strip())['beams'][0]
                if i == 0:
                    f_out.write('query_id\tquestion\tgold_query\tpred_query\texec\texact\t')
                    for j,col in enumerate(columns):
                        if j!=0:
                            f_out.write('{}\t'.format(' '.join(col)))
                    f_out.write('\n')
                pred_query = inferred['inferred_code']
                orig = origs[i]
                q = orig['question']
                select_cols = get_select_col(orig['sql'])
                where_cols = [cond[0] for cond in get_where_cond(orig['sql'])]
                f_out.write('{}\t{}\t{}\t{}\t{}\t{}\t'.format(nl_to_queryid[q],q,orig["query"],pred_query,results['per_item'][i]['exec'],results['per_item'][i]['exact']))
                q_map = inferred['preproc_item']["question_for_copy"]
                column_links = {j:[] for j in range(len(inferred['preproc_item']['columns']))}
                all_links = {j:[] for j in range(len(inferred['preproc_item']['columns']))}
                for link, link_type in inferred['preproc_item']['sc_link']['q_col_match'].items():
                    q_loc, c_loc = link.split(',')
                    column_links[int(c_loc)].append('{}: {}'.format(q_map[int(q_loc)], link_type))
                    all_links[int(c_loc)].append(link_type)
                for link, link_type in inferred['preproc_item']['cv_link']['num_date_match'].items():
                    q_loc, c_loc = link.split(',')
                    column_links[int(c_loc)].append('{}: {}'.format(q_map[int(q_loc)], link_type))
                    all_links[int(c_loc)].append(link_type)
                for link, link_type in inferred['preproc_item']['cv_link']['cell_match'].items():
                    q_loc, c_loc = link.split(',')
                    column_links[int(c_loc)].append('{}: {}'.format(q_map[int(q_loc)], link_type))
                    all_links[int(c_loc)].append(link_type)
                for j,col in enumerate(inferred['preproc_item']['columns']):
                    if j != 0:
                        f_out.write('{}\t'.format('; '.join(column_links[j])))
                f_out.write('\n')
                link_analysis[dataset_id].append([all_links, select_cols, where_cols])

100%|██████████| 947/947 [00:00<00:00, 340331.24it/s]

atis



100%|██████████| 246/246 [00:00<00:00, 199728.76it/s]

geography



100%|██████████| 23/23 [00:00<00:00, 24299.49it/s]
100%|██████████| 193/193 [00:00<00:00, 138946.22it/s]

restaurants
scholar



100%|██████████| 89/89 [00:00<00:00, 179813.61it/s]
100%|██████████| 110/110 [00:00<00:00, 198611.04it/s]
100%|██████████| 205/205 [00:00<00:00, 30148.40it/s]

imdb
yelp
advising



100%|██████████| 185/185 [00:00<00:00, 256003.38it/s]

academic





In [286]:
dataset_id = 'spider_val'
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/dev.json','r') as f:
    origs = json.load(f)
link_analysis[dataset_id] = []
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/nl2code-1115,output_from=true,fs=2,emb=bert,cvlink,value,dec_min_freq=10,spideronly,newvalue/enc/val.jsonl','r') as f:
    for i, line in enumerate(f):
        preproc_item = json.loads(line.strip())
        orig = origs[i]
        db_id = orig['db_id']
        select_cols = get_select_col(orig['sql'])
        where_cols = [cond[0] for cond in get_where_cond(orig['sql'])]
        all_links = {j:[] for j in range(len(preproc_item['columns']))}
        for link, link_type in preproc_item['sc_link']['q_col_match'].items():
            q_loc, c_loc = link.split(',')
            all_links[int(c_loc)].append(link_type)
        for link, link_type in preproc_item['cv_link']['num_date_match'].items():
            q_loc, c_loc = link.split(',')
            all_links[int(c_loc)].append(link_type)
        for link, link_type in preproc_item['cv_link']['cell_match'].items():
            q_loc, c_loc = link.split(',')
            all_links[int(c_loc)].append(link_type)
        link_analysis[dataset_id].append([all_links, select_cols, where_cols])
        if dataset_id+'_'+db_id not in link_analysis:
            link_analysis[dataset_id+'_'+db_id] = []
        link_analysis[dataset_id+'_'+db_id].append([all_links, select_cols, where_cols])
dataset_id = 'spider_train'
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/train_spider.json','r') as f:
    origs = json.load(f)
link_analysis[dataset_id] = []
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/nl2code-1115,output_from=true,fs=2,emb=bert,cvlink,value,dec_min_freq=10,spideronly,newvalue/enc/train.jsonl','r') as f:
    for i, line in enumerate(f):
        preproc_item = json.loads(line.strip())
        orig = origs[i]
        select_cols = get_select_col(orig['sql'])
        where_cols = [cond[0] for cond in get_where_cond(orig['sql'])]
        all_links = {j:[] for j in range(len(preproc_item['columns']))}
        for link, link_type in preproc_item['sc_link']['q_col_match'].items():
            q_loc, c_loc = link.split(',')
            all_links[int(c_loc)].append(link_type)
        for link, link_type in preproc_item['cv_link']['num_date_match'].items():
            q_loc, c_loc = link.split(',')
            all_links[int(c_loc)].append(link_type)
        for link, link_type in preproc_item['cv_link']['cell_match'].items():
            q_loc, c_loc = link.split(',')
            all_links[int(c_loc)].append(link_type)
        link_analysis[dataset_id].append([all_links, select_cols, where_cols])
for dataset_id in datasets:
    with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_dev.json'%dataset_id,'r') as f:
        origs = json.load(f)
    link_analysis[dataset_id] = []
    with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/nl2code-1115,output_from=true,fs=2,emb=bert,cvlink,value,dec_min_freq=10,spideronly,newvalue/enc/val_%s.jsonl'%dataset_id,'r') as f:
        for i, line in enumerate(f):
            preproc_item = json.loads(line.strip())
            orig = origs[i]
            select_cols = get_select_col(orig['sql'])
            where_cols = [cond[0] for cond in get_where_cond(orig['sql'])]
            all_links = {j:[] for j in range(len(preproc_item['columns']))}
            for link, link_type in preproc_item['sc_link']['q_col_match'].items():
                q_loc, c_loc = link.split(',')
                all_links[int(c_loc)].append(link_type)
            for link, link_type in preproc_item['cv_link']['num_date_match'].items():
                q_loc, c_loc = link.split(',')
                all_links[int(c_loc)].append(link_type)
            for link, link_type in preproc_item['cv_link']['cell_match'].items():
                q_loc, c_loc = link.split(',')
                all_links[int(c_loc)].append(link_type)
            link_analysis[dataset_id].append([all_links, select_cols, where_cols])

In [283]:
dataset_id = 'treqs'
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/%s_dev.json'%dataset_id,'r') as f:
    origs = json.load(f)
link_analysis[dataset_id] = []
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/nl2code-1115,output_from=true,fs=2,emb=bert,cvlink,value,dec_min_freq=10,spideronly,newvalue/enc/val_%s.jsonl'%dataset_id,'r') as f:
    for i, line in enumerate(f):
        preproc_item = json.loads(line.strip())
        orig = origs[i]
        select_cols = get_select_col(orig['sql'])
        where_cols = [cond[0] for cond in get_where_cond(orig['sql'])]
        all_links = {j:[] for j in range(len(preproc_item['columns']))}
        for link, link_type in preproc_item['sc_link']['q_col_match'].items():
            q_loc, c_loc = link.split(',')
            all_links[int(c_loc)].append(link_type)
        for link, link_type in preproc_item['cv_link']['num_date_match'].items():
            q_loc, c_loc = link.split(',')
            all_links[int(c_loc)].append(link_type)
        for link, link_type in preproc_item['cv_link']['cell_match'].items():
            q_loc, c_loc = link.split(',')
            all_links[int(c_loc)].append(link_type)
        link_analysis[dataset_id].append([all_links, select_cols, where_cols])

In [243]:
def get_where_cond(sql):
    conds = []
    for unit in sql['where']:
        if isinstance(unit, list):
            col = unit[2][1][1]
            val1 = unit[3]
            val2 = unit[4]
            if isinstance(val1, dict):
                conds += get_where_cond(val1)
                val1 = None
            if isinstance(val2, dict):
                conds += get_where_cond(val2)
                val2 = None
            conds.append([col, val1, val2])
    return conds
def get_select_col(sql):
    return [x[1][1][1] for x in sql['select'][1]]

In [244]:
get_select_col(orig['sql'])

[2]

In [245]:
get_where_cond(orig['sql'])

[[13, '"Databases"', None], [38, '"University of Michigan"', None]]

In [302]:
link_types = ['CPM','CEM','ANYCELLMATCH','CELLMATCH','CELLTOKENMATCH','NUMBER','TIME']
link_analysis_processed = {}
for dataset_id in link_analysis:
    link_analysis_processed[dataset_id] = {}
    where_num = {'all':0, 'any':0}
    select_num = {'all':0, 'any':0}
    link_num = {'all':[0,0,0]}
    for all_links, select_cols, where_cols in link_analysis[dataset_id]:
        where_num['all'] += len(where_cols)
        select_num['all'] += len(select_cols)
        for col, links in all_links.items():
            links = set(links)
            if 'CELLMATCH' in links or 'CELLTOKENMATCH' in links:
                links.add('ANYCELLMATCH')
            if links:
                link_num['all'][0] += 1
                if col in select_cols:
                    link_num['all'][1] += 1
                    select_num['any'] += 1
                if col in where_cols:
                    link_num['all'][2] += 1
                    where_num['any'] += 1
            for link in links:
                if link not in link_num:
                    link_num[link] = [0,0,0]
                link_num[link][0] += 1
                if col in select_cols:
                    link_num[link][1] += 1
                    if link not in select_num:
                        select_num[link] = 0
                    select_num[link] += 1
                if col in where_cols:
                    link_num[link][2] += 1
                    if link not in where_num:
                        where_num[link] = 0
                    where_num[link] += 1
    print(dataset_id, end='\t')
    for k in link_types:
        print(k,'%.4f'%(select_num.get(k,0)/select_num['all']),end='\t')
        link_analysis_processed[dataset_id]['select_%s'%k] = select_num.get(k,0)/select_num['all']
    for k in link_types:
        print(k,'%.4f'%(where_num.get(k,0)/where_num['all']),end='\t')
        link_analysis_processed[dataset_id]['where_%s'%k] = where_num.get(k,0)/where_num['all']
    for k in ['all']+link_types:
        v = link_num.get(k,[1,0,0])
        print(k,'%.4f, %.4f'%(v[1]/v[0],v[2]/v[0]), end='\t')
        link_analysis_processed[dataset_id]['linked_%s'%k] = (v[1]/v[0],v[2]/v[0])
    print('\n')

atis	CPM 0.7176	CEM 0.0000	ANYCELLMATCH 0.0802	CELLMATCH 0.0802	CELLTOKENMATCH 0.0000	NUMBER 0.0000	TIME 0.0000	CPM 0.0600	CEM 0.0030	ANYCELLMATCH 0.4708	CELLMATCH 0.4708	CELLTOKENMATCH 0.0030	NUMBER 0.0465	TIME 0.0000	all 0.0495, 0.0880	CPM 0.0702, 0.0149	CEM 0.0000, 0.1333	ANYCELLMATCH 0.0167, 0.2500	CELLMATCH 0.0277, 0.4137	CELLTOKENMATCH 0.0000, 0.0040	NUMBER 0.0000, 0.0764	TIME 0.0000, 0.0000	

geography	CPM 0.4624	CEM 0.3237	ANYCELLMATCH 0.2601	CELLMATCH 0.1561	CELLTOKENMATCH 0.1098	NUMBER 0.0000	TIME 0.0000	CPM 0.2461	CEM 0.1061	ANYCELLMATCH 0.5205	CELLMATCH 0.5205	CELLTOKENMATCH 0.0113	NUMBER 0.0000	TIME 0.0000	all 0.0988, 0.1151	CPM 0.1164, 0.0844	CEM 0.3953, 0.1765	ANYCELLMATCH 0.0557, 0.1519	CELLMATCH 0.0394, 0.1788	CELLTOKENMATCH 0.1469, 0.0206	NUMBER 0.0000, 0.0000	TIME 0.0000, 0.0000	

restaurants	CPM 0.0000	CEM 0.0000	ANYCELLMATCH 0.0000	CELLMATCH 0.0000	CELLTOKENMATCH 0.0000	NUMBER 0.0000	TIME 0.0000	CPM 0.0769	CEM 0.0769	ANYCELLMATCH 0.9615	CELLMATCH 0.9231	CELLTOKENMA

spider_train	CPM 0.2621	CEM 0.3190	ANYCELLMATCH 0.0357	CELLMATCH 0.0120	CELLTOKENMATCH 0.0249	NUMBER 0.0170	TIME 0.0209	CPM 0.2516	CEM 0.2233	ANYCELLMATCH 0.2728	CELLMATCH 0.2518	CELLTOKENMATCH 0.0306	NUMBER 0.1261	TIME 0.0151	all 0.0714, 0.0335	CPM 0.0535, 0.0224	CEM 0.2480, 0.0756	ANYCELLMATCH 0.0594, 0.1977	CELLMATCH 0.0347, 0.3171	CELLTOKENMATCH 0.0911, 0.0488	NUMBER 0.0248, 0.0799	TIME 0.0161, 0.0051	



In [306]:
with open('/home/t-xiaden/workspace/featurestorage/data/spider-20190205/dev.json','r') as f:
    origs = json.load(f)
with open('/home/t-xiaden/workspace/NL2CodeOverData/logdirs/bert_value_run_0_true_1_new_nocolvalue-step40000.eval', 'r') as f:
    results = json.load(f)
db_result = {}
for i, per_item in enumerate(results['per_item']):
    db_id = origs[i]['db_id']
    if db_id not in db_result:
        db_result[db_id] = {'exec':[],'exact':[],'select':[],'where':[],'where(with value)':[]}
    db_result[db_id]['exec'].append(per_item['exec'])
    db_result[db_id]['exact'].append(per_item['exact'])
    for k in ['select','where','where(with value)']:
        db_result[db_id][k].append(per_item['partial'][k]['acc'])
for db_id in db_result:
    db_result[db_id]['num'] = len(db_result[db_id]['exec'])
    db_result[db_id]['exec'] = sum(db_result[db_id]['exec'])/len(db_result[db_id]['exec'])
    db_result[db_id]['exact'] = sum(db_result[db_id]['exact'])/len(db_result[db_id]['exact'])
    for k in ['select','where','where(with value)']:
        db_result[db_id][k] = sum(db_result[db_id][k])/len(db_result[db_id][k])
    

In [307]:
for x in sorted(db_result.items(), key=lambda x:x[1]['exec']):
    print(x)
    print(link_analysis_processed['spider_val_'+x[0]])

('real_estate_properties', {'exec': 0.25, 'exact': 0.25, 'select': 1.0, 'where': 0.75, 'where(with value)': 0.75, 'num': 4})
{'select_CPM': 0.75, 'select_CEM': 0.5, 'select_ANYCELLMATCH': 0.0, 'select_CELLMATCH': 0.0, 'select_CELLTOKENMATCH': 0.0, 'select_NUMBER': 0.0, 'select_TIME': 0.0, 'where_CPM': 1.0, 'where_CEM': 0.0, 'where_ANYCELLMATCH': 0.5, 'where_CELLMATCH': 0.5, 'where_CELLTOKENMATCH': 0.0, 'where_NUMBER': 0.0, 'where_TIME': 0.0, 'linked_all': (0.03333333333333333, 0.022222222222222223), 'linked_CPM': (0.0375, 0.025), 'linked_CEM': (1.0, 0.0), 'linked_ANYCELLMATCH': (0.0, 1.0), 'linked_CELLMATCH': (0.0, 1.0), 'linked_CELLTOKENMATCH': (0.0, 0.0), 'linked_NUMBER': (0.0, 0.0), 'linked_TIME': (0.0, 0.0)}
('world_1', {'exec': 0.4, 'exact': 0.3333333333333333, 'select': 0.75, 'where': 0.575, 'where(with value)': 0.5416666666666666, 'num': 120})
{'select_CPM': 0.042682926829268296, 'select_CEM': 0.7012195121951219, 'select_ANYCELLMATCH': 0.15853658536585366, 'select_CELLMATCH': 0.

In [300]:
per_item

{'predicted': 'SELECT Properties.property_name FROM Properties WHERE Properties.room_count > 1.0 OR Properties.room_count > 1.0',
 'gold': 'SELECT property_name FROM Properties WHERE property_type_code  =  "House" UNION SELECT property_name FROM Properties WHERE property_type_code  =  "Apartment" AND room_count  >  1',
 'predicted_parse_error': False,
 'hardness': 'hard',
 'exact': 0,
 'exact (with val)': 0,
 'partial': {'select': {'acc': 1,
   'rec': 1,
   'f1': 1,
   'label_total': 1,
   'pred_total': 1},
  'select(no AGG)': {'acc': 1,
   'rec': 1,
   'f1': 1,
   'label_total': 1,
   'pred_total': 1},
  'where': {'acc': 0, 'rec': 0, 'f1': 0, 'label_total': 1, 'pred_total': 2},
  'where(no OP)': {'acc': 0,
   'rec': 0,
   'f1': 0,
   'label_total': 1,
   'pred_total': 2},
  'where(with value)': {'acc': 0,
   'rec': 0,
   'f1': 0,
   'label_total': 1,
   'pred_total': 2},
  'group(no Having)': {'acc': 1,
   'rec': 1,
   'f1': 1,
   'label_total': 0,
   'pred_total': 0},
  'group': {'ac