In [8]:
import sqlite3
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
from langchain_community.utilities.sql_database import SQLDatabase

In [2]:
database_path = "my_database.db"

In [3]:
connection = sqlite3.connect(database_path, check_same_thread=False)

In [4]:
schema = """
CREATE TABLE IF NOT EXISTS employees (
    id INTEGER PRIMARY KEY,
    name TEXT NOT NULL,
    department TEXT,
    salary REAL
);
CREATE TABLE IF NOT EXISTS departments (
    id INTEGER PRIMARY KEY,
    name TEXT NOT NULL
);
"""

In [5]:
connection.executescript(schema)

<sqlite3.Cursor at 0x133d5342840>

In [6]:
engine = create_engine(
    f"sqlite:///{database_path}",
    poolclass=StaticPool,
    connect_args={"check_same_thread": False}
)

In [9]:
db = SQLDatabase(engine)

In [10]:
cursor = connection.cursor()

In [11]:
employee_records = [
    ("Alice", "Engineering", 75000),
    ("Bob", "Sales", 60000),
    ("Charlie", "Engineering", 80000),
    ("Daisy", "Sales", 62000),
    ("Ethan", "Marketing", 55000),
    ("Fiona", "Engineering", 85000),
    ("George", "Marketing", 58000),
    ("Hannah", "HR", 52000),
    ("Ivy", "Finance", 90000),
    ("Jack", "Finance", 91000),
    ("Katie", "HR", 53000),
    ("Liam", "Sales", 70000),
    ("Mia", "Engineering", 92000),
    ("Noah", "Sales", 64000),
    ("Olivia", "IT", 95000),
    ("Parker", "IT", 97000),
    ("Quinn", "Marketing", 61000),
    ("Ruby", "HR", 54000),
    ("Sam", "Finance", 89000),
    ("Tina", "Sales", 71000),
    ("Uma", "Engineering", 86000),
    ("Vince", "Finance", 94000),
    ("Wendy", "IT", 98000),
    ("Xander", "Marketing", 60000),
    ("Yara", "Sales", 69000),
    ("Zane", "HR", 56000)
]

In [12]:
department_records = [
    ("Engineering",),
    ("Sales",),
    ("Marketing",),
    ("HR",),
    ("Finance",),
    ("IT",),
    ("Legal",),
    ("Operations",),
    ("Customer Service",)
]

In [13]:
cursor.executemany("INSERT INTO employees (name, department, salary) VALUES (?, ?, ?)", employee_records)
cursor.executemany("INSERT INTO departments (name) VALUES (?)", department_records)

<sqlite3.Cursor at 0x133d78c3d40>

In [14]:
connection.commit()

In [15]:
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
import os

In [16]:
load_dotenv()

True

In [17]:
my_openai_key = os.getenv("my_openai_key")

In [18]:
llm = ChatOpenAI(model="gpt-4o-mini", api_key=my_openai_key)

In [19]:
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit

In [20]:
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [21]:
toolkit.get_tools()

[QuerySQLDataBaseTool(description="Input to this tool is a detailed and correct SQL query, output is a result from the database. If the query is not correct, an error message will be returned. If an error is returned, rewrite the query, check the query, and try again. If you encounter an issue with Unknown column 'xxxx' in 'field list', use sql_db_schema to query the correct table fields.", db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x00000133D78F6290>),
 InfoSQLDatabaseTool(description='Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. Be sure that the tables actually exist by calling sql_db_list_tables first! Example Input: table1, table2, table3', db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x00000133D78F6290>),
 ListSQLDatabaseTool(db=<langchain_community.utilities.sql_database.SQLDatabase object at 0x00000133D78F6290>),
 QuerySQLCheckerTool(description='Use this tool to 

In [22]:
from langchain import hub

prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")

assert len(prompt_template.messages) == 1
print(prompt_template.input_variables)



['dialect', 'top_k']


In [23]:
system_message = prompt_template.format(dialect="SQLite", top_k=5)

In [24]:
from langgraph.prebuilt import create_react_agent

agent_executor = create_react_agent(
    llm, toolkit.get_tools(), state_modifier=system_message
)

In [25]:
example_query = "Which department has the highest average salary?"

events = agent_executor.stream(
    {"messages": [("user", example_query)]},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()


Which department has the highest average salary?
Tool Calls:
  sql_db_list_tables (call_Vifdxlvwc7zLz5q5KCyaMX19)
 Call ID: call_Vifdxlvwc7zLz5q5KCyaMX19
  Args:
Name: sql_db_list_tables

departments, employees
Tool Calls:
  sql_db_schema (call_FuY7UUbNk6CBhrmTLnzu022k)
 Call ID: call_FuY7UUbNk6CBhrmTLnzu022k
  Args:
    table_names: departments
  sql_db_schema (call_9n2qXSoNykzxPKhHNluaBCZL)
 Call ID: call_9n2qXSoNykzxPKhHNluaBCZL
  Args:
    table_names: employees
Name: sql_db_schema


CREATE TABLE employees (
	id INTEGER, 
	name TEXT NOT NULL, 
	department TEXT, 
	salary REAL, 
	PRIMARY KEY (id)
)

/*
3 rows from employees table:
id	name	department	salary
1	Alice	Engineering	75000.0
2	Bob	Sales	60000.0
3	Charlie	Engineering	80000.0
*/
Tool Calls:
  sql_db_query_checker (call_jMmQnFB5Mo1jt8Yhq2d2I310)
 Call ID: call_jMmQnFB5Mo1jt8Yhq2d2I310
  Args:
    query: SELECT department, AVG(salary) AS average_salary FROM employees GROUP BY department ORDER BY average_salary DESC LIMIT 1;
Nam

In [26]:
example_query = "List all employees in the Finance department, ordered by salary."

events = agent_executor.stream(
    {"messages": [("user", example_query)]},
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()


List all employees in the Finance department, ordered by salary.
Tool Calls:
  sql_db_list_tables (call_AlZjmYN4AsYIiy6VLHIzEoQN)
 Call ID: call_AlZjmYN4AsYIiy6VLHIzEoQN
  Args:
Name: sql_db_list_tables

departments, employees
Tool Calls:
  sql_db_schema (call_F33K66AbThVSovgFHwYa3WHq)
 Call ID: call_F33K66AbThVSovgFHwYa3WHq
  Args:
    table_names: departments
  sql_db_schema (call_PhYEviG1m4wodH4JYg0xPUvZ)
 Call ID: call_PhYEviG1m4wodH4JYg0xPUvZ
  Args:
    table_names: employees
Name: sql_db_schema


CREATE TABLE employees (
	id INTEGER, 
	name TEXT NOT NULL, 
	department TEXT, 
	salary REAL, 
	PRIMARY KEY (id)
)

/*
3 rows from employees table:
id	name	department	salary
1	Alice	Engineering	75000.0
2	Bob	Sales	60000.0
3	Charlie	Engineering	80000.0
*/
Tool Calls:
  sql_db_query_checker (call_ppu77pFtzAdCzJQ1WyFey0C7)
 Call ID: call_ppu77pFtzAdCzJQ1WyFey0C7
  Args:
    query: SELECT name, salary FROM employees WHERE department = 'Finance' ORDER BY salary;
Name: sql_db_query_checker

