# Dynamic Prompt
<img src="./assets/LC_DynamicPrompts.png" width="500">

## Setup

Load and/or check for needed environmental variables

In [1]:
from dotenv import load_dotenv
from env_utils import doublecheck_env

# Load environment variables from .env
load_dotenv()

# Check and print results
doublecheck_env(".env")

OPENAI_API_KEY=****here
LANGSMITH_API_KEY=****754b
LANGSMITH_TRACING=true
LANGSMITH_PROJECT=****ials


In [2]:
# Import required modules
import re
import nest_asyncio
from typing import Annotated, List, TypedDict, Literal, Set
from langchain_community.utilities import SQLDatabase
from langchain_core.tools import tool
from langchain_ollama import ChatOllama
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver

nest_asyncio.apply()

# Load database
db = SQLDatabase.from_uri("sqlite:///Chinook.db")

# Global access flag
CURRENT_IS_EMPLOYEE = False

ALLOWED_TABLES_NON_EMPLOYEE: Set[str] = {
    "Album", "Artist", "Genre", "Playlist", "PlaylistTrack", "Track"
}

def extract_table_names(query: str) -> Set[str]:
    q = query.upper()
    tables = set()
    patterns = [r'\bFROM\s+([^\s\(\),;]+)', r'\bJOIN\s+([^\s\(\),;]+)']
    for pattern in patterns:
        for match in re.findall(pattern, q):
            table = match.strip().split()[0].strip('"`[]')
            if table:
                tables.add(table)
    return tables

@tool(description="Execute a read-only SQLite SELECT query and return results as a string.")
def execute_sql(query: str) -> str:
    global CURRENT_IS_EMPLOYEE
    try:
        if not query.strip().upper().startswith("SELECT"):
            return "Error: Only SELECT queries are allowed."
        
        tables_used = extract_table_names(query)
        
        if not CURRENT_IS_EMPLOYEE:
            forbidden = tables_used - ALLOWED_TABLES_NON_EMPLOYEE
            if forbidden:
                return (f"Error: Access denied to tables: {sorted(forbidden)}. "
                        f"Allowed: {sorted(ALLOWED_TABLES_NON_EMPLOYEE)}")
        
        clean_query = query.strip().replace('```sql', '').replace('```', '').rstrip(';') + ';'
        return str(db.run(clean_query))
    
    except Exception as e:
        return f"Error: {e}"

tools = [execute_sql]
tool_node = ToolNode(tools)

class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage], add_messages]
    is_employee: bool
    tool_call_count: int

# Get filtered schema
def get_schema_info(is_employee: bool) -> str:
    if is_employee:
        return db.get_table_info()  # Full schema
    else:
        # Only include allowed tables
        all_info = db.get_table_info()
        lines = all_info.split("\n")
        filtered = []
        current_table = None
        for line in lines:
            if line.startswith("CREATE TABLE"):
                # Extract table name: CREATE TABLE "Album" ...
                table_match = re.search(r'CREATE TABLE ["`]([^"`]+)["`]', line)
                current_table = table_match.group(1) if table_match else None
                if current_table in ALLOWED_TABLES_NON_EMPLOYEE:
                    filtered.append(line)
            elif current_table in ALLOWED_TABLES_NON_EMPLOYEE:
                filtered.append(line)
            elif line.strip() == ");":
                if current_table in ALLOWED_TABLES_NON_EMPLOYEE:
                    filtered.append(line)
                current_table = None
        return "\n".join(filtered)

llm = ChatOllama(model="qwen2.5-coder:7b", temperature=0.8)

def get_system_message(is_employee: bool) -> SystemMessage:
    schema_info = get_schema_info(is_employee)
    if is_employee:
        limits = ""
    else:
        limits = ("- You ONLY have access to these tables: Album, Artist, Genre, Playlist, PlaylistTrack, Track.\n"
                  "- NEVER mention or use Customer, Invoice, InvoiceLine, Employee, etc.")
    
    return SystemMessage(content=f"""You are a careful SQLite analyst working with the Chinook music store database.

{limits}

Here is the relevant database schema:
{schema_info}

Rules:
- Think step-by-step.
- Use ONLY the tables and column names shown above (case-sensitive!).
- For employees: Customer.FirstName, Customer.LastName, Invoice.Total, InvoiceLine.Quantity, Track.Name, etc.
- Always use explicit column names; avoid SELECT *.
- Limit results to 5 rows unless asked for more.
- If you get an error, revise the query using the correct table/column names from the schema.
- If access is denied, stop and explain clearly.
""")

def call_model(state: AgentState):
    sys_msg = get_system_message(state["is_employee"])
    response = llm.invoke([sys_msg] + state["messages"])
    
    if not getattr(response, 'tool_calls', None):
        sql_match = re.search(r'(SELECT\s+.*?)(?:;|\s*$)', response.content, re.IGNORECASE | re.DOTALL)
        if sql_match:
            query = sql_match.group(1).strip() + ";"
            response.tool_calls = [{
                "name": "execute_sql",
                "args": {"query": query},
                "id": "manual_fix"
            }]
    
    return {"messages": [response]}

def route_after_agent(state: AgentState) -> Literal["tools", "__end__"]:
    last = state["messages"][-1]
    if state.get("tool_call_count", 0) >= 3:
        return "__end__"
    if getattr(last, 'tool_calls', None):
        return "tools"
    content = last.content.upper()
    if "SELECT" in content and "FROM" in content:
        return "tools"
    return "__end__"

def update_tool_count(state: AgentState) -> dict:
    return {"tool_call_count": state.get("tool_call_count", 0) + 1}

workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)
workflow.add_node("update_count", update_tool_count)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", route_after_agent)
workflow.add_edge("tools", "update_count")
workflow.add_edge("update_count", "agent")

app = workflow.compile(checkpointer=MemorySaver())

async def run_query(question: str, is_employee: bool):
    global CURRENT_IS_EMPLOYEE
    CURRENT_IS_EMPLOYEE = is_employee
    
    print(f"\n{'='*70}")
    print(f"User: {question}")
    print(f"Employee Access: {is_employee}")
    print('='*70)
    
    inputs = {
        "messages": [HumanMessage(content=question)],
        "is_employee": is_employee,
        "tool_call_count": 0
    }
    config = {"configurable": {"thread_id": "run_fixed"}}
    
    async for event in app.astream(inputs, config):
        for value in event.values():
            if "messages" in value:
                msg = value["messages"][-1]
                if isinstance(msg, AIMessage):
                    if msg.tool_calls:
                        query = msg.tool_calls[0]['args']['query']
                        print(f"[TOOL CALL] Query: {query}")
                    else:
                        print(f"[AI] {msg.content}")
                elif hasattr(msg, 'content'):
                    print(f"[RESULT/ERROR] {msg.content}")

# Test both modes
await run_query("What is the most costly purchase by Frank Harris?", is_employee=False)
await run_query("What is the most costly purchase by Frank Harris?", is_employee=True)


User: What is the most costly purchase by Frank Harris?
Employee Access: False
[TOOL CALL] Query: SELECT i.Total
FROM Customer c
JOIN Invoice i ON c.CustomerId = i.CustomerId
WHERE c.FirstName = 'Frank' AND c.LastName = 'Harris'
ORDER BY i.Total DESC
LIMIT 1;
[RESULT/ERROR] Error: Access denied to tables: ['CUSTOMER', 'INVOICE']. Allowed: ['Album', 'Artist', 'Genre', 'Playlist', 'PlaylistTrack', 'Track']
[TOOL CALL] Query: SELECT t.TrackId, t.Name, COUNT(pt.PlaylistId) AS PlaylistCount
FROM Track t
JOIN PlaylistTrack pt ON t.TrackId = pt.TrackId
GROUP BY t.TrackId, t.Name
ORDER BY PlaylistCount DESC
LIMIT 1;
[RESULT/ERROR] Error: Access denied to tables: ['PLAYLISTTRACK', 'TRACK']. Allowed: ['Album', 'Artist', 'Genre', 'Playlist', 'PlaylistTrack', 'Track']
[TOOL CALL] Query: SELECT Title
FROM Album
ORDER BY Title ASC;
[RESULT/ERROR] Error: Access denied to tables: ['ALBUM']. Allowed: ['Album', 'Artist', 'Genre', 'Playlist', 'PlaylistTrack', 'Track']
[TOOL CALL] Query: SELECT Name
FROM