# Agentic RAG Experiment

In [1]:
%pip install -U --quiet langchain-community tiktoken langchain-openai langchainhub chromadb langchain langgraph langchain-text-splitters

Note: you may need to restart the kernel to use updated packages.


In [2]:
import getpass
import os


def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")


_set_env("OPENAI_API_KEY")

In [4]:
from langchain_community.utilities import SQLDatabase

# Replace with your actual PostgreSQL connection details
connection_uri = "postgresql://postgres:postgres@localhost:5432/aiagent"

# Include the schema in the search path
db = SQLDatabase.from_uri(connection_uri, engine_args={"connect_args": {"options": "-csearch_path=aiagent"}})

# Print database dialect
print(db.dialect)

# Print usable table names
print(db.get_usable_table_names())

# Example query
result = db.run("SELECT * FROM customer LIMIT 10;")
print(result)


postgresql
['customer']



# Generate Sample Data

In [5]:
!pip install faker

Collecting faker
  Downloading Faker-33.3.1-py3-none-any.whl.metadata (15 kB)
Downloading Faker-33.3.1-py3-none-any.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faker
Successfully installed faker-33.3.1


In [7]:
import psycopg2
import random
import json
from faker import Faker
from datetime import datetime

# Initialize Faker for generating random data
fake = Faker()

# PostgreSQL connection details with schema specified in `options`
connection_details = {
    "dbname": "aiagent",
    "user": "postgres",
    "password": "postgres",
    "host": "localhost",
    "port": 5432,
    "options": "-csearch_path=aiagent"
}

# Generate random data for the customer table
def generate_customer_data():
    customer_name = fake.name()
    customer_portfolio = round(random.uniform(1000, 100000), 2)  # Random portfolio amount
    #customer_products = json.dumps(
     #   {"products": [fake.word() for _ in range(random.randint(1, 5))]}
    #)  # Random JSON products
    # Predefined list of investment products
    investment_products = ["RRSP", "TFSA", "RESP", "Non-Registered Account", "LIRA", "Annuity"]

    # Generate random customer products
    customer_products = json.dumps(
        {"products": random.sample(investment_products, random.randint(1, len(investment_products)))}
)
    customer_dob = fake.date_of_birth(minimum_age=18, maximum_age=80)
    customer_sin = random.randint(100000000, 999999999)  # Random SIN
    customer_gender = random.choice(["Male", "Female", "Non-binary", "Other"])
    return (
        customer_name,
        customer_portfolio,
        customer_products,
        customer_dob,
        customer_sin,
        customer_gender,
    )

# Insert data into the database
def insert_data_into_table(conn, data):
    query = """
    INSERT INTO aiagent.customer (
        customer_name,
        customer_portfolio,
        customer_products,
        customer_dob,
        customer_sin,
        customer_gender
    )
    VALUES (%s, %s, %s, %s, %s, %s);
    """
    with conn.cursor() as cur:
        cur.executemany(query, data)
        conn.commit()

def main():
    # Connect to PostgreSQL
    try:
        conn = psycopg2.connect(**connection_details)
        print("Connection successful!")

        # Generate and insert 100 random customers
        data = [generate_customer_data() for _ in range(100)]
        insert_data_into_table(conn, data)

        print("Data inserted successfully!")
    except Exception as e:
        print("An error occurred:", e)
    finally:
        if conn:
            conn.close()

if __name__ == "__main__":
    main()


Connection successful!
Data inserted successfully!


# Agent Flow

In [8]:
from typing import Any

from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode


def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
    """
    Create a ToolNode with a fallback to handle errors and surface them to the agent.
    """
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state) -> dict:
    error = state.get("error")
    tool_calls = state["messages"][-1].tool_calls
    return {
        "messages": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_calls
        ]
    }

In [12]:
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI

toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model="gpt-4o"))
tools = toolkit.get_tools()

list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")

print(list_tables_tool.invoke(""))

print(get_schema_tool.invoke("customer"))

customer

CREATE TABLE customer (
	customer_id SERIAL NOT NULL, 
	customer_name VARCHAR(100) NOT NULL, 
	customer_portfolio NUMERIC(10, 2), 
	customer_products JSONB, 
	customer_dob DATE, 
	customer_sin BIGINT, 
	customer_gender VARCHAR(10), 
	CONSTRAINT customer_pkey PRIMARY KEY (customer_id)
)

/*
3 rows from customer table:
customer_id	customer_name	customer_portfolio	customer_products	customer_dob	customer_sin	customer_gender
101	Jonathan Scott	77764.15	{'products': ['Non-Registered Account', 'Annuity', 'TFSA', 'RESP', 'RRSP', 'LIRA']}	1974-02-08	580635182	Non-binary
102	William Swanson	81953.82	{'products': ['Annuity', 'RESP']}	2002-03-15	637863547	Male
103	Kevin Lawson	64930.05	{'products': ['Non-Registered Account', 'RRSP', 'LIRA', 'Annuity', 'RESP']}	1958-06-29	323174348	Other
*/


In [13]:
from langchain_core.tools import tool

@tool
def db_query_tool(query: str) -> str:
    """
    Execute a SQL query against the database and get back the result.
    If the query is not correct, an error message will be returned.
    If an error is returned, rewrite the query, check the query, and try again.
    """
    result = db.run_no_throw(query)
    if not result:
        return "Error: Query failed. Please rewrite your query and try again."
    return result


print(db_query_tool.invoke("SELECT * FROM customer LIMIT 10;"))

[(101, 'Jonathan Scott', Decimal('77764.15'), {'products': ['Non-Registered Account', 'Annuity', 'TFSA', 'RESP', 'RRSP', 'LIRA']}, datetime.date(1974, 2, 8), 580635182, 'Non-binary'), (102, 'William Swanson', Decimal('81953.82'), {'products': ['Annuity', 'RESP']}, datetime.date(2002, 3, 15), 637863547, 'Male'), (103, 'Kevin Lawson', Decimal('64930.05'), {'products': ['Non-Registered Account', 'RRSP', 'LIRA', 'Annuity', 'RESP']}, datetime.date(1958, 6, 29), 323174348, 'Other'), (104, 'Jaime Fitzgerald', Decimal('6789.04'), {'products': ['RESP', 'TFSA']}, datetime.date(1954, 4, 4), 691973921, 'Other'), (105, 'Christina Grimes', Decimal('38739.76'), {'products': ['Non-Registered Account']}, datetime.date(1993, 8, 31), 465610759, 'Female'), (106, 'Kimberly Rios', Decimal('6174.18'), {'products': ['RRSP', 'LIRA']}, datetime.date(1980, 8, 16), 516915108, 'Male'), (107, 'Jason Schmidt', Decimal('93916.93'), {'products': ['Non-Registered Account', 'RRSP', 'Annuity', 'LIRA', 'RESP', 'TFSA']}, d