In [1]:
import pandas as pd
from transformers import T5Tokenizer, T5ForConditionalGeneration

  from .autonotebook import tqdm as notebook_tqdm


In [41]:
tokenizer = T5Tokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
model = T5ForConditionalGeneration.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")

Downloading (…)ve/main/spiece.model: 100%|██████████| 792k/792k [00:00<00:00, 10.1MB/s]
Downloading (…)cial_tokens_map.json: 1.79kB [00:00, 2.78MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 25.0/25.0 [00:00<00:00, 13.7kB/s]
Downloading (…)lve/main/config.json: 1.23kB [00:00, 1.78MB/s]
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
Downloading pytorch_model.bin: 100%|██████████| 1.19G/1.19G [00:29<00:00, 40.7MB/s]


In [19]:
# information about the cross-domain tables
with open('spider/tables.json', 'r') as f:
    schema_df = pd.read_json(f)

# datasets
with open('spider/train_spider.json', 'r') as f:
    train_spider = pd.read_json(f)
with open('spider/train_others.json', 'r') as f:
    others_spider = pd.read_json(f)
with open('spider/dev.json', 'r') as f:
    dev_spider = pd.read_json(f)

In [16]:
print(train_spider.columns, len(train_spider))
print(others_spider.columns, len(others_spider))
print(dev_spider.columns, len(dev_spider))
print(schema_df.columns, len(schema_df))

Index(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question',
       'question_toks', 'sql'],
      dtype='object') 7000
Index(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question',
       'question_toks', 'sql'],
      dtype='object') 1659
Index(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question',
       'question_toks', 'sql'],
      dtype='object') 1034
Index(['column_names', 'column_names_original', 'column_types', 'db_id',
       'foreign_keys', 'primary_keys', 'table_names', 'table_names_original'],
      dtype='object') 166


In [22]:
schema = []
f_keys = []
p_keys = []
for index, row in schema_df.iterrows():
    tables = row['table_names_original']
    col_names = row['column_names_original']
    col_types = row['column_types']
    foreign_keys = row['foreign_keys']
    primary_keys = row['primary_keys']
    for col, col_type in zip(col_names, col_types):
        index, col_name = col
        if index == -1:
            for table in tables:
                schema.append([row['db_id'], table, '*', 'text'])
        else:
            schema.append([row['db_id'], tables[index], col_name, col_type])
    for primary_key in primary_keys:
        index, column = col_names[primary_key]
        p_keys.append([row['db_id'], tables[index], column])
    for foreign_key in foreign_keys:
        first, second = foreign_key
        first_index, first_column = col_names[first]
        second_index, second_column = col_names[second]
        f_keys.append([row['db_id'], tables[first_index], tables[second_index], first_column, second_column])
spider_schema = pd.DataFrame(schema, columns=['Database name', ' Table Name', ' Field Name', ' Type'])
spider_primary = pd.DataFrame(p_keys, columns=['Database name', 'Table Name', 'Primary Key'])
spider_foreign = pd.DataFrame(f_keys,
                    columns=['Database name', 'First Table Name', 'Second Table Name', 'First Table Foreign Key',
                                'Second Table Foreign Key'])

In [28]:
print('schema: \n', spider_schema.head())
print('primary key: \n', spider_primary.head())
print('foreign key: \n', spider_foreign.head())

schema: 
   Database name   Table Name      Field Name    Type
0   perpetrator  perpetrator               *    text
1   perpetrator       people               *    text
2   perpetrator  perpetrator  Perpetrator_ID  number
3   perpetrator  perpetrator       People_ID  number
4   perpetrator  perpetrator            Date    text
primary key: 
   Database name   Table Name     Primary Key
0   perpetrator  perpetrator  Perpetrator_ID
1   perpetrator       people       People_ID
2     college_2    classroom        building
3     college_2   department       dept_name
4     college_2       course       course_id
foreign key: 
   Database name First Table Name Second Table Name First Table Foreign Key  \
0   perpetrator      perpetrator            people               People_ID   
1     college_2           course        department               dept_name   
2     college_2       instructor        department               dept_name   
3     college_2          section         classroom          

In [43]:
prefix = 'translate English to SQL:'

dev_questions = list(dev_spider.iloc[:10]['question'])
dev_questions = [prefix + q for q in dev_questions]

inputs = tokenizer(dev_questions, return_tensors="pt", padding=True)

output_tokens = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_length=64
)

outputs = [tokenizer.decode(i, skip_special_tokens=True) for i in output_tokens]

print(outputs)

['SELECT COUNT Singer(s) FROM table', 'SELECT COUNT Singer(s) FROM table', 'SELECT Name, Country, Age FROM table WHERE Order by Age = oldest to youngest', 'SELECT Name, Country, Age FROM table WHERE Age indescending order = singer', 'SELECT Age (Age) FROM table WHERE Country = france AND Minimum/Max. = minimum', 'SELECT Age (Age) FROM table WHERE Language = french AND Minimum/Maximum = all singers', 'SELECT Name and Release Year FROM table WHERE Name = youngest singer', 'SELECT Names and Release Year FROM table WHERE Name = youngest singer', 'SELECT Country FROM table WHERE Age > 20', 'SELECT Country FROM table WHERE Age > 20']


In [48]:
# Use test-suite-sql-eval-master/evaluation.py for evaluation

queries = list(dev_spider.iloc[:10]['query'])