In [11]:
from langchain.chains import create_sql_query_chain
from langchain.chat_models import ChatOpenAI
from langchain.utilities import SQLDatabase
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough

import os


# Setup for the database
db = SQLDatabase.from_uri("sqlite:///./company.db")

In [16]:
from langchain.prompts import PromptTemplate

TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Some examples of SQL queries that correspond to questions are:


Question: {input}"""

CUSTOM_PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "dialect"],
    template=TEMPLATE,
)

In [20]:
sql_query_chain = create_sql_query_chain(ChatOpenAI(), db)
sql_query_chain.dict


<bound method BaseModel.dict of {
  input: RunnableLambda(...),
  top_k: RunnableLambda(...),
  table_info: RunnableLambda(...)
}
| PromptTemplate(input_variables=['input', 'table_info', 'top_k'], template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column 

In [21]:
sql_query_chain = create_sql_query_chain(ChatOpenAI(), db, prompt=CUSTOM_PROMPT)
sql_query_chain.dict


<bound method BaseModel.dict of {
  input: RunnableLambda(...),
  top_k: RunnableLambda(...),
  table_info: RunnableLambda(...),
  dialect: RunnableLambda(lambda _: (db.dialect, prompt_to_use))
}
| PromptTemplate(input_variables=['dialect', 'input', 'table_info'], template='Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.\nUse the following format:\n\nQuestion: "Question here"\nSQLQuery: "SQL Query to run"\nSQLResult: "Result of the SQLQuery"\nAnswer: "Final answer here"\n\nOnly use the following tables:\n\n{table_info}.\n\nSome examples of SQL queries that correspond to questions are:\n\n\nQuestion: {input}')
| RunnableBinding(bound=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x000001990FCF3820>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x000001990ECC33D0>, openai_api_key='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'

In [30]:

# Create the SQL query chain with ChatOpenAI

# Function to get the schema
def get_schema(_):
    return db.get_table_info()

# Template for the natural language response
response_template = ChatPromptTemplate.from_template("""
Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}""")

# Full chain combining SQL query generation and natural language response generation
full_chain = (
    RunnablePassthrough.assign(query=sql_query_chain, schema=get_schema)
    | RunnablePassthrough.assign(response=lambda x: db.run(x["query"]))
    | response_template
    | ChatOpenAI()
)

# Example invocation
result = full_chain.invoke({"question": "How many employees are there?"})
result

RateLimitError: Error code: 429 - {'error': {'message': 'Rate limit reached for gpt-3.5-turbo in organization org-acRzOCBCV2hUntscAm83omST on requests per min (RPM): Limit 3, Used 3, Requested 1. Please try again in 20s. Visit https://platform.openai.com/account/rate-limits to learn more. You can increase your rate limit by adding a payment method to your account at https://platform.openai.com/account/billing.', 'type': 'requests', 'param': None, 'code': 'rate_limit_exceeded'}}

In [24]:
from langchain.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [25]:
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [27]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [28]:
full_chain = (
    RunnablePassthrough.assign(query=sql_response)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | model
)

In [29]:
full_chain.invoke({"question": "How many employees are there?"})

AIMessage(content='There are 9 employees.')