In [180]:
from langchain import PromptTemplate, HuggingFaceHub, LLMChain
import os
from dotenv import load_dotenv
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
load_dotenv()


True

In [266]:
# construct index 
loader = TextLoader('./schema.txt')
tables = loader.load()

table_splitter = CharacterTextSplitter(
    separator = "###",
    chunk_size= 500,
    chunk_overlap = 0,
    length_function = len
)

docs = table_splitter.split_documents(tables)
persist_directory = './db'

embeddings = OpenAIEmbeddings()
db = Chroma.from_documents(documents=docs, embedding=embeddings, persist_directory=persist_directory)
db.persist()

Created a chunk of size 735, which is longer than the specified 500
Created a chunk of size 603, which is longer than the specified 500


In [267]:
persist_directory = './db'
question = "Show me customers with amount greater than $100000"

embedding = OpenAIEmbeddings()
vector_store = Chroma(persist_directory=persist_directory, embedding_function=embedding)
docs = vector_store.similarity_search(question)

In [268]:
instructions="""
New York City is named NYC
city name is city column of the offices table
use addressline1 from offices table for all address questions
offices does not have customerNumber column
please check the column name before including it in your query 
Only include city if asked specifically
amount is in payments table
orderNumber is not in payments table
amount is prefixed by $ sign in the input question
"""

PROMPT_SUFFIX = """Only use as few tables from the following as possible:
{context}
Question: {question}"""

_DEFAULT_TEMPLATE = f"""You're a PostgreSQL expert. Given an input question, first identify the relevant tables and columns that are needed to answer, then solve step by step to create a syntactically correct PostgreSQL query. You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only search for relevant columns given the question.
Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Follow the instructions one by one: {instructions}
There migth be some errors in the query, so please correct them and rewrite the query.
You MUST double check your query before giving the final answer and give a score (1-100) on how well the query answers the input question.
Format your response as indented JSON with fields Reasoning, SQL, Score and Explanation and do not format the SQL query. Make sure the response can be parsed by json.loads()
json
"""

prompt_template = PROMPT_SUFFIX + _DEFAULT_TEMPLATE

prompt_template_with_query = PromptTemplate(
        template=prompt_template, input_variables=["question", "context"]
)


In [269]:
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
import json
chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", prompt=prompt_template_with_query)
res = chain({"input_documents": docs, "question": question}, return_only_outputs=True)
query = json.loads(res['output_text'])
print(query['SQL'])

SELECT customerNumber, amount FROM payments WHERE amount > 100000 ORDER BY amount DESC;


In [270]:
instructions="""
New York City is named NYC
city name is city column of the offices table
use addressline1 from offices table for all address questions
offices does not have customerNumber column
please check the column name before including it in your query
Only include city if asked specifically
amount is in payments table
"""

PROMPT_SUFFIX = """I asked my friend (who is a self proclaimed genius) for help for writing an SQL query for the following question - {question}
This is the answer I got - {query}
I know you are way smarter than him. Did the query answer the question? Please correct any errors and rewrite the query.
Only use as few tables from the following as possible:
{context}"""

_DEFAULT_TEMPLATE = f"""

Follow the instructions if you have doubts: {instructions}
You MUST double check your query before giving the final answer and give a score (1-100) on how well the query answers the input question.
Format your response as indented JSON with fields Reasoning, SQL, Score and Explanation and do not format the SQL query. Make sure the response can be parsed by json.loads()
json
"""

_recheck_prompt_template = PROMPT_SUFFIX + _DEFAULT_TEMPLATE

recheck_prompt_template = PromptTemplate(
        template=_recheck_prompt_template, input_variables=["question", "query", "context"]
    )

In [271]:
import mysql.connector
import decimal

def dec_serializer(o):
    if isinstance(o, decimal.Decimal):
        return float(o)
    
mydb = mysql.connector.connect(
  host="demo-database-1.cc8nminyw6rb.us-east-1.rds.amazonaws.com",
  user="admin",
  password="test1234",
  database="classicmodels"
)

cursor = mydb.cursor(dictionary=True)
cursor.execute(query['SQL'])
result = cursor.fetchall()

print(json.dumps(result, default=dec_serializer, indent=4))



[
    {
        "customerNumber": 141,
        "amount": 120166.58
    },
    {
        "customerNumber": 141,
        "amount": 116208.4
    },
    {
        "customerNumber": 124,
        "amount": 111654.4
    },
    {
        "customerNumber": 148,
        "amount": 105743.0
    },
    {
        "customerNumber": 124,
        "amount": 101244.59
    }
]


In [122]:
chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", prompt=recheck_prompt_template)
res = chain({"input_documents": docs, "question": question, "query": query}, return_only_outputs=True)
print(res)

{'output_text': '\n{\n  \'Reasoning\': \'New York City is named NYC and the address is in the addressLine1 column of the offices table.\', \n  \'SQL\': "SELECT addressLine1 FROM offices WHERE city = \'NYC\';", \n  \'Score\': 100, \n  \'Explanation\': \'This query correctly identifies the relevant table and column and returns the address of the NYC office.\'\n}'}


In [16]:
template = """Question: {question}

Answer: Let's think step by tesmp. """

prompt = PromptTemplate(template=template, input_variables=["question"])

In [17]:
llm_chain = LLMChain(prompt=prompt,
                    llm=HuggingFaceHub(repo_id="google/flan-t5-xl",
                                      model_kwargs={"temperature":0,
                                                   "max_length":64}))

In [18]:
question = "what is the captial of France?"

In [19]:
print(llm_chain.run(question))

KeyboardInterrupt: 