In [5]:
from langchain_community.llms import Ollama
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from sqlalchemy.exc import OperationalError

llm = Ollama(model='llama3')


db = SQLDatabase.from_uri("sqlite:///../sqlite_db/northwind.db")


def get_db_schema(_):
    return db.get_table_info()


def run_query(query):
    try:
        return db.run(query)
    except (OperationalError, Exception) as e:
        return str(e)


gen_sql_prompt = ChatPromptTemplate.from_messages([
    ('system', 'Based on the table schema below, write a SQL query that would answer the user\'s question: {db_schema}'),
    ('user', 'Please generate a SQL query for the following question: "{input}". \
     The query should be formatted as follows without any additional explanation: \
     SQL> <sql_query>\
    '),
])


class SqlQueryParser(StrOutputParser):
    def parse(self, s):
        r = s.split('SQL> ')
        if len(r) > 0:
            return r[1]
        return s


gen_query_chain = (
    RunnablePassthrough.assign(db_schema=get_db_schema)
    | gen_sql_prompt
    | llm
    | SqlQueryParser()
)


gen_answer_prompt = ChatPromptTemplate.from_template("""
Based on the provided question, SQL query, and query result, write a natural language response.
No additional explanations should be included.

Question: {input}
SQL Query: {query}
Query Result: {result}

The response should be formatted as follows:
'''
Executed: {query}
Answer: <answer>
'''
""")


chain = (
    RunnablePassthrough.assign(query=gen_query_chain).assign(
        result=lambda x: run_query(x["query"]),
    )
    | gen_answer_prompt
    | llm
)

input_text = input('>>> ')
while input_text.lower() != 'bye':
    if input_text:
        response = chain.invoke({
            'input': input_text,
        })
        print(response)
    input_text = input('>>> ')


CREATE TABLE "Categories" (
	"CategoryID" INTEGER, 
	"CategoryName" TEXT, 
	"Description" TEXT, 
	"Picture" BLOB, 
	PRIMARY KEY ("CategoryID")
)

/*
3 rows from Categories table:
CategoryID	CategoryName	Description	Picture
1	Beverages	Soft drinks, coffees, teas, beers, and ales	b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x02\x00\x00d\x00d\x00\x00\xff\xec\x00\x11Ducky\x00\x01\x00\x0
2	Condiments	Sweet and savory sauces, relishes, spreads, and seasonings	b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x02\x00\x00d\x00d\x00\x00\xff\xec\x00\x11Ducky\x00\x01\x00\x0
3	Confections	Desserts, candies, and sweet breads	b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x02\x00\x00d\x00d\x00\x00\xff\xec\x00\x11Ducky\x00\x01\x00\x0
*/


CREATE TABLE "CustomerCustomerDemo" (
	"CustomerID" TEXT NOT NULL, 
	"CustomerTypeID" TEXT NOT NULL, 
	PRIMARY KEY ("CustomerID", "CustomerTypeID"), 
	FOREIGN KEY("CustomerTypeID") REFERENCES "CustomerDemographics" ("CustomerTypeID"), 
	FOREIGN KEY("CustomerID") REFERENCES "Customers" ("Cu

IndexError: list index out of range