# Dynamic Prompt
<img src="./assets/LC_DynamicPrompts.png" width="500">

## Setup

Load and/or check for needed environmental variables

In [1]:
from dotenv import load_dotenv
from env_utils import doublecheck_env

# Load environment variables from .env
load_dotenv()

# Check and print results
doublecheck_env(".env")

OPENAI_API_KEY=****here
LANGSMITH_API_KEY=****754b
LANGSMITH_TRACING=true
LANGSMITH_PROJECT=****ials


In [2]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")

In [3]:
from dataclasses import dataclass


@dataclass
class RuntimeContext:
    is_employee: bool
    db: SQLDatabase

In [4]:
from langchain_core.tools import tool
from langgraph.runtime import get_runtime

@tool
def execute_sql(query: str) -> str:
    """Execute a SQLite command and return results."""
    runtime = get_runtime(RuntimeContext)
    db = runtime.context.db

    try:
        return db.run(query)
    except Exception as e:
        return f"Error: {e}"

In [5]:
SYSTEM_PROMPT_TEMPLATE = """You are a careful SQLite analyst.

Rules:
- Think step-by-step.
- When you need data, call the tool `execute_sql` with ONE SELECT query.
- Read-only only; no INSERT/UPDATE/DELETE/ALTER/DROP/CREATE/REPLACE/TRUNCATE.
- Limit to 5 rows unless the user explicitly asks otherwise.
{table_limits}
- If the tool returns 'Error:', revise the SQL and try again.
- Prefer explicit column lists; avoid SELECT *.
"""

## Build a Dynamic Prompt
Utilize runtime context and middleware to generate a dynamic prompt.

In [6]:
from langchain.agents.middleware.types import ModelRequest, dynamic_prompt


@dynamic_prompt
def dynamic_system_prompt(request: ModelRequest) -> str:
    if not request.runtime.context.is_employee:
        table_limits = "- Limit access to these tables: Album, Artist, Genre, Playlist, PlaylistTrack, Track."
    else:
        table_limits = ""

    return SYSTEM_PROMPT_TEMPLATE.format(table_limits=table_limits)

Include middleware in `create_agent`.

In [7]:
from langchain_ollama import ChatOllama

llm = ChatOllama(model="qwen2.5-coder:7b", temperature=0.8)

In [8]:
from langchain.agents import create_agent

agent = create_agent(
    model=llm,
    tools=[execute_sql],
    middleware=[dynamic_system_prompt],
    context_schema=RuntimeContext,
)

In [9]:
question = "What is the most costly purchase by Frank Harris?"

for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    context=RuntimeContext(is_employee=False, db=db),
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


What is the most costly purchase by Frank Harris?

To determine the most costly purchase by Frank Harris, we need to join the `Invoice` and `Customer` tables based on the customer ID. Then, filter the results to only include invoices for Frank Harris and sort them by the total cost in descending order to find the most expensive one.

Here is the SQL query to achieve this:

```sql
SELECT i.InvoiceId, i.Total 
FROM Invoice i
JOIN Customer c ON i.CustomerId = c.CustomerId
WHERE c.FirstName = 'Frank' AND c.LastName = 'Harris'
ORDER BY i.Total DESC
LIMIT 1;
```

Let's execute this SQL query to get the result.


In [10]:
question = "What is the most costly purchase by Frank Harris?"

for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    context=RuntimeContext(is_employee=True, db=db),
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


What is the most costly purchase by Frank Harris?

To find the most costly purchase by Frank Harris, we need to:

1. Identify the table that contains purchase information.
2. Filter the purchases to include only those made by Frank Harris.
3. Determine the cost of each purchase.
4. Find the purchase with the highest cost.

Let's assume the table is named `purchases` and it has columns `customer_name`, `item_name`, and `cost`. We will use the following SQL query:

```sql
SELECT item_name, MAX(cost) as max_cost
FROM purchases
WHERE customer_name = 'Frank Harris';
```

This query selects the most expensive item and its cost for Frank Harris.

{"name": "execute_sql", "arguments": {"query": "SELECT item_name, MAX(cost) as max_cost FROM purchases WHERE customer_name = 'Frank Harris' LIMIT 1"}}


In [None]:
import json
import re
import nest_asyncio
from typing import Annotated, List, TypedDict, Literal
from langchain_community.utilities import SQLDatabase
from langchain_core.tools import tool
from langchain_ollama import ChatOllama
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.runnables import RunnableConfig

# 1. Setup
nest_asyncio.apply()
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage], add_messages]

@tool
def execute_sql(query: str, config: RunnableConfig) -> str:
    """Execute a SQLite command and return results."""
    database = config["configurable"].get("db")
    try:
        # Aggressive cleaning of the query string
        clean_query = query.strip().strip("'").strip('"').replace('```sql', '').replace('```', '')
        return str(database.run(clean_query))
    except Exception as e:
        return f"Error: {e}"

tools = [execute_sql]
tool_node = ToolNode(tools)
llm = ChatOllama(model="qwen2.5-coder:7b", temperature=0)

# 2. THE FIX: Custom Routing Logic
def manual_router(state: AgentState) -> Literal["tools", "__end__"]:
    last_message = state["messages"][-1]
    
    # Check if the model used the native tool_calls list
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        return "tools"
    
    # FALLBACK: If the model wrote a SQL query in the text, FORCE it to the tools node
    content = last_message.content.upper()
    if "SELECT" in content and "FROM" in content:
        return "tools"
    
    return "__end__"

# 3. Model Node with Forced Parsing
def call_model(state: AgentState, config: RunnableConfig):
    is_employee = config["configurable"].get("is_employee", False)
    database = config["configurable"].get("db")
    
    limits = "Full access." if is_employee else "ONLY access: Album, Artist, Genre, Track."
    
    sys_msg = SystemMessage(content=f"""You are a SQL analyst. 
    {limits}
    ALWAYS use 'execute_sql' to get data. 
    If you need data, output ONLY the JSON for the tool call.
    Schema: {database.get_table_info()}""")
    
    # Invoke model
    response = llm.invoke([sys_msg] + state["messages"])
    
    # If the model wrote a query but didn't trigger 'tool_calls', we fix the message object
    if not response.tool_calls:
        sql_match = re.search(r'SELECT\s+.*?;', response.content, re.IGNORECASE | re.DOTALL)
        if sql_match:
            response.tool_calls = [{
                "name": "execute_sql", 
                "args": {"query": sql_match.group()}, 
                "id": "manual_fix"
            }]
            
    return {"messages": [response]}

# 4. Build Graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)

workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", manual_router) # Uses our fixed router
workflow.add_edge("tools", "agent")

app = workflow.compile(checkpointer=MemorySaver())

# 5. Run Test
async def run_test(is_emp: bool):
    print(f"\n--- Testing (Employee={is_emp}) ---")
    config = {"configurable": {"thread_id": "1", "db": db, "is_employee": is_emp}}
    async for event in app.astream({"messages": [HumanMessage(content="What is the total of Frank Harris's last invoice?")]}, config):
        for value in event.values():
            value["messages"][-1].pretty_print()

await run_test(is_emp=True)

In [None]:
import json
import nest_asyncio
from dataclasses import dataclass
from langchain_community.utilities import SQLDatabase
from langchain_core.tools import tool
from langchain_ollama import ChatOllama
from langgraph.graph import StateGraph, MessagesState, START, END
from langgraph.prebuilt import ToolNode

# 1. Setup
nest_asyncio.apply()
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# 2. Tool with logic to access DB
@tool
def execute_sql(query: str) -> str:
    """Execute a SQLite command and return results."""
    try:
        # Clean potential markdown/quotes from local LLM output
        clean_query = query.strip().strip("'").strip('"').replace('```sql', '').replace('```', '')
        return db.run(clean_query)
    except Exception as e:
        return f"Error: {e}"

tools = [execute_sql]
tool_node = ToolNode(tools)
llm = ChatOllama(model="qwen2.5-coder:7b", temperature=0).bind_tools(tools)

# 3. Dynamic System Prompt Logic
SYSTEM_PROMPT_TEMPLATE = """You are a careful SQLite analyst.
Rules:
- Think step-by-step.
- Call `execute_sql` with ONE SELECT query.
- Read-only: no INSERT/UPDATE/DELETE.
- Limit to 5 rows.
{table_limits}
- Return tool calls in JSON format.
- Schema:
{schema}
"""

# 4. Agent Node with Security Logic
def call_model(state: MessagesState, config: dict):
    # Retrieve security context from the config
    is_employee = config.get("configurable", {}).get("is_employee", False)
    
    if not is_employee:
        table_limits = "- SECURITY: Limit access ONLY to: Album, Artist, Genre, Playlist, PlaylistTrack, Track. DO NOT access Invoice or Customer tables."
    else:
        table_limits = "- SECURITY: You have full employee access to all tables including Customer and Invoice."

    # Construct the dynamic prompt
    sys_content = SYSTEM_PROMPT_TEMPLATE.format(
        table_limits=table_limits,
        schema=db.get_table_info()
    )
    
    messages = [{"role": "system", "content": sys_content}] + state["messages"]
    response = llm.invoke(messages)
    return {"messages": [response]}

# 5. Custom Router (The Fix for local JSON formatting)
def custom_router(state: MessagesState):
    last_message = state["messages"][-1]
    
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        return "tools"
    
    content = last_message.content.strip()
    if content.startswith("{") and "name" in content:
        try:
            tool_data = json.loads(content)
            last_message.tool_calls = [{
                "name": tool_data["name"],
                "args": tool_data["arguments"],
                "id": "manual_call_id"
            }]
            return "tools"
        except:
            pass
            
    return END

# 6. Build Graph
workflow = StateGraph(MessagesState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)

workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", custom_router)
workflow.add_edge("tools", "agent")

app = workflow.compile()

# 7. Run Tests
async def run_test(is_employee: bool, question: str):
    print(f"\n--- Testing as {'Employee' if is_employee else 'Customer'} ---")
    config = {"configurable": {"is_employee": is_employee}}
    inputs = {"messages": [("user", question)]}
    
    async for chunk in app.astream(inputs, config=config, stream_mode="values"):
        chunk["messages"][-1].pretty_print()

# Run scenarios
# 1. Customer asking for invoice (Should be blocked by prompt)
await run_test(is_employee=False, question="What is the most costly purchase by Frank Harris?")

# 2. Employee asking for invoice (Should be allowed)
await run_test(is_employee=True, question="What is the most costly purchase by Frank Harris?")