In [None]:
from langchain_community.llms import LlamaCpp 
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
from langchain_core.prompts import PromptTemplate
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine
import pandas as pd
import re

In [None]:

prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- SQL Dialect is SQLite
- Generate only one variation of query
- Do not suggest alternative varions for the query 

### Database Schema
This query will run on a database whose schema is represented in this string:
{db_schema}
### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [None]:
# Callbacks support streaming output 
callback = CallbackManager([StreamingStdOutCallbackHandler()])

In [None]:
db_file_path = '/home/shamit/proj/genai/data/Chinook_Sqlite.sqlite'
db_uri = "sqlite:///" + db_file_path
db = SQLDatabase.from_uri(db_uri,sample_rows_in_table_info=0)
db_engine = create_engine(db_uri)
db_schema_str = db.get_table_info()

In [None]:
print(db_schema_str)

In [None]:
llm_file_path = '/home/shamit/proj/models/Phi-3-mini-4k-instruct-q4.gguf'

In [None]:
llm = LlamaCpp(
        model_path=llm_file_path,
        n_ctx=4096,
        temperature=0,
        seed=4381,
        max_tokens=10000,
        verbose=True,  # Verbose is required to pass to the callback manager
        streaming=True,
        callback_manager=callback
    )

In [None]:
inp = prompt.format(question="Which customer generated max sales?",db_schema=db_schema_str)
output = llm.invoke(inp)

In [None]:
queries = re.findall('```sql(.*)```', output,flags=re.DOTALL)

In [None]:
queries 

In [None]:
qry = None
df = None
if len(queries) > 0:
    qry = queries[0].strip()
    print(qry)
    df = pd.read_sql(qry, db_engine)

In [None]:
df

In [None]:
def answer_question(question):
    inp = prompt.format(question=question,db_schema=db_schema_str)
    output = llm.invoke(inp)
    queries = re.findall('```sql(.*)```', output,flags=re.DOTALL)
    if len(queries) == 0:
        queries = re.findall('SELECT .*;', output,flags=re.DOTALL)
    if len(queries) > 0:
        qry = queries[0].strip()
        print(qry)
        df = pd.read_sql(qry, db_engine)
        return df 
    

In [None]:
df = answer_question("Which customer generated max sales")

In [None]:
df

In [None]:
df = answer_question("Which album generated max sales")