# Memory

<div style="display: flex; justify-content: flex-start; gap: 10px;">
  <img src="./assets/LC_Memory_after.png" style="width:300px; border:1px solid #ccc; border-radius:6px;">
</div>

Persisting messages, or 'agent state' between invocations of the agent.

## Setup

Load and/or check for needed environmental variables

In [1]:
from dotenv import load_dotenv
from env_utils import doublecheck_env

# Load environment variables from .env
load_dotenv()

# Check and print results
doublecheck_env(".env")

OPENAI_API_KEY=****here
LANGSMITH_API_KEY=****754b
LANGSMITH_TRACING=true
LANGSMITH_PROJECT=****ials


In [15]:
import json
import nest_asyncio
from langchain_community.utilities import SQLDatabase
from langchain_core.tools import tool
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode

nest_asyncio.apply()
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

@tool
def execute_sql(query: str) -> str:
    """Execute a SQLite command and return results."""
    try:
        return db.run(query.strip().strip("'").strip('"'))
    except Exception as e:
        return f"Error: {e}"

tools = [execute_sql]
tool_node = ToolNode(tools)
llm = ChatOllama(model="qwen2.5-coder:7b", temperature=0).bind_tools(tools)

# --- THE FIX: Custom Routing Logic ---
def custom_router(state: MessagesState):
    last_message = state["messages"][-1]
    
    # 1. Check if the model actually used the 'tool_calls' property
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        return "tools"
    
    # 2. If not, check if the TEXT content looks like a tool call JSON
    content = last_message.content.strip()
    if content.startswith("{") and "name" in content and "arguments" in content:
        # We manually inject the tool call into the message so ToolNode can find it
        try:
            tool_data = json.loads(content)
            last_message.tool_calls = [{
                "name": tool_data["name"],
                "args": tool_data["arguments"],
                "id": "manual_call_id" # unique ID required by LangGraph
            }]
            return "tools"
        except:
            pass
            
    return END

def call_model(state: MessagesState):
    # Prompt emphasizing the JSON format for the model
    system_prompt = f"You are a SQL expert. Return tool calls in JSON format.\nSchema:\n{db.get_table_info()}"
    messages = [{"role": "system", "content": system_prompt}] + state["messages"]
    return {"messages": [llm.invoke(messages)]}

# --- Build the Graph ---
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)

workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", custom_router) # Use our custom router
workflow.add_edge("tools", "agent")

app = workflow.compile()

# --- Run ---
async def run_agent():
    inputs = {"messages": [("user", "This is Frank Harris, What was the total on my last invoice?")]}
    async for chunk in app.astream(inputs, stream_mode="values"):
        chunk["messages"][-1].pretty_print()

await run_agent()


This is Frank Harris, What was the total on my last invoice?

{
  "name": "execute_sql",
  "arguments": {
    "query": "SELECT Total FROM Invoice WHERE CustomerId = (SELECT CustomerId FROM Customer WHERE FirstName = 'Frank' AND LastName = 'Harris') ORDER BY InvoiceDate DESC LIMIT 1"
  }
}
Tool Calls:
  execute_sql (manual_call_id)
 Call ID: manual_call_id
  Args:
    query: SELECT Total FROM Invoice WHERE CustomerId = (SELECT CustomerId FROM Customer WHERE FirstName = 'Frank' AND LastName = 'Harris') ORDER BY InvoiceDate DESC LIMIT 1
Name: execute_sql

[(5.94,)]

The total on your last invoice was $5.94.


In [16]:
# --- Run ---
async def run_agent():
    inputs = {"messages": [("user", "What were the titles?")]}
    async for chunk in app.astream(inputs, stream_mode="values"):
        chunk["messages"][-1].pretty_print()

await run_agent()


What were the titles?

{
  "name": "execute_sql",
  "arguments": {
    "query": "SELECT Title FROM Album"
  }
}
Tool Calls:
  execute_sql (manual_call_id)
 Call ID: manual_call_id
  Args:
    query: SELECT Title FROM Album
Name: execute_sql

[('For Those About To Rock We Salute You',), ('Balls to the Wall',), ('Restless and Wild',), ('Let There Be Rock',), ('Big Ones',), ('Jagged Little Pill',), ('Facelift',), ('Warner 25 Anos',), ('Plays Metallica By Four Cellos',), ('Audioslave',), ('Out Of Exile',), ('BackBeat Soundtrack',), ('The Best Of Billy Cobham',), ('Alcohol Fueled Brewtality Live! [Disc 1]',), ('Alcohol Fueled Brewtality Live! [Disc 2]',), ('Black Sabbath',), ('Black Sabbath Vol. 4 (Remaster)',), ('Body Count',), ('Chemical Wedding',), ('The Best Of Buddy Guy - The Millenium Collection',), ('Prenda Minha',), ('Sozinho Remix Ao Vivo',), ('Minha Historia',), ('Afrociberdelia',), ('Da Lama Ao Caos',), ('Acústico MTV [Live]',), ('Cidade Negra - Hits',), ('Na Pista',), ('Axé Bah

## Add memory

## Try your own queries
Now that there is memory, check the agents recall!

In [17]:
import json
import nest_asyncio
from langchain_community.utilities import SQLDatabase
from langchain_core.tools import tool
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver # For memory persistence

# 1. Initialize for Notebook environment
nest_asyncio.apply()

# 2. Database Connection
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# 3. Tool Definition
@tool
def execute_sql(query: str) -> str:
    """Execute a SQLite command and return results."""
    try:
        # Some local models add extra quotes to the query string
        clean_query = query.strip().strip("'").strip('"')
        return db.run(clean_query)
    except Exception as e:
        return f"Error: {e}"

tools = [execute_sql]
tool_node = ToolNode(tools)

# 4. Model Setup
# Qwen 2.5 is bound with the tools to understand the execute_sql schema
llm = ChatOllama(model="qwen2.5-coder:7b", temperature=0).bind_tools(tools)

# 5. Routing Logic (The "Fix" for local model JSON output)
def custom_router(state: MessagesState):
    last_message = state["messages"][-1]
    
    # Check for native tool calls first
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        return "tools"
    
    # If model outputs JSON text instead of a tool call object, parse it manually
    content = last_message.content.strip()
    if content.startswith("{") and "name" in content:
        try:
            tool_data = json.loads(content)
            last_message.tool_calls = [{
                "name": tool_data["name"],
                "args": tool_data["arguments"],
                "id": "manual_call_id"
            }]
            return "tools"
        except:
            pass
    return END

def call_model(state: MessagesState):
    system_prompt = (
        "You are a SQL expert. Use the 'execute_sql' tool for data.\n"
        f"Schema:\n{db.get_table_info()}\n"
        "Rules: Return ONLY the tool call if you need data. "
        "After getting results, answer the user's question naturally."
    )
    # Combine system prompt with conversation history
    messages = [{"role": "system", "content": system_prompt}] + state["messages"]
    response = llm.invoke(messages)
    return {"messages": [response]}

# 6. Build the StateGraph
workflow = StateGraph(MessagesState)

workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)

workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", custom_router)
workflow.add_edge("tools", "agent")

# 7. Compile with Memory Checkpointer
checkpointer = MemorySaver()
app = workflow.compile(checkpointer=checkpointer)

# 8. Execution with Thread ID
async def run_chat():
    # 'thread_id' is the key that tells the agent which conversation to load
    config = {"configurable": {"thread_id": "frank_harris_session"}}
    
    # Step 1: Introduction
    print("--- FIRST TURN ---")
    inputs = {"messages": [("user", "This is Frank Harris. What was the total on my last invoice?")]}
    async for chunk in app.astream(inputs, config=config, stream_mode="values"):
        chunk["messages"][-1].pretty_print()

    # Step 2: Follow-up (Agent remembers who 'I' am)
    print("\n--- SECOND TURN (Memory) ---")
    follow_up = {"messages": [("user", "What is my email address in your system?")]}
    async for chunk in app.astream(follow_up, config=config, stream_mode="values"):
        chunk["messages"][-1].pretty_print()
    
    # Step 3: Another Follow-up (Agent remembers who 'I' am)
    print("\n--- THIRD TURN (Memory) ---")
    follow_up = {"messages": [("user", "What were the titles?")]}
    async for chunk in app.astream(follow_up, config=config, stream_mode="values"):
        chunk["messages"][-1].pretty_print()

await run_chat()

--- FIRST TURN ---

This is Frank Harris. What was the total on my last invoice?

{
  "name": "execute_sql",
  "arguments": {
    "query": "SELECT Total FROM Invoice WHERE CustomerId = (SELECT CustomerId FROM Customer WHERE FirstName = 'Frank' AND LastName = 'Harris') ORDER BY InvoiceDate DESC LIMIT 1"
  }
}
Tool Calls:
  execute_sql (manual_call_id)
 Call ID: manual_call_id
  Args:
    query: SELECT Total FROM Invoice WHERE CustomerId = (SELECT CustomerId FROM Customer WHERE FirstName = 'Frank' AND LastName = 'Harris') ORDER BY InvoiceDate DESC LIMIT 1
Name: execute_sql

[(5.94,)]

The total on your last invoice was $5.94.

--- SECOND TURN (Memory) ---

What is my email address in your system?

{
  "name": "execute_sql",
  "arguments": {
    "query": "SELECT Email FROM Customer WHERE FirstName = 'Frank' AND LastName = 'Harris'"
  }
}
Tool Calls:
  execute_sql (manual_call_id)
 Call ID: manual_call_id
  Args:
    query: SELECT Email FROM Customer WHERE FirstName = 'Frank' AND LastName 

In [18]:
import sqlite3
import pandas as pd

# 1. Connect to the Chinook database
# Ensure Chinook.db is in the same folder as your notebook
conn = sqlite3.connect('Chinook.db')

# 2. Define the SQL query
# This joins the Customer and Invoice tables to find Frank Harris's last total
query = """
SELECT 
    i.Total, 
    i.InvoiceDate 
FROM Invoice i
JOIN Customer c ON i.CustomerId = c.CustomerId
WHERE c.FirstName = 'Frank' AND c.LastName = 'Harris'
ORDER BY i.InvoiceDate DESC
LIMIT 1;
"""

# 3. Execute the query and display the result using Pandas for a clean table
df = pd.read_sql_query(query, conn)

# 4. Display the result
if not df.empty:
    print(f"Frank, your last invoice total was: ${df['Total'].iloc[0]}")
    display(df)
else:
    print("No invoice records found for Frank Harris.")

# 5. Close the connection
conn.close()

Frank, your last invoice total was: $5.94


Unnamed: 0,Total,InvoiceDate
0,5.94,2013-07-04 00:00:00
