In [6]:
from langchain.prompts import ChatPromptTemplate

template = """Given an table schema, use sqlite syntax to generate a sql query by choosing 
one or multiple of the following tables.

For this Problem you can use the following table Schema:
<table_schema>
{schema}
</table_schema

Please provide the SQL query without comment for this question: 
<question>
{question}
</question>

Wirite only query statement simply between <SQL> and </SQL> tags correspond to the above question. No preface
SQL Query:"""

prompt = ChatPromptTemplate.from_template(template)

In [7]:
from langchain_community.chat_models import BedrockChat
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda

import boto3

session = boto3.Session()
bedrock_cli = session.client(
  service_name='bedrock-runtime',
  region_name='us-east-1',
  endpoint_url=None
  )

llm = BedrockChat(
  model_id="anthropic.claude-v2:1",
  client=bedrock_cli,
  model_kwargs={"temperature": 0.1},
)


In [8]:
from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
db_chain = SQLDatabaseChain.from_llm(llm, db)

def get_schema(_):
    return db.get_table_info()

def run_query(query):
    return db.run(query)

# claude는 query 앞에 별도의 설명이 붙기 때문에 처리 필요
# prompt에서 생성된 query는 <SQl></SQL> tag로 묶음.
import xml.etree.ElementTree as ET

def claude_query(query):
    text = "<root>" + query + "</root>"
    # Parsing the text as XML
    root = ET.fromstring(text)
    sql = root.find('SQL').text
    return sql

In [9]:

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\n\nSQLResult"])
    | StrOutputParser()
    | RunnableLambda(claude_query)
)

In [10]:
sql_response.invoke({"question": "How many employees are there?"})

'\nSELECT COUNT(*) AS num_employees \nFROM Employee\n'

In [12]:
template = """Based on the table schema, question, sql query, and sql response, write a natural language response:

{schema}

Question: {question}

use only query statement between <SQL> and </SQL> tag:
SQL Query: {query}
SQL Response: {response}"""

prompt_response = ChatPromptTemplate.from_template(template)


In [13]:
full_chain = (
    RunnablePassthrough.assign(query=sql_response).assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | llm
)

In [16]:
from langchain.callbacks import StdOutCallbackHandler
full_chain.invoke({"question": "How many employees are there?"} )
# full_chain.invoke({"question": "How many employees are there?"}, config={'callbacks': [StdOutCallbackHandler()]})



[1m> Entering new RunnableSequence chain...[0m


[1m> Entering new RunnableAssign chain...[0m


[1m> Entering new RunnableParallel chain...[0m


[1m> Entering new RunnableSequence chain...[0m


[1m> Entering new RunnableAssign chain...[0m


[1m> Entering new RunnableParallel chain...[0m


[1m> Entering new RunnableLambda chain...[0m

[1m> Finished chain.[0m

[1m> Finished chain.[0m

[1m> Finished chain.[0m


[1m> Entering new ChatPromptTemplate chain...[0m

[1m> Finished chain.[0m


[1m> Entering new StrOutputParser chain...[0m

[1m> Finished chain.[0m


[1m> Entering new RunnableLambda chain...[0m

[1m> Finished chain.[0m

[1m> Finished chain.[0m

[1m> Finished chain.[0m

[1m> Finished chain.[0m


[1m> Entering new RunnableAssign chain...[0m


[1m> Entering new RunnableParallel chain...[0m


[1m> Entering new RunnableLambda chain...[0m


[1m> Entering new RunnableLambda chain...[0m

[1m> Finished chain.[0m

[1m> Finished chain.[0m



AIMessage(content=' <SQL>\nSELECT COUNT(*) AS num_employees \nFROM Employee\n</SQL>\n\nThe SQL query counts the total number of rows in the Employee table. The response shows there are 8 employees.')