#### Enviroment configuration

In [0]:
%restart_python

In [0]:
import os
os.environ["OPENAI_API_KEY"] = dbutils.secrets.get(scope = "langgraph", key = "openai-api-key")
os.environ["ANTHROPIC_API_KEY"] = dbutils.secrets.get(scope = "langgraph", key = "anthropic-api-key")

#### Define in memory store

In [0]:
from langchain.embeddings import init_embeddings
from langgraph.store.memory import InMemoryStore

in_memory_embd_store = InMemoryStore(
    index={
        "embed": init_embeddings("openai:text-embedding-3-small"),  # Embedding provider
        "dims": 1536,                              # Embedding dimensions
        "fields": ["content"]              # Fields to embed - updated to 'content'
    }
)

user_id = "1"
namespace_for_memory = (user_id, "embd_memory_store")

#### Define tools and model

In [0]:
from langchain_core.tools import tool
from langchain.chat_models import init_chat_model
import requests

# init LLM
llm = init_chat_model(
    "openai:gpt-4o",
    temperature=0
)

# Define tools
@tool
def get_weather(latitude, longitude):
    """Get the weather for a given latitude and longitude"""

    base_url = "https://api.open-meteo.com/v1/forecast"

    other_params = "current=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m"

    response = requests.get(
        f"{base_url}?latitude={latitude}&longitude={longitude}&{other_params}"
    )
    data = response.json()
    return data["current"]["temperature_2m"]

In [0]:
from langchain_mcp_adapters.client import MultiServerMCPClient

In [0]:
client = MultiServerMCPClient(
    {
        "get_news": {
            # Make sure you start your weather server on port 8000
            "url": "https://mcp-server-sse-91511769710.europe-west3.run.app/sse",
            "transport": "sse",
        }
    }
)
mcp_tools = await client.get_tools()

In [0]:
for tool in mcp_tools:
    print(tool.get_name())

In [0]:
# Augment the LLM with tools
tools = [get_weather] + [tool for tool in mcp_tools]
tools_by_name = {tool.name: tool for tool in tools}
llm_with_tools = llm.bind_tools(tools)

In [0]:
tools

#### Define state structure

In [0]:
from langchain_core.messages import AnyMessage
from typing_extensions import TypedDict, Annotated
import operator

class MessagesState(TypedDict):
    messages: Annotated[list[AnyMessage], operator.add]
    llm_calls: int

#### Define model node

In [0]:
from langchain_core.messages.utils import (
    trim_messages,
    count_tokens_approximately
)

def message_trimmer(state: MessagesState):
    messages = trim_messages(
        state["messages"],
        strategy="last",
        token_counter=count_tokens_approximately,
        max_tokens=256,
        start_on="human",
        end_on=("human", "tool"),
    )
    response = llm_with_tools.invoke(messages)
    return {"messages": [response]}

In [0]:
from langchain_core.messages import SystemMessage

def llm_call(state: dict):
    """LLM decides whether to call a tool or not"""

    return {
        "messages": [
            llm_with_tools.invoke(
                [
                    SystemMessage(
                        #content=system_prompt_with_rag(state)
                        content='You are a helpful agent! Use the available tools to find the answer!'
                    )
                ]
                + state["messages"]
            )
        ],
        "llm_calls": state.get('llm_calls', 0) + 1
    }

#### Define tool node

In [0]:
from langchain_core.messages import ToolMessage, AIMessage, HumanMessage

def tool_node(state: dict):
    """Performs the tool call"""

    result = []
    for tool_call in state["messages"][-1].tool_calls:
        tool = tools_by_name[tool_call["name"]]
        observation = tool.invoke(tool_call["args"])
        result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
    return {"messages": result}

#### Add shoud countinue logic

In [0]:
from typing import Literal
from langgraph.graph import StateGraph, START, END

# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call
def should_continue(state: MessagesState) -> Literal["tool_node", "update_memory"]:
    """Decide if we should continue the loop or stop based upon whether the LLM made a tool call"""

    messages = state["messages"]
    last_message = messages[-1]
    # If the LLM makes a tool call, then perform an action
    if last_message.tool_calls:
        return "tool_node"
    # Otherwise, we stop (reply to the user)
    return "update_memory"

#### Build and show agent graph

In [0]:
def update_memory(state: dict):
    memory_id = str(uuid.uuid4())
    # Extract content from messages and store as a string
    memory_content = ""
    for m in state['messages']:
        if isinstance(m, HumanMessage):
            memory_content += "Human: " + str(m.content) + " "
        elif isinstance(m, AIMessage):
            memory_content += "AI: " + str(m.content) + " "
    memory_content = memory_content.strip()

    memory = {"content" : memory_content}  # Store the content in a serializable format
    in_memory_embd_store.put(namespace_for_memory,
                            memory_id,
                            memory,
                            index=['content']) # Index on the new 'content' field

In [0]:
from langgraph.checkpoint.memory import InMemorySaver

saver = InMemorySaver()

# Build workflow
agent_builder = StateGraph(MessagesState)

# Add nodes
#agent_builder.add_node("rag", rag)
#agent_builder.add_node("search_memory", search_memory)
agent_builder.add_node("pre_model_hook", message_trimmer)
agent_builder.add_node("llm_call", llm_call)
agent_builder.add_node("update_memory", update_memory)
agent_builder.add_node("tool_node", tool_node)

# Add edges to connect nodes
#agent_builder.add_edge(START, "rag")
#agent_builder.add_edge("rag", "llm_call")
#agent_builder.add_edge(START, "llm_call")
agent_builder.add_edge(START, "pre_model_hook")
#agent_builder.add_edge(START,"search_memory")
#agent_builder.add_edge("search_memory", "pre_model_hook")
agent_builder.add_edge("pre_model_hook", "llm_call")
agent_builder.add_conditional_edges(
    "llm_call",
    should_continue,
    #["tool_node", END]
    ["tool_node", "update_memory"]
)
agent_builder.add_edge("tool_node", "llm_call")
#agent_builder.add_edge("llm_call", "update_memory")
agent_builder.add_edge("update_memory", END)

# Compile the agent
agent = agent_builder.compile(checkpointer=saver,
                              store=in_memory_embd_store)

from IPython.display import Image, display
# Show the agent
display(Image(agent.get_graph(xray=True).draw_mermaid_png()))

#### Set up mlflow experiment to trace agent workflow

In [0]:
import mlflow

mlflow.login()
mlflow.set_experiment('/Agent-Workflow')

In [0]:
mlflow.langchain.autolog()

#### Create thread and run agent

In [0]:
# Configure thread
import uuid
config = {"configurable": {"thread_id": str(uuid.uuid4())}}

input_message = input('Input your question for the agent: ')

# Invoke
from langchain_core.messages import HumanMessage, BaseMessage # Import BaseMessage for type checking
messages = [HumanMessage(content=input_message)]
messages = await agent.ainvoke({"messages": messages}, config)
#for m in messages["messages"]:
#    if isinstance(m, BaseMessage): # Check if it's a message object
#        m.pretty_print()
#    else:
#        print(m) # Print the string directly if not a message object

In [0]:
mlflow.search_traces().head(5)