In [150]:
from langchain.prompts import ChatPromptTemplate

template = """Given an table schema, use sqlite syntax to generate a sql query by choosing 
one or multiple of the following tables.

For this Problem you can use the following table Schema:
<table_schema>
{schema}
</table_schema

Please provide the SQL query without comment for this question: 
<question>
{question}
</question>

Wirite only query statement correspond to the above question without any comment, description, preamble.
SQL Query:"""

prompt = ChatPromptTemplate.from_template(template)

In [151]:
from langchain_community.chat_models import BedrockChat
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

import boto3

session = boto3.Session()
bedrock_cli = session.client(
  service_name='bedrock-runtime',
  region_name='us-east-1',
  endpoint_url=None
  )

llm = BedrockChat(
  model_id="anthropic.claude-v2:1",
  client=bedrock_cli,
  model_kwargs={"temperature": 0},
  streaming=True,
  callbacks=[StreamingStdOutCallbackHandler()],
)



In [152]:
from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

db = SQLDatabase.from_uri("sqlite:///./Chinook.db")
db_chain = SQLDatabaseChain.from_llm(llm, db)

def get_schema(_):
    return db.get_table_info()

def run_query(query):
    return db.run(query)

In [153]:
sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\n\nSQLResult:"])
    | StrOutputParser()
)

In [154]:
sql_response.invoke({"question": "How many employees are there?"})

 Here is the SQL query to get the number of employees:

SELECT COUNT(*) AS num_employees 
FROM Employee

' Here is the SQL query to get the number of employees:\n\nSELECT COUNT(*) AS num_employees \nFROM Employee'

In [158]:
def extract_sql(query):
  return query.split(':')[1]

extract_sql(sql_response.invoke({"question": "How many employees are there?"}))

 Here is the SQL query to get the number of employees:

SELECT COUNT(*) AS num_employees 
FROM Employee

'\n\nSELECT COUNT(*) AS num_employees \nFROM Employee'

In [155]:
template = """Based on the table schema, question, sql query, and sql response, write a natural language response:

{schema}

Question: {question}

use only query statement between <SQL> and </SQL> tag:
SQL Query: {query}
SQL Response: {response}"""

prompt_response = ChatPromptTemplate.from_template(template)

full_chain = (
    RunnablePassthrough.assign(query=sql_response).assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | llm
)

In [156]:
full_chain.invoke({"question": "How many employees are there?"})

 Here is the SQL query to get the number of employees:

SELECT COUNT(*) AS num_employees 
FROM Employee

OperationalError: (sqlite3.OperationalError) near "Here": syntax error
[SQL:  Here is the SQL query to get the number of employees:

SELECT COUNT(*) AS num_employees 
FROM Employee]
(Background on this error at: https://sqlalche.me/e/20/e3q8)