We need to use Python 3.10

In [None]:
%pip install torch transformers bitsandbytes accelerate sqlparse pystarburst python-dotenv

In [None]:
import torch
import os
import trino
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv

In [None]:
from trino.dbapi import connect

load_dotenv()

conn = connect(
    host = os.environ.get("host"),
    port=os.environ.get("port"),
    http_scheme=os.environ.get("http_scheme"),
    catalog=os.environ.get("catalog"),
    schema=os.environ.get("schema"),
    auth= trino.auth.BasicAuthentication(os.environ.get("username"),os.environ.get("password"))
)

cur = conn.cursor()
cur.execute("show create table sample.burstbank.account")
rows = cur.fetchall()
schema_string = ""
for elem in rows:
    for elem1 in elem:
        schema_string = elem1
# print(myString)

# get a sample of the data        
cur = conn.cursor()
cur.execute("select * from sample.burstbank.account limit 10")
column_names = [tup[0] for tup in cur.description]
print(column_names)

rows = cur.fetchall()
table_string = ""
for row in enumerate(rows):
    t = ''.join(map(str, row))
    t = t[1:]
    table_string = table_string + t +"\n"

print(table_string)

In [None]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        cache_dir="./models/"
)

prompt = """### Task
Generate a SQL query that can run in Presto or Trino to answer [QUESTION]{question}[/QUESTION]


### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- The query must be able to run in Trino


### Database Schema
This query will run on a database whose schema is represented in this string:
"""+schema_string+"""

### Sample Data
This query will run on a database whose first 10 rows of data is represented in this string:
"""+table_string+"""

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

print(prompt)

In [27]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)


    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [28]:
question = "What is the average credit card balance?"
generated_sql = generate_query(question)
print(generated_sql)

In [None]:
generated_sql = generated_sql[:-1]
cur.execute(generated_sql)
rows = cur.fetchall()

print(rows)