In [1]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Initialize the tokenizer from Hugging Face Transformers library
tokenizer = T5Tokenizer.from_pretrained('t5-small')

  from .autonotebook import tqdm as notebook_tqdm
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained('cssupport/t5-small-awesome-text-to-sql')
model = model.to(device)
model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [3]:
def generate_sql(input_prompt):
    # Tokenize the input prompt
    inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device)
    
    # Forward pass
    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=512)
    
    # Decode the output IDs to a string (SQL query in this case)
    generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return generated_sql

In [4]:
# Feeding the table schema as a prompt

user_input = "get all the clinics with zip code 27"

table_schema = """CREATE TABLE dentists (
    dentist_name text, 
    clinic_name text,
    address_town text,
    citytown text,
    province text,
    region text
    zip_code real,
    contact_number text,
    schedule text,
    id integer,
    by_appointment int,
    monday_start int ,
    monday_end int ,
    tuesday_start int ,
    tuesday_end int ,
    wednesday_start int ,
    wednesday_end int ,
    thursday_start int ,
    thursday_end int ,
    friday_start int ,
    friday_end int ,
    saturday_start int ,
    saturday_end int ,
    sunday_start int ,
    sunday_end int)"""

prompt = f"tables: {table_schema}\n \nquery for: {user_input}"

In [5]:
# Test the function
generated_sql = generate_sql(prompt)

print(f"The generated SQL query is: {generated_sql}")

The generated SQL query is: SELECT clinic_name FROM dentists WHERE zip_code = 27


In [6]:
user_input = "Show me clinics that are open on Monday"
prompt = f"tables: {table_schema}\n \nquery for: {user_input}"

generated_sql_2 = generate_sql(prompt)

print(f"The generated SQL query is: {generated_sql_2}")

The generated SQL query is: SELECT clinic_name FROM dentists WHERE NOT dentist_name IN (SELECT dentist_name FROM dentists WHERE monday_start = 'Monday')


In [7]:
user_input = "Show me clinics in Malabon"
prompt = f"tables: {table_schema}\n \nquery for: {user_input}"

generated_sql_2 = generate_sql(prompt)

print(f"The generated SQL query is: {generated_sql_2}")

The generated SQL query is: SELECT clinic_name FROM dentists WHERE city = "Malabon"


This is a smaller model finetuned to convert English statement to SQL queries. 
The demo only shows performance of model when fed with database schema only.
Maybe this will have improved performance when fine tuned with own english-to-sql dataset.