In [1]:
import os
from dotenv import load_dotenv
from typing import Dict, List, Any, Optional, Annotated, Union, Literal
from typing_extensions import TypedDict

from langchain_openai import ChatOpenAI
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.vectorstores import SKLearnVectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_core.tools import tool
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.runnables import RunnablePassthrough
from langgraph.graph import MessagesState

from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode

from IPython.display import Image, display

In [2]:
load_dotenv()

True

In [3]:
model = ChatOpenAI(temperature=0, model_name="gpt-4o-mini")

In [4]:
db = SQLDatabase.from_uri("sqlite:///sql-support-bot/chinook.db")
print(db.get_usable_table_names())

['albums', 'artists', 'customers', 'employees', 'genres', 'invoice_items', 'invoices', 'media_types', 'playlist_track', 'playlists', 'tracks']


In [5]:
os.environ["LANGCHAIN_PROJECT"] = "music-store-support-demo-prep"

## Set up state

In [6]:
class MusicStoreChatbotState(MessagesState):
    customer_id: Optional[int]
    current_mode: Literal["router", "account", "music"]

## Define Tools

### Tool helper: Allow fuzzy retrieval for artists and tracks/songs

In [7]:
artists = db._execute("SELECT * FROM artists")
songs = db._execute("SELECT * FROM tracks")
artist_retriever = SKLearnVectorStore.from_texts(
    [a['Name'] for a in artists],
    OpenAIEmbeddings(),
    metadatas=artists
).as_retriever()

In [8]:
song_retriever = SKLearnVectorStore.from_texts(
    [a['Name'] for a in songs],
    OpenAIEmbeddings(),
    metadatas=songs
).as_retriever()

### Tool definitions

In [9]:
# Account Management Tools

#######     1      #######
@tool
def get_customer_info(customer_id: int):
    """Look up customer info given their ID. Requires customer authentication."""
    return db.run(f"SELECT * FROM customers WHERE CustomerId = {customer_id};")


#######     2      #######
@tool
def update_customer_info(customer_id: int, field: str, value: str):
    """
    Update a customer's information. This is a sensitive operation that requires human approval.
    - customer_id: The ID of the customer to update
    - field: The field to update (FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Email)
    - value: The new value for the field
    """
    allowed_fields = ["FirstName", "LastName", "Company", "Address", "City", 
                       "State", "Country", "PostalCode", "Phone", "Email"]

    if field not in allowed_fields:
        return f"Error: Cannot update field '{field}'. Allowed fields are: {', '.join(allowed_fields)}"

    # Verify customer exists first
    customer = db.run(f"SELECT * FROM customers WHERE CustomerId = {customer_id};")
    if not customer:
        return f"Error: No customer found with ID {customer_id}"

    # This is what would actually run after approval
    return f"Successfully updated {field} to '{value}' for customer {customer_id}"


# Music Recommendation Tools

#######     3      #######
@tool
def get_albums_by_artist(artist: str):
    """Get albums by an artist (or similar artists)."""
    docs = artist_retriever.get_relevant_documents(artist)
    artist_ids = ", ".join([str(d.metadata['ArtistId']) for d in docs])
    return db.run(
        f"SELECT Title, Name as ArtistName FROM albums LEFT JOIN artists ON albums.ArtistId = artists.ArtistId WHERE albums.ArtistId in ({artist_ids});",
        include_columns=True)


#######     4      #######
@tool
def get_tracks_by_artist(artist: str):
    """Get songs by an artist (or similar artists)."""
    docs = artist_retriever.get_relevant_documents(artist)
    artist_ids = ", ".join([str(d.metadata['ArtistId']) for d in docs])
    return db.run(
        f"SELECT tracks.Name as SongName, artists.Name as ArtistName FROM albums LEFT JOIN artists ON albums.ArtistId = artists.ArtistId LEFT JOIN tracks ON tracks.AlbumId = albums.AlbumId WHERE albums.ArtistId in ({artist_ids});",
        include_columns=True)


#######     5      #######
@tool
def search_songs(song_title: str):
    """Search for songs by title."""
    docs = song_retriever.get_relevant_documents(song_title)
    return [{"Title": doc.page_content, "TrackId": doc.metadata["TrackId"]} for doc in docs[:5]]

In [10]:
tools = [
    get_customer_info,
    update_customer_info,
    get_albums_by_artist,
    get_tracks_by_artist,
    search_songs
]
llm_with_tools = model.bind_tools(tools)

## Graph definition

### Router Agent

In [11]:
# Router Agent
router_system_message = """You are a helpful customer support assistant for a music store.
Your job is to determine what the customer needs help with and route them to the appropriate department:

1. ACCOUNT - For updating personal information, account details, or authentication
2. MUSIC - For music recommendations, searching for songs/artists, or questions about music inventory
3. ROUTER - for anything that does not fit into ACCOUNT or MUSIC

Respond with the department that best matches their query.
If they want to access personal information and have not provided a customer ID, please ask for the ID.
"""


def route_conversation(state: MusicStoreChatbotState) -> Literal["account", "music", "router"]:
    """Route the conversation to the appropriate department"""
    print("Routing conversation...")
    last_human_message = get_last_human_message(state["messages"])
    if not last_human_message:
        print("No human message found...")
        return state

    # Prepare conversation for the router
    router_messages = [SystemMessage(content=router_system_message), last_human_message]
    response = model.invoke(router_messages)

    # Update the state with the correct department
    content = response.content.lower()
    if "account" in content:
        department = "account"
    elif "music" in content:
        department = "music"
    else:
        department = "router"  # Stay in router mode if unclear

    # Return the routing decision
    return {"current_mode": department}


def get_last_human_message(messages: List[BaseMessage]) -> Optional[HumanMessage]:
    """Extract the last human message from the conversation history."""
    for message in reversed(messages):
        if isinstance(message, HumanMessage):
            return message
    return None

In [12]:
def music_node(state: MusicStoreChatbotState):
    print("music_node invoked")
    return state


def account_node(state: MusicStoreChatbotState):
    print("account_node invoked")
    return state


def route_condition(state: MusicStoreChatbotState) -> str:
    """Return the routing condition based on the current state"""
    return state["current_mode"]

entry_builder = StateGraph(MusicStoreChatbotState)
entry_builder.add_node("route_conversation", route_conversation)
# For sub-graphs, use the sub-graph builder instead of the function name, 
# ex. entry_builder.add_node("music_node", qs_builder.compile())
entry_builder.add_node("music_node", music_node)
entry_builder.add_node("account_node", account_node)

# Add the starting edge
entry_builder.add_edge(START, "route_conversation")

# Add conditional edges - this is the key part
entry_builder.add_conditional_edges(
    "route_conversation",  # Source node
    route_condition,  # Function that returns the condition value
    {
        "music": "music_node",
        "account": "account_node",
        "router": END
    }
)

# Add the ending edges
entry_builder.add_edge("music_node", END)
entry_builder.add_edge("account_node", END)

memory = MemorySaver()
# Can hard-code interruptions  using `builder.compile(interrupt_before=["tools"], checkpointer=memory)`
graph = entry_builder.compile(checkpointer=memory)

In [13]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "1"}}

# Then use it in your invoke call
graph.invoke(
    {"messages": [HumanMessage(content="Hey. recommend music by Amy Winehouse")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

Routing conversation...
music_node invoked


{'messages': [HumanMessage(content='Hey. recommend music by Amy Winehouse', additional_kwargs={}, response_metadata={}, id='cb91263a-f713-4c26-b78a-7badd1d0b77a')],
 'customer_id': None,
 'current_mode': 'music'}

In [14]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "1"}}

# Then use it in your invoke call
graph.invoke(
    {"messages": [HumanMessage(content="Hey, my customer ID is 2. Please update my email to a@b.com")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

Routing conversation...
account_node invoked


{'messages': [HumanMessage(content='Hey. recommend music by Amy Winehouse', additional_kwargs={}, response_metadata={}, id='cb91263a-f713-4c26-b78a-7badd1d0b77a'),
  HumanMessage(content='Hey, my customer ID is 2. Please update my email to a@b.com', additional_kwargs={}, response_metadata={}, id='fb80f5b5-f275-4d91-b78e-a2c64b3d9723')],
 'customer_id': None,
 'current_mode': 'account'}

In [15]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "1"}}

# Then use it in your invoke call
graph.invoke(
    {"messages": [HumanMessage(content="What is the meaning of life?")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

Routing conversation...


{'messages': [HumanMessage(content='Hey. recommend music by Amy Winehouse', additional_kwargs={}, response_metadata={}, id='cb91263a-f713-4c26-b78a-7badd1d0b77a'),
  HumanMessage(content='Hey, my customer ID is 2. Please update my email to a@b.com', additional_kwargs={}, response_metadata={}, id='fb80f5b5-f275-4d91-b78e-a2c64b3d9723'),
  HumanMessage(content='What is the meaning of life?', additional_kwargs={}, response_metadata={}, id='de8480b9-6b0a-4b42-9ecb-a7cb9c4d99dc')],
 'customer_id': None,
 'current_mode': 'router'}

In [16]:
# display(Image(graph.get_graph().draw_mermaid_png()))

In [17]:
# # Request an update that requires approval
# message = "Please change my email to newemail@example.com"
# response, needs_approval, update_info = await process_message(message, "user123")
# print(f"User: {message}")
# print(f"Bot: {response}")

# if needs_approval:
#     print("\n[Human manager approval required]")
#     # Simulate manager approval
#     approval_response = await resume_with_approval(True, update_info, "user123")
#     print(f"Bot (after approval): {approval_response}")