In [None]:
import os
from dotenv import load_dotenv
from typing import Dict, List, Any, Optional, Annotated, Union, Literal
from typing_extensions import TypedDict

from langchain_openai import ChatOpenAI
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.vectorstores import SKLearnVectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_core.tools import tool
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.runnables import RunnablePassthrough
from langgraph.graph import MessagesState

from langgraph.prebuilt import tools_condition, ToolNode

from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode

from IPython.display import Image, display

In [None]:
load_dotenv()

In [None]:
db = SQLDatabase.from_uri("sqlite:///../sql-support-bot/chinook.db")
print(db.get_usable_table_names())

In [None]:
os.environ["LANGCHAIN_PROJECT"] = "music-store-support-demo-prep"

## Set up state

In [None]:
class MusicStoreChatbotState(MessagesState):
    customer_id: Optional[int]
    current_mode: Literal["router", "account", "music"]

In [None]:
model = ChatOpenAI(temperature=0, model_name="gpt-4o-mini")

## Helper Functions

In [None]:
def get_last_human_message(messages: List[BaseMessage]) -> Optional[HumanMessage]:
    """Extract the last human message from the conversation history."""
    for message in reversed(messages):
        if isinstance(message, HumanMessage):
            return message
    return None

## Graph definition

### Node: Route based on Customer ID

In [None]:
@tool
def is_customer_id_known(state: MusicStoreChatbotState):
    """Routing based on Customer ID"""
    print("T" * 50)
    print("is_customer_id_known")
    print("T" * 50)
    return state

### Tools

In [None]:
@tool
def get_customer_info(customer_id: int):
    """Look up customer info given their ID. Requires customer authentication."""
    return db.run(f"SELECT * FROM customers WHERE CustomerId = {customer_id};")

@tool
def update_customer_info(customer_id: int, field: str, value: str):
    """
    Update a customer's information. This is a sensitive operation that requires human approval.
    - customer_id: The ID of the customer to update
    - field: The field to update (FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Email)
    - value: The new value for the field
    """
    allowed_fields = ["FirstName", "LastName", "Company", "Address", "City", 
                       "State", "Country", "PostalCode", "Phone", "Email"]

    if field not in allowed_fields:
        return f"Error: Cannot update field '{field}'. Allowed fields are: {', '.join(allowed_fields)}"

    # Verify customer exists first
    customer = db.run(f"SELECT * FROM customers WHERE CustomerId = {customer_id};")
    if not customer:
        return f"Error: No customer found with ID {customer_id}"

    # This is what would actually run after approval
    return f"Successfully updated {field} to '{value}' for customer {customer_id}"


get_tool_names = ["get_customer_info"]
get_tools = [get_customer_info]
update_tool_names = ["update_customer_info"]
update_tools = [update_customer_info]

### Node: Process Account Query

In [None]:
account_system_message = """You are a customer account specialist at a music store.
You can help customers:
1. Retrieve their account information
2. Update their profile details

IMPORTANT: 
- Always verify you have the customer_id before taking any actions.
- For security, any updates to customer information require human manager approval.
- DO NOT make updates without explicit approval.
"""

@tool
def handle_account_query(state: MusicStoreChatbotState):
    """Handle account-related queries"""
    assert state["customer_id"] is not None
    print("T" * 50)
    print("handle_account_query")
    print("T" * 50)

    # Build messages for the account agent
    messages = [SystemMessage(content=account_system_message)]
    # Add some conversation context
    context_messages = state["messages"][-5:] if len(state["messages"]) > 5 else state["messages"]
    messages.extend(context_messages)

    # Add customer context
    customer_id = state["customer_id"]
    messages.append(SystemMessage(content=f"Current authenticated customer ID: {customer_id}"))

    # Let the agent determine what to do next
    response = model.bind_tools(tools=[get_customer_info, update_customer_info]).invoke(messages)

    print(f"handle_account_query response = {response}")

    return {
        "messages": state["messages"] + [response]
    }


def handle_account_query_route_condition(state: MusicStoreChatbotState) -> str:
    handle_account_query_response = state[messages][-1]
    destination = None

    # Check if the last message has tool calls
    if hasattr(last_message, "tool_calls") and last_message.tool_calls:
        # Get the name of the called tool
        tool_name = last_message.tool_calls[0]["name"]
        
        # Route to the appropriate tool node
        if tool_name in get_tool_names:
            destination = "lookup_tools"
        elif tool_name in update_tool_names:
            destination = "update_tools"
    
    # If no tool calls or tool not recognized, route to END
    destination = END
    print("~" * 50)
    print(f"Destination = {destination}")
    print("~" * 50)
    return destination

### TODO: Separate out get and update customer tool nodes

### Update Customer ID if in query (TODO: 1. Verify, 2. Prompt user)

In [None]:
customer_id_prompt = """Your task is to help determine the Customer ID of the user based on their presented chat history. 
If you are able to determine their Customer ID, return just the ID. If the ID cannot be inferred, return UNKNOWN."""

# @tool
def get_customer_id(state: MusicStoreChatbotState):
    """Fetch the custimer ID if it is part of the query"""
    print("T" * 50)
    print("get_customer_id")
    print("T" * 50)
    assert state["customer_id"] is None

    messages = [SystemMessage(content=customer_id_prompt)]
    # Add some conversation context
    context_messages = state["messages"][-5:] if len(state["messages"]) > 5 else state["messages"]
    messages.extend(context_messages)

    # Let the agent determine what to do next
    response = model.invoke(messages)
    print(f"Raw router response = {response.content}")

    # Update the state with the correct department
    content = response.content.lower()
    if content.isdigit():
        return {"customer_id": int(content)}
    else:
        return state

In [None]:
def customer_id_known_route_condition(state: MusicStoreChatbotState) -> bool:
    """Return the routing condition based on the current state"""
    customer_id_is_known = state["customer_id"] is not None
    print("~" * 50)
    print(f"Destination = {customer_id_is_known}")
    print("~" * 50)
    return state["customer_id"] is not None

entry_builder = StateGraph(MusicStoreChatbotState)
entry_builder.add_node("is_customer_id_known", is_customer_id_known)
entry_builder.add_node("handle_account_query", handle_account_query)
entry_builder.add_node("get_customer_id", get_customer_id)
entry_builder.add_node("lookup_tools", ToolNode(get_tools))
entry_builder.add_node("update_tools", ToolNode(update_tools))

# Add the starting edge
entry_builder.add_edge(START, "is_customer_id_known")
entry_builder.add_edge("is_customer_id_known", "get_customer_id")
entry_builder.add_edge("lookup_tools", "handle_account_query")
entry_builder.add_edge("update_tools", "handle_account_query")

entry_builder.add_conditional_edges(
    "is_customer_id_known",  # Source node
    customer_id_known_route_condition,  # Function that returns the condition value
    {
        True: "handle_account_query",
        False: "get_customer_id",
    }
)

entry_builder.add_conditional_edges(
    "handle_account_query",  # Source node
    handle_account_query_route_condition,  # Function that returns the condition value
    {
        "update_tools": "update_tools",
        "lookup_tools": "lookup_tools",
        END: END
    }
)

entry_builder.add_conditional_edges(
    "get_customer_id",  # Source node
    customer_id_known_route_condition,  # Function that returns the condition value
    {
        True: "is_customer_id_known",
        False: END,
    }
)

memory = MemorySaver()
# Can hard-code interruptions  using `builder.compile(interrupt_before=["tools"], checkpointer=memory)`
graph = entry_builder.compile(checkpointer=memory)

In [None]:
display(Image(graph.get_graph().draw_mermaid_png()))

### Test

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "21"}}

# Then use it in your invoke call
graph.invoke(
    {"messages": [HumanMessage(content="Hey, my customer ID is 2. Please show me all my information")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "1"}}

# Then use it in your invoke call
graph.invoke(
    {"messages": [HumanMessage(content="Hey. recommend music by Amy Winehouse")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "1"}}

# Then use it in your invoke call
graph.invoke(
    {"messages": [HumanMessage(content="Hey, my customer ID is 2. Please update my email to a@b.com")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)