In [None]:
%pip install torch transformers bitsandbytes accelerate sqlparse pystarburst python-dotenv

In [16]:
import torch
import os
import trino
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv

In [None]:
from trino.dbapi import connect

load_dotenv()

conn = connect(
    host = os.environ.get("host"),
    port=os.environ.get("port"),
    http_scheme=os.environ.get("http_scheme"),
    catalog=os.environ.get("catalog"),
    schema=os.environ.get("schema"),
    auth= trino.auth.BasicAuthentication(os.environ.get("username"),os.environ.get("password"))
)

cur = conn.cursor()
cur.execute("show create table sample.burstbank.account")
rows = cur.fetchall()
myString = ""
for elem in rows:
    for elem1 in elem:
        myString = elem1
# print(myString)

In [None]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        cache_dir="./models/"
)

prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:
"""+myString+"""
### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

#print(prompt)

In [20]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cpu")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [None]:
question = "What is the average credit card balance?"
generated_sql = generate_query(question)
print(generated_sql)

In [22]:
generated_sql = generated_sql[:-1]
cur.execute(generated_sql)
rows = cur.fetchall()

print(rows)

[[None]]
