In [38]:
# model 인스턴스 생성
from langchain.llms import Bedrock
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler


llm_text = Bedrock(
  model_id="anthropic.claude-v2:1",
  region_name= 'us-east-1',
  endpoint_url=None,
  model_kwargs={
    "max_tokens_to_sample": 512,
    "temperature": 0,
    "top_k": 256,
    "top_p": 1,
    "stop_sequences": ["\n\nHuman:"],
  },
  streaming=True,
  callbacks=[StreamingStdOutCallbackHandler()]
)


In [39]:
# model 기본 테스트
from langchain_core.runnables import RunnablePassthrough
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# prompt = ChatPromptTemplate.from_template("대한민국의 수도는 어디입니까?")
# llm_text.invoke(prompt)
# chain = prompt | llm_text | StrOutputParser()
# chain.invoke({})

prompt = ChatPromptTemplate.from_template("{country}의 수도는 어디입니까?")

chain = (
  { "country": RunnablePassthrough()}
  | prompt
  | llm_text
  | StrOutputParser()
)

resp=chain.invoke("프랑스")


 프랑스의 수도는 파리입니다.

In [None]:
llm = llm_text

### DB
Connect to a SQLite DB.

In [None]:
from langchain.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///nba_roster.db", sample_rows_in_table_info=0)

def get_schema(_):
  return db.get_table_info()

def run_query(query):
  return db.run(query)

In [40]:
# Prompt
from langchain.prompts import ChatPromptTemplate

# Update the template based on the type of SQL Database like MySQL, Microsoft SQL Server and so on
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "Given an input question, convert it to a SQL query. No pre-amble."),
        ("human", template),
    ]
)

# Chain to query
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

sql_response.invoke({"question": "What team is Klay Thompson on?"})

 SELECT "Team"
FROM nba_roster
WHERE "NAME" = 'Klay Thompson'

' SELECT "Team"\nFROM nba_roster\nWHERE "NAME" = \'Klay Thompson\''

In [50]:
# Chain to answer
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}
"""

prompt_response = ChatPromptTemplate.from_messages(
  [
    (
      "system",
      "Given an input question and SQL response, convert it to a natural language answer. No pre-amble"
    ),
    ("human", template)
  ]
)

full_chain = (
  RunnablePassthrough.assign(query=sql_response)
  | RunnablePassthrough.assign(
    schema=get_schema,
    response=lambda x: db.run(x["query"]),
  )
  | prompt_response
  | llm
)

full_chain.invoke({"question": "How many unique teams are there?"})

 ```sql
SELECT COUNT(DISTINCT "Team") 
FROM nba_roster
```

This query counts the number of distinct team names in the nba_roster table. The COUNT(DISTINCT) aggregate function returns the number of unique non-null values from the specified column. By applying it to the "Team" column, we can get the number of unique team names.

OperationalError: (sqlite3.OperationalError) near "```sql
SELECT COUNT(DISTINCT "Team") 
FROM nba_roster
```": syntax error
[SQL:  ```sql
SELECT COUNT(DISTINCT "Team") 
FROM nba_roster
```

This query counts the number of distinct team names in the nba_roster table. The COUNT(DISTINCT) aggregate function returns the number of unique non-null values from the specified column. By applying it to the "Team" column, we can get the number of unique team names.]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

### Chat with a SQL DB

In [None]:
# Prompt
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder

template = """Given an input question, convert it to a SQL query only. Based on the table schema below, write a SQL query that would answer the user's question:
{schema}
"""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", template),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question}"),
    ]
)

memory = ConversationBufferMemory(return_messages=True)

# Chain to query with memory
from langchain_core.runnables import RunnableLambda

sql_chain = (
    RunnablePassthrough.assign(
        schema=get_schema,
        history=RunnableLambda(lambda x: memory.load_memory_variables(x)["history"]),
    )
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


def save(input_output):
    output = {"output": input_output.pop("output")}
    memory.save_context(input_output, output)
    return output["output"]


sql_response_memory = RunnablePassthrough.assign(output=sql_chain) | save
sql_response_memory.invoke({"question": "What team is Klay Thompson on?"})

 SELECT "Team"
FROM nba_roster
WHERE "NAME" = 'Klay Thompson'

' SELECT "Team"\nFROM nba_roster\nWHERE "NAME" = \'Klay Thompson\''

In [None]:
# Chain to answer
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given an input question and SQL response, convert it to a natural language answer. No pre-amble.",
        ),
        ("human", template),
    ]
)

full_chain = (
    RunnablePassthrough.assign(query=sql_response_memory)
    | RunnablePassthrough.assign(
        schema=get_schema,
        response=lambda x: db.run(x["query"]),
    )
    | prompt_response
    | llm
)

full_chain.invoke({"question": "What is his salary?"})

OperationalError: (sqlite3.OperationalError) near "AI": syntax error
[SQL:  AI:  Sure thing! Here's the SQL query based on the provided table schema:

SELECT SALARY FROM nba_roster WHERE NAME = 'Klay Thompson';]
(Background on this error at: https://sqlalche.me/e/20/e3q8)