以下のガイドを参考に text-to-sql を試してみる

https://python.langchain.com/v0.2/docs/how_to/sql_prompting/


In [1]:
from jawsume import jawsume

profile_name = input("Enter the profile name: ")
jawsume(profile_name=profile_name)

In [2]:
import os
import getpass

os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
os.environ["LANGCHAIN_TRACING_V2"] = "true"

In [3]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///assets/Chinook.db", sample_rows_in_table_info=3)
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM Artist LIMIT 10;"))

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]


In [4]:
from langchain.chains.sql_database.prompt import SQL_PROMPTS

list(SQL_PROMPTS)

['crate',
 'duckdb',
 'googlesql',
 'mssql',
 'mysql',
 'mariadb',
 'oracle',
 'postgresql',
 'sqlite',
 'clickhouse',
 'prestodb']

In [5]:
from langchain_aws.chat_models import ChatBedrock

llm = ChatBedrock(
    model_id="anthropic.claude-3-haiku-20240307-v1:0",
    region_name="us-east-1",
)

In [6]:
import re
from typing import Optional


def extract_sql_query(text: str) -> Optional[str]:
    """create_sql_query_chain関数で作成したChainが以下のような形式の文字列を返すため、SQLクエリ部分のみを抽出する関数を作成する

    ```txt
    Question: 従業員は何人いますか?
    SQLQuery: SELECT COUNT("EmployeeId") AS "従業員の人数" FROM "Employee";
    ```
    """
    match = re.search(r"SQLQuery: (.*?)$", text)
    if match:
        return match.group(1)
    else:
        return None

In [19]:
from langchain.chains.sql_database.query import create_sql_query_chain
from langchain_core.runnables import RunnableLambda

write_query = create_sql_query_chain(llm, db) | RunnableLambda(extract_sql_query)

In [21]:
response = write_query.invoke({"question": "従業員は何人いますか？"})
response

'SELECT COUNT(*) AS "従業員数" FROM Employee;'

In [22]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
chain = write_query | execute_query
chain.invoke({"question": "従業員は何人いますか？"})

'[(8,)]'