- https://github.com/langchain-ai/langchain/blob/master/cookbook/LLaMA2_sql_chat.ipynb

In [1]:
# model 인스턴스 생성
from langchain_community.chat_models import ChatOpenAI
llm_text = ChatOpenAI(model="gpt-3.5-turbo")


In [3]:
# model 기본 테스트
from langchain_core.runnables import RunnablePassthrough
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# prompt = ChatPromptTemplate.from_template("대한민국의 수도는 어디입니까?")
# llm_text.invoke(prompt)
# chain = prompt | llm_text | StrOutputParser()
# chain.invoke({})

prompt = ChatPromptTemplate.from_template("{country}의 수도는 어디입니까?")

chain = (
  { "country": RunnablePassthrough()}
  | prompt
  | llm_text
  | StrOutputParser()
)

chain.invoke("프랑스")


'프랑스의 수도는 파리입니다.'

In [4]:
llm = llm_text

### DB
Connect to a SQLite DB.
- SQLDatabaseChain의  사용하지 않을 경우, SQL 생성 시 AI의 전문(preamble)이 들어가서 정상적인 Query가 수행되지 않는
경우 있음. (특히 chat시)
- 큰 DB는 SQLDatabaseSequenceChain 
- import 시 path 주의할 것. (계속 변경됨)

In [5]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain

db = SQLDatabase.from_uri("sqlite:///nba_roster.db", sample_rows_in_table_info=0)
db_chain = SQLDatabaseChain.from_llm(llm, db)
db_chain.return_sql=True

def get_schema(_):
  return db.get_table_info()

def run_query(query):
  return db_chain.run(query)

In [6]:
# Prompt
from langchain.prompts import ChatPromptTemplate

# Update the template based on the type of SQL Database like MySQL, Microsoft SQL Server and so on
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "Given an input question, convert it to a SQL query without preamble."),
        ("human", template),
    ]
)

# Chain to query
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

# sql_response.invoke({"question": "What team is Klay Thompson on?"})
sql_response.invoke({"question": " Klay Thompson이 속한 팀은?"})

"SELECT Team\nFROM nba_roster\nWHERE NAME = 'Klay Thompson'"

In [7]:
# Chain to answer
template = """Based on the table schema below, question, sqlite sql query, and sqlite sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}
"""

prompt_response = ChatPromptTemplate.from_messages(
  [
    (
      "system",
      "You're sqlite SQL Expert. Given an input question and SQL response, convert it to a natural language answer. No pre-amble"
    ),
    ("human", template)
  ]
)

full_chain = (
  RunnablePassthrough.assign(query=sql_response)
  | RunnablePassthrough.assign(
    schema=get_schema,
    response=lambda x: db_chain.run(x["query"]),
  )
  | prompt_response
  | llm
)

# full_chain.invoke({"question": "How many unique teams are there?"})
full_chain.invoke({"question": "고유한 소속팀의 수는?"})

AIMessage(content='The number of unique teams in the NBA roster table is represented by the result of the SQL query as "Unique_Teams".')

### Chat with a SQL DB

In [8]:
# Prompt
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder

template = """Given an input question, convert it to a SQL query only. Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
"""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", template),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ]
)

memory = ConversationBufferMemory(return_messages=True)

# Chain to query with memory
from langchain_core.runnables import RunnableLambda

sql_chain = (
    RunnablePassthrough.assign(
        schema=get_schema,
        history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
    )
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


def save(input_output):
    output = {"output": input_output.pop("output")}
    memory.save_context(input_output, output)
    return output["output"]


sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save
# sql_response_memory.invoke({"question": "What team is Klay Thompson on?"})
sql_response_memory.invoke({"question": "Klay Thompson이 속한 팀은?"})

"SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson'"

In [9]:
# Chain to answer
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}


Question: {question}
SQL Query: {query}
SQL Response: {response}
"""
prompt_response = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given an input question and SQL response, convert it to a natural language answer. No preamble.",
        ),
        ("human", template),
    ]
)

full_chain = (
    RunnablePassthrough.assign(query=sql_response_memory)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db_chain.run(x["query"]),
    )
    | prompt_response
    | llm
)

# full_chain.invoke({"question": "What is his salary?"})
full_chain.invoke({"question": "그의 월급은?"})

AIMessage(content='His salary is $15,500,000.')