# Initialise Bedrock LLM

In [1]:
# from simd.llm import llm

from langchain_aws import ChatBedrock
import dotenv
import os

dotenv.load_dotenv()
# Set LANGSMITH_PROJECT environment variable
os.environ['LANGSMITH_PROJECT'] = "agent_sql_database"


# Claude is a good alternative to GPT-4o: https://blog.promptlayer.com/big-differences-claude-3-5-vs-gpt-4o/
llm = ChatBedrock(
    model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
    # model_id="us.meta.llama3-3-70b-instruct-v1:0",
    # model_id="us.meta.llama3-2-3b-instruct-v1:0",
    # model_id="us.deepseek.r1-v1:0",
    model_kwargs=dict(temperature=0.2),
    # other params...
)

# Create Database Engine

In [2]:
import sqlite3

import requests
from langchain_community.utilities.sql_database import SQLDatabase
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool


def get_engine_for_chinook_db():
    """Pull sql file, populate in-memory database, and create engine."""
    url = "https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql"
    response = requests.get(url)
    sql_script = response.text

    connection = sqlite3.connect(":memory:", check_same_thread=False)
    connection.executescript(sql_script)
    return create_engine(
        "sqlite://",
        creator=lambda: connection,
        poolclass=StaticPool,
        connect_args={"check_same_thread": False},
    )


engine = get_engine_for_chinook_db()

db = SQLDatabase(engine)

In [3]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

# Test Spark SQL Toolkit

In [4]:
toolkit.get_tools()

[QuerySQLDatabaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10ebb6ba0>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10ebb6ba0>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x10ebb6ba0>),
 QuerySQLCheckerTool(description='Use this tool to double check if your 

# Create SQL Database ReAct Agent

In [5]:
from langchain import hub

prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")

assert len(prompt_template.messages) == 1
print(prompt_template.input_variables)

['dialect', 'top_k']


In [7]:
from langgraph.prebuilt import create_react_agent
system_message = prompt_template.format(dialect="SQLite", top_k=5)

agent_executor = create_react_agent(llm, toolkit.get_tools(), prompt=system_message)

In [9]:
example_query = "Which Table has the most records??"

events = agent_executor.stream(
    {"messages": [("user", example_query)]},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()


Which Table has the most records??

I'll help you find out which table has the most records in the database. Let's start by listing all the tables in the database.
Tool Calls:
  sql_db_list_tables (toolu_bdrk_01FzdgqTEMGAs9F5WuzjNFrY)
 Call ID: toolu_bdrk_01FzdgqTEMGAs9F5WuzjNFrY
  Args:
    tool_input:
Name: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track

Now that I have the list of tables, I need to count the number of records in each table to determine which one has the most. Let me create a query to do this.
Tool Calls:
  sql_db_query_checker (toolu_bdrk_01GsUCUua4WRiEAMgbZJuGdh)
 Call ID: toolu_bdrk_01GsUCUua4WRiEAMgbZJuGdh
  Args:
    query: WITH TableCounts AS (
    SELECT 'Album' AS TableName, COUNT(*) AS RecordCount FROM Album
    UNION ALL
    SELECT 'Artist' AS TableName, COUNT(*) AS RecordCount FROM Artist
    UNION ALL
    SELECT 'Customer' AS TableName, COUNT(*) AS RecordCount FROM Customer
  