In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset

In [2]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"

In [3]:
hf_dataset_name = "wikisql"
dataset = load_dataset(hf_dataset_name)

In [4]:
model_name = "google/flan-t5-base"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

# No Prompt Engineering

In [6]:
example_indices = [40, 200]

dash_line = "-".join(" " for _ in range(100))

for i, index in enumerate(example_indices):
    question = dataset["test"][index]["question"]
    sql_query = dataset["test"][index]["sql"]["human_readable"]
    
    input = tokenizer(question, return_tensors="pt")
    input.to(device)
    output = tokenizer.decode(model.generate(input["input_ids"], max_new_tokens=50)[0], skip_special_tokens=True)
    
    print(dash_line)
    print("Example", i+1)
    print(dash_line)
    print(f"INPUT_PROMPT:\n{question}")
    print(dash_line)
    print(f"BASELINE SQL QUERY:\n{sql_query}")
    print(dash_line)
    print(f"MODEL-GENERATION:\n{output}")
    print(dash_line)

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Example 1
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
NATURAL LANGUAGE QUERY:
Which Frequency is used for WEGP calls?
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
BASELINE SQL QUERY:
SELECT Frequency FROM table WHERE Calls = WEGP
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
MODEL-GENERATION:
dvd
 - - - - - - - - - - - - - - - - - - 

# Zero Shot Inference with an Instruction Prompt

In [7]:
dash_line = "-".join(" " for _ in range(100))

for i, index in enumerate(example_indices):
    question = dataset["test"][index]["question"]
    sql_query = dataset["test"][index]["sql"]["human_readable"]
    
    prompt = f"""
Translate this query into SQL:

{question}

SQL:
    """
    
    input = tokenizer(prompt, return_tensors="pt")
    input.to(device)
    output = tokenizer.decode(model.generate(input["input_ids"], max_new_tokens=50)[0], skip_special_tokens=True)
    
    print(dash_line)
    print("Example", i+1)
    print(dash_line)
    print(f"INPUT_PROMPT:\n{prompt}")
    print(dash_line)
    print(f"BASELINE SQL QUERY:\n{sql_query}")
    print(dash_line)
    print(f"MODEL-GENERATION:\n{output}")
    print(dash_line)

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Example 1
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
INPUT_PROMPT:

Translate this query into SQL:

Which Frequency is used for WEGP calls?

SQL:
    
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
BASELINE SQL QUERY:
SELECT Frequency FROM table WHERE Calls = WEGP
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
MODEL-GENERATION:
Which F

# Zero Shot Inference with FLAN-T5 Prompt Template

In [8]:
dash_line = "-".join(" " for _ in range(100))

for i, index in enumerate(example_indices):
    question = dataset["test"][index]["question"]
    sql_query = dataset["test"][index]["sql"]["human_readable"]
    
    prompt = f"""
Task: Convert the following natural language question into an SQL query.


Question: {question}


SQL Query:
    """
    
    input = tokenizer(prompt, return_tensors="pt")
    input.to(device)
    output = tokenizer.decode(model.generate(input["input_ids"], max_new_tokens=50)[0], skip_special_tokens=True)
    
    print(dash_line)
    print("Example", i+1)
    print(dash_line)
    print(f"INPUT_PROMPT:\n{prompt}")
    print(dash_line)
    print(f"BASELINE SQL QUERY:\n{sql_query}")
    print(dash_line)
    print(f"MODEL-GENERATION:\n{output}")
    print(dash_line)

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Example 1
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
INPUT_PROMPT:

Task: Convert the following natural language question into an SQL query.


Question: Which Frequency is used for WEGP calls?


SQL Query:
    
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
BASELINE SQL QUERY:
SELECT Frequency FROM table WHERE Calls = WEGP
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 

# One Shot Inference

In [9]:
def create_prompt(example_indices_ref, example_index):
    prompt = ""
    for index_ref in example_indices_ref:
        question_ref = dataset["test"][index_ref]["question"]
        sql_query_ref = dataset["test"][index_ref]["sql"]["human_readable"]
        
        prompt += f"""
Task: Convert the following natural language question into an SQL query.


Question: {question_ref}


SQL Query: {sql_query_ref}
    """
        
    question = dataset["test"][example_index]["question"]
    
    prompt += f"""
Task: Convert the following natural language question into an SQL query.


Question: {question}


SQL Query:
    """
    
    return prompt

In [11]:
example_indices_ref = [20]

dash_line = "-".join(" " for _ in range(100))

for i, index in enumerate(example_indices):
    prompt = create_prompt(example_indices_ref, index)
    sql_query = dataset["test"][index]["sql"]["human_readable"]
    
    input = tokenizer(prompt, return_tensors="pt")
    input.to(device)
    output = tokenizer.decode(model.generate(input["input_ids"], max_new_tokens=50)[0], skip_special_tokens=True)
    
    print(dash_line)
    print("Example", i+1)
    print(dash_line)
    print(f"INPUT_PROMPT:\n{prompt}")
    print(dash_line)
    print(f"BASELINE SQL QUERY:\n{sql_query}")
    print(dash_line)
    print(f"MODEL-GENERATION:\n{output}")
    print(dash_line)

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Example 1
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
INPUT_PROMPT:

Task: Convert the following natural language question into an SQL query.


Question: If you are a pilot officer in the commonwealth then what will you called as in the US air force?


SQL Query: SELECT US Air Force equivalent FROM table WHERE Commonwealth equivalent = Pilot Officer
    
Task: Convert the following natural language question into an SQL query.


Question: Which Frequency is used for WEGP calls?


SQL Query:
    
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

# Few Shot Inference

In [13]:
example_indices_ref = [10, 20]

dash_line = "-".join(" " for _ in range(100))

for i, index in enumerate(example_indices):
    prompt = create_prompt(example_indices_ref, index)
    sql_query = dataset["test"][index]["sql"]["human_readable"]
    
    input = tokenizer(prompt, return_tensors="pt")
    input.to(device)
    output = tokenizer.decode(model.generate(input["input_ids"], max_new_tokens=50)[0], skip_special_tokens=True)
    
    print(dash_line)
    print("Example", i+1)
    print(dash_line)
    print(f"INPUT_PROMPT:\n{prompt}")
    print(dash_line)
    print(f"BASELINE SQL QUERY:\n{sql_query}")
    print(dash_line)
    print(f"MODEL-GENERATION:\n{output}")
    print(dash_line)

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Example 1
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
INPUT_PROMPT:

Task: Convert the following natural language question into an SQL query.


Question: How many different nationalities do the players of New Jersey Devils come from?


SQL Query: SELECT COUNT Nationality FROM table WHERE NHL team = New Jersey Devils
    
Task: Convert the following natural language question into an SQL query.


Question: If you are a pilot officer in the commonwealth then what will you called as in the US air force?


SQL Query: SELECT US Air Force equivalent FROM table WHERE Commonwealth equivalent = Pilot Officer
    
Task: Convert the following natur