We need to use Python 3.10

In [1]:
%pip install torch transformers bitsandbytes accelerate sqlparse pystarburst python-dotenv

Collecting sqlparse
  Using cached sqlparse-0.4.4-py3-none-any.whl.metadata (4.0 kB)
Collecting pystarburst
  Using cached pystarburst-0.7.0-py3-none-any.whl.metadata (2.8 kB)
Collecting trino<0.328.0,>=0.327.0 (from pystarburst)
  Using cached trino-0.327.0-py3-none-any.whl.metadata (17 kB)
Collecting urllib3<3.0.0,>=2.2.0 (from pystarburst)
  Using cached urllib3-2.2.1-py3-none-any.whl.metadata (6.4 kB)
Using cached sqlparse-0.4.4-py3-none-any.whl (41 kB)
Using cached pystarburst-0.7.0-py3-none-any.whl (130 kB)
Using cached trino-0.327.0-py3-none-any.whl (49 kB)
Using cached urllib3-2.2.1-py3-none-any.whl (121 kB)
Installing collected packages: urllib3, sqlparse, trino, pystarburst
  Attempting uninstall: urllib3
    Found existing installation: urllib3 1.26.6
    Uninstalling urllib3-1.26.6:
      Successfully uninstalled urllib3-1.26.6
Successfully installed pystarburst-0.7.0 sqlparse-0.4.4 trino-0.327.0 urllib3-2.2.1
Note: you may need to restart the kernel to use updated packages

In [2]:
import torch
import os
import trino
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
from trino.dbapi import connect

load_dotenv()

conn = connect(
    host = os.environ.get("host"),
    port=os.environ.get("port"),
    http_scheme=os.environ.get("http_scheme"),
    catalog=os.environ.get("catalog"),
    schema=os.environ.get("schema"),
    auth= trino.auth.BasicAuthentication(os.environ.get("username"),os.environ.get("password"))
)

cur = conn.cursor()
cur.execute("show create table sample.burstbank.account")
rows = cur.fetchall()
schema_string = ""
for elem in rows:
    for elem1 in elem:
        schema_string = elem1
# print(myString)

# get a sample of the data        
cur = conn.cursor()
cur.execute("select * from sample.burstbank.account limit 10")
column_names = [tup[0] for tup in cur.description]
print(column_names)

rows = cur.fetchall()
table_string = ""
for row in enumerate(rows):
    t = ''.join(map(str, row))
    t = t[1:]
    table_string = table_string + t +"\n"

print(table_string)

['custkey', 'acctkey', 'products', 'cc_number', 'cc_open_date', 'cc_closed_date', 'cc_balance', 'cc_status', 'cc_default', 'mortgage_id', 'mortgage_open_date', 'mortgage_closed_date', 'mortgage_balance', 'mortgage_status', 'mortgage_default', 'auto_loan_id', 'auto_loan_open_date', 'auto_loan_closed_date', 'auto_loan_balance', 'auto_loan_status', 'auto_loan_default']
['1000001', '1217470', 'credit_card,auto_loan', '180045349625167', '2000-07-03', None, 9209.9, 'open', 'N', None, None, None, None, None, None, '5876198', '2017-05-12', '2018-12-13', None, 'closed', 'Y']
['1000002', '1217471', 'credit_card,mortgage,auto_loan', '180086982231350', '2002-07-31', None, 385.68, 'open', 'N', '4649851', '2003-03-05', None, 29175.75, 'open', 'N', '5876199', '2019-07-23', None, 97687.63, 'open', 'N']
['1000003', '1217472', 'credit_card,mortgage', '676129241615', '2014-10-22', '2019-09-22', 0.0, 'closed', 'N', '4649852', '1997-10-20', '2010-07-24', None, 'closed', 'Y', None, None, None, None, None, N

In [23]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        cache_dir="./models/"
)

prompt = """### Task
Generate a SQL query that can run in Presto or Trino to answer [QUESTION]{question}[/QUESTION]


### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- The query must be able to run in Trino


### Database Schema
This query will run on a database whose schema is represented in this string:
"""+schema_string+"""

### Sample Data
This query will run on a database whose first 10 rows of data is represented in this string:
"""+table_string+"""

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

print(prompt)

Loading checkpoint shards: 100%|██████████| 3/3 [00:10<00:00,  3.49s/it]


### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

The SQL needs to be run in Trino

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE sample.burstbank.account (
   custkey varchar,
   acctkey varchar,
   products varchar,
   cc_number varchar,
   cc_open_date varchar,
   cc_closed_date varchar,
   cc_balance double,
   cc_status varchar,
   cc_default varchar,
   mortgage_id varchar,
   mortgage_open_date varchar,
   mortgage_closed_date varchar,
   mortgage_balance double,
   mortgage_status varchar,
   mortgage_default varchar,
   auto_loan_id varchar,
   auto_loan_open_date varchar,
   auto_loan_closed_date varchar,
   auto_loan_balance double,
   auto_loan_status varchar,
  

In [24]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)


    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [25]:
question = "What is the average credit card balance?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT AVG(a.cc_balance) AS average_credit_card_balance
FROM sample.burstbank.account a
WHERE a.products ilike '%credit_card%';


In [22]:
generated_sql = generated_sql[:-1]
cur.execute(generated_sql)
rows = cur.fetchall()

print(rows)

TrinoUserError: TrinoUserError(type=USER_ERROR, name=SYNTAX_ERROR, message="line 4:18: mismatched input 'ilike'. Expecting: '%', '*', '+', '-', '.', '/', 'AND', 'AT', 'EXCEPT', 'FETCH', 'GROUP', 'HAVING', 'INTERSECT', 'LIMIT', 'OFFSET', 'OR', 'ORDER', 'UNION', 'WINDOW', '[', '||', <EOF>, <predicate>", query_id=20240229_090430_01129_pekug)