We need to use Python 3.10

In [6]:
%pip install torch transformers bitsandbytes accelerate sqlparse pystarburst python-dotenv

Note: you may need to restart the kernel to use updated packages.


In [7]:
import torch
import os
import trino
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv

In [10]:
from trino.dbapi import connect

load_dotenv()

conn = connect(
    host = os.environ.get("host"),
    port=os.environ.get("port"),
    http_scheme=os.environ.get("http_scheme"),
    catalog=os.environ.get("catalog"),
    schema=os.environ.get("schema"),
    auth= trino.auth.BasicAuthentication(os.environ.get("username"),os.environ.get("password"))
)

cur = conn.cursor()
cur.execute("show create table sample.burstbank.account")
rows = cur.fetchall()
schema_string = ""
for elem in rows:
    for elem1 in elem:
        schema_string = elem1
# print(myString)

# get a sample of the data        
cur = conn.cursor()
cur.execute("select * from sample.burstbank.account limit 10")
column_names = [tup[0] for tup in cur.description]
print(column_names)

rows = cur.fetchall()
table_string = ""
for row in enumerate(rows):
    t = ''.join(map(str, row))
    t = t[1:]
    table_string = table_string + t +"\n"

print(table_string)

['custkey', 'acctkey', 'products', 'cc_number', 'cc_open_date', 'cc_closed_date', 'cc_balance', 'cc_status', 'cc_default', 'mortgage_id', 'mortgage_open_date', 'mortgage_closed_date', 'mortgage_balance', 'mortgage_status', 'mortgage_default', 'auto_loan_id', 'auto_loan_open_date', 'auto_loan_closed_date', 'auto_loan_balance', 'auto_loan_status', 'auto_loan_default']
['1000001', '1217470', 'credit_card,auto_loan', '180045349625167', '2000-07-03', None, 9209.9, 'open', 'N', None, None, None, None, None, None, '5876198', '2017-05-12', '2018-12-13', None, 'closed', 'Y']
['1000002', '1217471', 'credit_card,mortgage,auto_loan', '180086982231350', '2002-07-31', None, 385.68, 'open', 'N', '4649851', '2003-03-05', None, 29175.75, 'open', 'N', '5876199', '2019-07-23', None, 97687.63, 'open', 'N']
['1000003', '1217472', 'credit_card,mortgage', '676129241615', '2014-10-22', '2019-09-22', 0.0, 'closed', 'N', '4649852', '1997-10-20', '2010-07-24', None, 'closed', 'Y', None, None, None, None, None, N

In [11]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        cache_dir="./models/"
)

prompt = """### Task
Generate a SQL query that can run in Presto or Trino to answer [QUESTION]{question}[/QUESTION]


### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- The query must be able to run in Trino


### Database Schema
This query will run on a database whose schema is represented in this string:
"""+schema_string+"""

### Sample Data
This query will run on a database whose first 10 rows of data is represented in this string:
"""+table_string+"""

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

print(prompt)

tokenizer_config.json: 100%|██████████| 1.84k/1.84k [00:00<00:00, 5.16MB/s]
tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 6.73MB/s]
tokenizer.json: 100%|██████████| 1.84M/1.84M [00:00<00:00, 6.13MB/s]
special_tokens_map.json: 100%|██████████| 515/515 [00:00<00:00, 3.24MB/s]
config.json: 100%|██████████| 691/691 [00:00<00:00, 4.48MB/s]
model.safetensors.index.json: 100%|██████████| 23.9k/23.9k [00:00<00:00, 3.21MB/s]
model-00001-of-00003.safetensors: 100%|██████████| 4.94G/4.94G [05:59<00:00, 13.7MB/s]
model-00002-of-00003.safetensors: 100%|██████████| 4.95G/4.95G [06:30<00:00, 12.7MB/s]
model-00003-of-00003.safetensors: 100%|██████████| 3.59G/3.59G [04:24<00:00, 13.5MB/s]
Downloading shards: 100%|██████████| 3/3 [16:55<00:00, 338.66s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:25<00:00,  8.55s/it]
generation_config.json: 100%|██████████| 111/111 [00:00<00:00, 392kB/s]

### Task
Generate a SQL query that can run in Presto or Trino to answer [QUESTION]{question}[/QUESTION]


### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- The query must be able to run in Trino


### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE sample.burstbank.account (
   custkey varchar,
   acctkey varchar,
   products varchar,
   cc_number varchar,
   cc_open_date varchar,
   cc_closed_date varchar,
   cc_balance double,
   cc_status varchar,
   cc_default varchar,
   mortgage_id varchar,
   mortgage_open_date varchar,
   mortgage_closed_date varchar,
   mortgage_balance double,
   mortgage_status varchar,
   mortgage_default varchar,
   auto_loan_id varchar,
   auto_loan_open_date varchar,
   auto_loan_closed_date varchar,
   auto_loan_balance double,
   auto_loan_status varchar,
   auto_loan_default varchar
)
WITH (
   external_location = 's3://galaxy-spa




In [12]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("mps")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)


    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [13]:
question = "What is the average credit card balance?"
generated_sql = generate_query(question)
print(generated_sql)

In [None]:
generated_sql = generated_sql[:-1]
cur.execute(generated_sql)
rows = cur.fetchall()

print(rows)