# Dynamic Prompt
<img src="./assets/LC_DynamicPrompts.png" width="500">

## Setup

Load and/or check for needed environmental variables

In [1]:
from dotenv import load_dotenv
from env_utils import doublecheck_env

# Load environment variables from .env
load_dotenv()

# Check and print results
doublecheck_env(".env")

OPENAI_API_KEY=****here
LANGSMITH_API_KEY=****754b
LANGSMITH_TRACING=true
LANGSMITH_PROJECT=****ials


In [3]:
# Import required modules
import re
import nest_asyncio
from typing import Annotated, List, TypedDict, Literal, Set
from langchain_community.utilities import SQLDatabase
from langchain_core.tools import tool
from langchain_ollama import ChatOllama
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver

nest_asyncio.apply()

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
CURRENT_IS_EMPLOYEE = False

ALLOWED_TABLES_NON_EMPLOYEE: Set[str] = {
    "Album", "Artist", "Genre", "Playlist", "PlaylistTrack", "Track"
}

def extract_table_names(query: str) -> Set[str]:
    q = query.upper()
    tables = set()
    for pattern in [r'\bFROM\s+([^\s\(\),;]+)', r'\bJOIN\s+([^\s\(\),;]+)']:
        for match in re.findall(pattern, q):
            table = match.strip().split()[0].strip('"`[]')
            if table:
                tables.add(table)
    return tables

@tool(description="Execute a read-only SQLite SELECT query and return results as a string.")
def execute_sql(query: str) -> str:
    global CURRENT_IS_EMPLOYEE
    try:
        if not query.strip().upper().startswith("SELECT"):
            return "Error: Only SELECT queries are allowed."
        tables_used = extract_table_names(query)
        if not CURRENT_IS_EMPLOYEE:
            forbidden = tables_used - ALLOWED_TABLES_NON_EMPLOYEE
            if forbidden:
                return (f"Error: Access denied to tables: {sorted(forbidden)}. "
                        f"Allowed: {sorted(ALLOWED_TABLES_NON_EMPLOYEE)}")
        clean_query = query.strip().replace('```sql', '').replace('```', '').rstrip(';') + ';'
        return str(db.run(clean_query))
    except Exception as e:
        return f"Error: {e}"

tools = [execute_sql]
tool_node = ToolNode(tools)

class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage], add_messages]
    is_employee: bool
    tool_call_count: int

def get_schema_info(is_employee: bool) -> str:
    if is_employee:
        return db.get_table_info()
    else:
        # Filter schema to allowed tables only
        info = db.get_table_info()
        lines = info.split('\n')
        filtered = []
        include = True
        for line in lines:
            if line.startswith('CREATE TABLE'):
                table_match = re.search(r'"([^"]+)"', line)
                table = table_match.group(1) if table_match else ""
                include = table in ALLOWED_TABLES_NON_EMPLOYEE
            if include:
                filtered.append(line)
        return '\n'.join(filtered)

llm = ChatOllama(model="qwen2.5-coder:7b", temperature=0.8)

def get_system_message(is_employee: bool) -> SystemMessage:
    schema = get_schema_info(is_employee)
    access_note = "" if is_employee else (
        "- You ONLY have access to: Album, Artist, Genre, Playlist, PlaylistTrack, Track.\n"
        "- NEVER use Customer, Invoice, etc.\n"
    )
    return SystemMessage(content=f"""You are a helpful SQL analyst for the Chinook music store.

{access_note}
Here is the database schema:
{schema}

Rules:
- Use correct table/column names (e.g., Customer.FirstName, Invoice.Total).
- When returning a final answer, use plain English.
- Format monetary amounts as **$X.XX** (e.g., "$13.86").
- DO NOT use LaTeX, \\(\\), \\boxed{{}}, markdown, or code formatting in your final answer.
- If you cannot answer due to restrictions, say so clearly.
- Limit queries to 5 rows unless asked for more.
- Stop immediately after an "Access denied" error â€” do not retry unrelated queries.
""")

def call_model(state: AgentState):
    sys_msg = get_system_message(state["is_employee"])
    response = llm.invoke([sys_msg] + state["messages"])
    if not getattr(response, 'tool_calls', None):
        sql_match = re.search(r'(SELECT\s+.*?)(?:;|\s*$)', response.content, re.IGNORECASE | re.DOTALL)
        if sql_match:
            query = sql_match.group(1).strip() + ";"
            response.tool_calls = [{
                "name": "execute_sql",
                "args": {"query": query},
                "id": "manual_fix"
            }]
    return {"messages": [response]}

def route_after_agent(state: AgentState) -> Literal["tools", "__end__"]:
    last = state["messages"][-1]
    tool_call_count = state.get("tool_call_count", 0)
    # Prevent 4th attempt: stop if already made 3 calls
    if tool_call_count >= 3:
        return "__end__"
    if getattr(last, 'tool_calls', None):
        return "tools"
    content = last.content.upper()
    if "SELECT" in content and "FROM" in content:
        return "tools"
    return "__end__"

def update_tool_count(state: AgentState) -> dict:
    return {"tool_call_count": state.get("tool_call_count", 0) + 1}

workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)
workflow.add_node("update_count", update_tool_count)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", route_after_agent)
workflow.add_edge("tools", "update_count")
workflow.add_edge("update_count", "agent")

app = workflow.compile(checkpointer=MemorySaver())

async def run_query(question: str, is_employee: bool):
    global CURRENT_IS_EMPLOYEE
    CURRENT_IS_EMPLOYEE = is_employee
    print(f"\n{'='*70}")
    print(f"User: {question}")
    print(f"Employee Access: {is_employee}")
    print('='*70)
    
    inputs = {
        "messages": [HumanMessage(content=question)],
        "is_employee": is_employee,
        "tool_call_count": 0
    }
    config = {"configurable": {"thread_id": "final_run"}}
    
    async for event in app.astream(inputs, config):
        for value in event.values():
            if "messages" in value:
                msg = value["messages"][-1]
                if isinstance(msg, AIMessage):
                    if msg.tool_calls:
                        query = msg.tool_calls[0]['args']['query']
                        print(f"[TOOL CALL] Query: {query}")
                    else:
                        # Clean up any LaTeX or boxed formatting
                        answer = msg.content
                        # Remove \(\boxed{...}\) and similar
                        answer = re.sub(r'\\\(\s*\\boxed\{([^}]+)\}\s*\\\)', r'\1', answer)
                        answer = re.sub(r'\\boxed\{([^}]+)\}', r'\1', answer)
                        # Ensure dollar format if number detected
                        if re.search(r'\d+\.\d{2}', answer) and '$' not in answer:
                            answer = re.sub(r'(\d+\.\d{2})', r'$\1', answer)
                        print(f"[AI] {answer}")
                elif hasattr(msg, 'content'):
                    print(f"[RESULT/ERROR] {msg.content}")

# Run both tests
await run_query("What is the most costly purchase by Frank Harris?", is_employee=False)
await run_query("What is the most costly purchase by Frank Harris?", is_employee=True)


User: What is the most costly purchase by Frank Harris?
Employee Access: False
[AI] I'm sorry, but I cannot answer this question due to restrictions. The provided schema does not include tables related to customers or purchases, so I don't have access to information about individual purchases made by specific customers like Frank Harris.

User: What is the most costly purchase by Frank Harris?
Employee Access: True
[TOOL CALL] Query: SELECT i.InvoiceId, i.Total
FROM Invoice i
JOIN Customer c ON i.CustomerId = c.CustomerId
WHERE c.FirstName = 'Frank' AND c.LastName = 'Harris'
ORDER BY i.Total DESC
LIMIT 1;
[RESULT/ERROR] [(145, 13.86)]
[AI] Frank Harris's most costly purchase was for an invoice with a total of **$13.86**.
