In [2]:
import json
import re
import os
import pandas as pd
import time

import sqlvalidator
import googletrans as gt

from datasets import load_dataset
from langdetect import detect


def translate(source, target_lang='en'):
    return gt.translate(source, target_lang)

CACHE_DIR = "../cache"
SPIDER_DATASET_DIR = "../dataset/spider/"

In [2]:
def find_primary_keys_MYSQL_like(db_name, spider_primary):
    df = spider_primary[spider_primary['Database name'] == db_name]
    output = "["
    for index, row in df.iterrows():
        output += row['Table Name'] + '.' + row['Primary Key'] +','
    output = output[:-1]
    output += "]\n"
    return output
def creating_schema(DATASET_JSON):
    schema_df = pd.read_json(DATASET_JSON)
    schema_df = schema_df.drop(['column_names','table_names'], axis=1)
    schema = []
    f_keys = []
    p_keys = []
    for index, row in schema_df.iterrows():
        tables = row['table_names_original']
        col_names = row['column_names_original']
        col_types = row['column_types']
        foreign_keys = row['foreign_keys']
        primary_keys = row['primary_keys']
        for col, col_type in zip(col_names, col_types):
            index, col_name = col
            if index == -1:
                for table in tables:
                    schema.append([row['db_id'], table, '*', 'text'])
            else:
                schema.append([row['db_id'], tables[index], col_name, col_type])
        for primary_key in primary_keys:
            index, column = col_names[primary_key]
            p_keys.append([row['db_id'], tables[index], column])
        for foreign_key in foreign_keys:
            first, second = foreign_key
            first_index, first_column = col_names[first]
            second_index, second_column = col_names[second]
            f_keys.append([row['db_id'], tables[first_index], tables[second_index], first_column, second_column])
    spider_schema = pd.DataFrame(schema, columns=['Database name', ' Table Name', ' Field Name', ' Type'])
    spider_primary = pd.DataFrame(p_keys, columns=['Database name', 'Table Name', 'Primary Key'])
    spider_foreign = pd.DataFrame(
        f_keys,
        columns=[
            'Database name', 'First Table Name', 'Second Table Name', 'First Table Foreign Key',
            'Second Table Foreign Key'
        ]
    )
    return spider_schema,spider_primary,spider_foreign

In [3]:
def convert_type_to_sql_type(type):
    if type == 'text':
        return 'VARCHAR'
    elif type == 'integer' or type == 'number' or type == 'int':
        return 'INTEGER'
    elif type == 'time':
        return 'DATETIME'
    elif type == 'boolean':
        return 'BOOLEAN'
    elif type == 'real' or type == 'float' or type == 'double':
        return 'FLOAT'
    elif type == 'others':
        return 'BOOLEAN'
    else:
        return 'VARCHAR'

def get_context_with_db_name(db_name, spider_schema, spider_primary, spider_foreign):
    # find all tables related to db_name
    df = spider_schema[spider_schema['Database name'] == db_name]
    df = df.groupby(' Table Name')
    tables = []
    for name, group in df:
        table = {}
        table['name'] = name
        table['columns'] = []
        for index, row in group.iterrows():
            table['columns'].append((row[" Field Name"], row[" Type"]))
        tables.append(table)
        
    # for each table, create the "CREATE TABLE" statement and append it to context
    statements = []
    for table in tables:
        statement = "CREATE TABLE " + table['name'] + " ("
        for idx, column in enumerate(table['columns']):
            col_name = column[0]
            col_type = column[1]
            if col_name == '*':
                continue
            if ' ' in col_name:
                col_name = '"' + col_name + '"'
            
            
            statement += col_name + " " + col_type
            if idx != len(table['columns']) - 1:
                statement += ", "
        statement = statement + ")"
        statements.append(statement)
    
    # print("; ".join(statements))
    return "; ".join(statements)

In [21]:
def preprocess_text(text):
    text = text.replace('\n', ' ')
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'\(\s+', '(', text)
    text = re.sub(r'\s+\)', ')', text)
    text = text.replace(' ,  ', ', ')
    text = text.replace(' .  ', '. ')
    text = text.replace(' ,' , ',')
    return text

In [22]:
spider_schema,spider_primary,spider_foreign = creating_schema(
    os.path.join(SPIDER_DATASET_DIR, 'tables.json')
)

In [23]:
train_spider = pd.read_json(os.path.join(SPIDER_DATASET_DIR, 'train_spider.json'))
print(len(train_dataset))
train_dataset.head()

7000


Unnamed: 0,db_id,query,query_toks,query_toks_no_value,question,question_toks,sql
0,department_management,SELECT count(*) FROM head WHERE age > 56,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...",How many heads of the departments are older th...,"[How, many, heads, of, the, departments, are, ...","{'from': {'table_units': [['table_unit', 1]], ..."
1,department_management,"SELECT name , born_state , age FROM head ORD...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","List the name, born state and age of the heads...","[List, the, name, ,, born, state, and, age, of...","{'from': {'table_units': [['table_unit', 1]], ..."
2,department_management,"SELECT creation , name , budget_in_billions ...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","List the creation year, name and budget of eac...","[List, the, creation, year, ,, name, and, budg...","{'from': {'table_units': [['table_unit', 0]], ..."
3,department_management,"SELECT max(budget_in_billions) , min(budget_i...","[SELECT, max, (, budget_in_billions, ), ,, min...","[select, max, (, budget_in_billions, ), ,, min...",What are the maximum and minimum budget of the...,"[What, are, the, maximum, and, minimum, budget...","{'from': {'table_units': [['table_unit', 0]], ..."
4,department_management,SELECT avg(num_employees) FROM department WHER...,"[SELECT, avg, (, num_employees, ), FROM, depar...","[select, avg, (, num_employees, ), from, depar...",What is the average number of employees of the...,"[What, is, the, average, number, of, employees...","{'from': {'table_units': [['table_unit', 0]], ..."


In [24]:
train_others = pd.read_json(os.path.join(SPIDER_DATASET_DIR, 'train_others.json'))
print(len(train_others))
train_others.head()

1659


Unnamed: 0,db_id,query,query_toks,query_toks_no_value,question,question_toks,sql
0,geo,SELECT city_name FROM city WHERE population =...,"[SELECT, city_name, FROM, city, WHERE, populat...","[select, city_name, from, city, where, populat...",what is the biggest city in wyoming,"[what, is, the, biggest, city, in, wyoming]","{'from': {'table_units': [['table_unit', 1]], ..."
1,geo,SELECT city_name FROM city WHERE population =...,"[SELECT, city_name, FROM, city, WHERE, populat...","[select, city_name, from, city, where, populat...",what wyoming city has the largest population,"[what, wyoming, city, has, the, largest, popul...","{'from': {'table_units': [['table_unit', 1]], ..."
2,geo,SELECT city_name FROM city WHERE population =...,"[SELECT, city_name, FROM, city, WHERE, populat...","[select, city_name, from, city, where, populat...",what is the largest city in wyoming,"[what, is, the, largest, city, in, wyoming]","{'from': {'table_units': [['table_unit', 1]], ..."
3,geo,SELECT city_name FROM city WHERE population =...,"[SELECT, city_name, FROM, city, WHERE, populat...","[select, city_name, from, city, where, populat...",where is the most populated area of wyoming,"[where, is, the, most, populated, area, of, wy...","{'from': {'table_units': [['table_unit', 1]], ..."
4,geo,SELECT city_name FROM city WHERE population =...,"[SELECT, city_name, FROM, city, WHERE, populat...","[select, city_name, from, city, where, populat...",which city in wyoming has the largest population,"[which, city, in, wyoming, has, the, largest, ...","{'from': {'table_units': [['table_unit', 1]], ..."


In [25]:
get_context_with_db_name('geo', spider_schema, spider_primary, spider_foreign)

'CREATE TABLE border_info (state_name text, border text); CREATE TABLE city (city_name text, population number, country_name text, state_name text); CREATE TABLE highlow (state_name text, highest_elevation text, lowest_point text, highest_point text, lowest_elevation text); CREATE TABLE lake (lake_name text, area number, country_name text, state_name text); CREATE TABLE mountain (mountain_name text, mountain_altitude number, country_name text, state_name text); CREATE TABLE river (river_name text, length number, country_name text, traverse text); CREATE TABLE state (state_name text, population number, area number, country_name text, capital text, density number)'

In [28]:
processed_dataset = {'train': []}
for database in [train_spider, train_others]:
    for idx, sample in database.iterrows():
        db_id = sample['db_id']
        answer = preprocess_text(sample['query'])
        context = get_context_with_db_name(db_id, spider_schema, spider_primary, spider_foreign)
        question = preprocess_text(sample['question'])
        processed_dataset['train'].append({
            'context': context,
            'question': question,
            'answer': answer
        })


In [36]:
dev_spider = pd.read_json(os.path.join(SPIDER_DATASET_DIR, 'dev.json'))
processed_dataset['dev'] = []
for idx, sample in dev_spider.iterrows():
    db_id = sample['db_id']
    answer = preprocess_text(sample['query'])
    context = get_context_with_db_name(db_id, spider_schema, spider_primary, spider_foreign)
    question = preprocess_text(sample['question'])
    processed_dataset['dev'].append({
        'context': context,
        'question': question,
        'answer': answer
    })

In [37]:
# save each subset to json file
JSON_DIR = "../dataset/SpiderInstruct_raw"
os.makedirs(JSON_DIR, exist_ok=True)
for subset in processed_dataset:
    with open(f"{JSON_DIR}/{subset}.jsonl", 'w', encoding='utf-8') as f:
        for sample in processed_dataset[subset]:
            json.dump(sample, f, ensure_ascii=False)
            f.write('\n')

In [39]:
import os
import tarfile

def makr_tarfile(output_filename, source_dir):
    with tarfile.open(output_filename, "w:gz") as tar:
        tar.add(source_dir, arcname=os.path.basename(source_dir))
        
for subset in processed_dataset:
    makr_tarfile(
        os.path.join(JSON_DIR, f"{subset}.tar.gz"),
        os.path.join(JSON_DIR, f"{subset}.jsonl")
    )

In [6]:
from datasets import load_dataset

dataset = load_dataset('tmnam20/SpiderInstruct', None, cache_dir=CACHE_DIR)
dataset

DatasetDict({
    train: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 8659
    })
    validation: Dataset({
        features: ['answer', 'question', 'context'],
        num_rows: 1034
    })
})

In [5]:
"validation" in dataset.keys()

True