In [None]:
import os
from dotenv import load_dotenv
from typing import Dict, List, Any, Optional, Annotated, Union, Literal
from typing_extensions import TypedDict

from langchain_openai import ChatOpenAI
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.vectorstores import SKLearnVectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_core.tools import tool
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AIMessage, ToolMessage
from langchain_core.runnables import RunnablePassthrough
from langgraph.graph import MessagesState

from langgraph.prebuilt import tools_condition, ToolNode

from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode

from IPython.display import Image, display

In [None]:
load_dotenv()

In [None]:
db = SQLDatabase.from_uri("sqlite:///../sql-support-bot/chinook.db")
print(db.get_usable_table_names())

In [None]:
os.environ["LANGCHAIN_PROJECT"] = "music-store-support-demo-prep"

## Set up state

In [None]:
class MusicStoreChatbotState(MessagesState):
    customer_id: Optional[int]
    current_mode: Literal["router", "account", "music"]

In [None]:
model = ChatOpenAI(temperature=0, model_name="gpt-4o-mini")

## Helper Functions

In [None]:
def get_last_human_message(messages: List[BaseMessage]) -> Optional[HumanMessage]:
    """Extract the last human message from the conversation history."""
    for message in reversed(messages):
        if isinstance(message, HumanMessage):
            return message
    return None

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "1"}}

# Then use it in your invoke call
graph.invoke(
    {"messages": [HumanMessage(content="Hey. recommend music by Amy Winehouse")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "1"}}

# Then use it in your invoke call
graph.invoke(
    {"messages": [HumanMessage(content="Hey, my customer ID is 2. Please update my email to a@b.com")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "1"}}

# Then use it in your invoke call
graph.invoke(
    {"messages": [HumanMessage(content="What is the meaning of life?")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

## Recommendation Agent

### Tool Helpers- Allow fuzzy retrieval for artists and tracks/songs

In [None]:
artists = db._execute("SELECT * FROM artists")
songs = db._execute("SELECT * FROM tracks")
artist_retriever = SKLearnVectorStore.from_texts(
    [a['Name'] for a in artists],
    OpenAIEmbeddings(),
    metadatas=artists
).as_retriever()


song_retriever = SKLearnVectorStore.from_texts(
    [a['Name'] for a in songs],
    OpenAIEmbeddings(),
    metadatas=songs
).as_retriever()

### LLM with Tools

In [None]:
@tool
def get_albums_by_artist(artist: str):
    """Get albums by an artist (or similar artists)."""
    print("T" * 50)
    print("get_albums_by_artist tool called")
    print("T" * 50)
    docs = artist_retriever.get_relevant_documents(artist)
    artist_ids = ", ".join([str(d.metadata['ArtistId']) for d in docs])
    return db.run(
        f"SELECT Title, Name as ArtistName FROM albums LEFT JOIN artists ON albums.ArtistId = artists.ArtistId WHERE albums.ArtistId in ({artist_ids});",
        include_columns=True)


@tool
def get_tracks_by_artist(artist: str):
    """Get songs by an artist (or similar artists)."""
    print("T" * 50)
    print("get_tracks_by_artist tool called")
    print("T" * 50)
    docs = artist_retriever.get_relevant_documents(artist)
    artist_ids = ", ".join([str(d.metadata['ArtistId']) for d in docs])
    return db.run(
        f"SELECT tracks.Name as SongName, artists.Name as ArtistName FROM albums LEFT JOIN artists ON albums.ArtistId = artists.ArtistId LEFT JOIN tracks ON tracks.AlbumId = albums.AlbumId WHERE albums.ArtistId in ({artist_ids});",
        include_columns=True)


@tool
def search_songs(song_title: str):
    """Search for songs by title."""
    print("T" * 50)
    print("search_songs tool called")
    print("T" * 50)
    docs = song_retriever.get_relevant_documents(song_title)
    return [{"Title": doc.page_content, "TrackId": doc.metadata["TrackId"]} for doc in docs[:5]]


@tool
def get_similar_music(genre: str, artist: str = ""):
    """Get music recommendations based on genre and optionally artist."""
    print("T" * 50)
    print("get_similar_music tool called")
    print("T" * 50)
    # Find genre ID
    genre_query = f"""
    SELECT GenreId FROM genres 
    WHERE Name LIKE '%{genre}%' 
    LIMIT 1
    """
    genre_result = db.run(genre_query)

    if not genre_result:
        return f"No genre found matching '{genre}'"

    genre_id = genre_result.split('|')[0].strip()

    # Get tracks in that genre, optionally filtering by artist
    artist_filter = ""
    if artist:
        docs = artist_retriever.get_relevant_documents(artist)
        if docs:
            artist_ids = ", ".join([str(d.metadata['ArtistId']) for d in docs])
            artist_filter = f"AND albums.ArtistId IN ({artist_ids})"

    query = f"""
    SELECT tracks.Name as TrackName, artists.Name as ArtistName, albums.Title as AlbumTitle
    FROM tracks
    JOIN albums ON tracks.AlbumId = albums.AlbumId
    JOIN artists ON albums.ArtistId = artists.ArtistId
    WHERE tracks.GenreId = {genre_id} {artist_filter}
    ORDER BY RANDOM()
    LIMIT 5
    """

    return db.run(query, include_columns=True)

In [None]:
tools = [
    get_albums_by_artist,
    get_tracks_by_artist,
    search_songs,
    get_similar_music
]
music_recommendation_model = ChatOpenAI(temperature=0, model_name="gpt-4o-mini")
music_recommendation_model_with_tools = music_recommendation_model.bind_tools(tools)

### Graph Definition

In [None]:
# Music Recommendation Agent
music_system_message = """You are a music specialist at a music store.
You can help customers:
1. Find music by specific artists
2. Discover new songs similar to their interests
3. Search for specific tracks
4. Get recommendations based on genres

Always be conversational and enthusiastic about music. If you don't find exactly what they're looking for, suggest alternatives.
"""


def handle_music_query(state: MusicStoreChatbotState):
    """Handle music-related queries"""
    print("Processing music query...")
    
    # Build messages for the music agent
    messages = [SystemMessage(content=music_system_message)]
    # Add some conversation context
    context_messages = state["messages"][-5:] if len(state["messages"]) > 5 else state["messages"]
    messages.extend(context_messages)

    # Let the agent determine what to do next
    response = music_recommendation_model_with_tools.invoke(messages)
    print(f"Music agent: {response.content}")

    return {
        "messages": state["messages"] + [response]
    }

In [None]:
music_recommendation_graph_builder = StateGraph(MusicStoreChatbotState)

music_recommendation_graph_builder.add_node("handle_music_query", handle_music_query)
music_recommendation_graph_builder.add_node("tools", ToolNode(tools))

# Add the starting edge
music_recommendation_graph_builder.add_edge(START, "handle_music_query")
music_recommendation_graph_builder.add_conditional_edges(
    "handle_music_query",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
    tools_condition,
)
music_recommendation_graph_builder.add_edge("tools", "handle_music_query")
music_recommendation_graph_builder.add_edge("handle_music_query", END)

memory = MemorySaver()
# Can hard-code interruptions  using `builder.compile(interrupt_before=["tools"], checkpointer=memory)`
music_recommendation_graph = music_recommendation_graph_builder.compile(checkpointer=memory)

In [None]:
display(Image(music_recommendation_graph.get_graph().draw_mermaid_png()))

### Test

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "1"}}

# Then use it in your invoke call
music_recommendation_graph.invoke(
    {"messages": [HumanMessage(content="Hey, recommend music by Amy Winehouse")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "2"}}

# Then use it in your invoke call
response = music_recommendation_graph.invoke(
    {"messages": [HumanMessage(content="Hey, recommend songs by Green Day")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "3"}}

# Then use it in your invoke call
response = music_recommendation_graph.invoke(
    {"messages": [HumanMessage(content="Hey, is Boulevard Of Broken Dreams available for purchase?")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)

In [None]:
# Define a thread config dictionary
thread_config = {"configurable": {"thread_id": "4"}}

# Then use it in your invoke call
response = music_recommendation_graph.invoke(
    {"messages": [HumanMessage(content="Hey, I like Boulevard Of Broken Dreams. What other albums and songs would you recommend?")], 
     "current_mode": "router", 
     "customer_id": None}, 
    thread_config
)