# LangGraph custom schema agent notebook

This notebook shows you how to write a LangGraph AI agent compatible with Mosaic AI Agent Framework that accepts custom inputs and returns custom outputs. 

To ensure compatibility, the agent must conform to Mosaic AI Agent Framework schema requirements, see ([AWS](https://docs.databricks.com/en/generative-ai/agent-framework/agent-schema.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/agent-schema)).

### Model-as-code notebook

Mosaic AI Agent Framework uses MLflow's **Models-as-code** development workflow, which requires two notebooks: 

- An agent notebook that defines the agent's logic (this notebook)
- A driver notebook that logs, registers, and deploys the agent
  - You can find the driver notebook for this agent, **custom-langgraph-schema-driver**, on Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/agent-framework/agent-schema.html#langgraph-custom-schema-driver-notebook) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/agent-schema#langgraph-custom-schema-driver-notebook))

For more information on Model-as-code, see MLflow's [Models as code guide](https://mlflow.org/docs/latest/model/models-from-code.html).

## Requirements

This notebook requires a Unity Catalog enabled workspace.

 **_NOTE:_**  This notebook uses LangGraph, but Mosaic AI Agent Framework is compatible with other agent authoring frameworks, like LlamaIndex.

In [0]:
# %pip install -U -qqqq mlflow>=2.19.0 langchain==0.2.16 langgraph-checkpoint==1.0.12 langchain_core langgraph==0.2.16 pydantic databricks-langchain

# %pip install -U -qqqq mlflow langchain==0.2.16 langgraph-checkpoint==1.0.12 langchain_core langgraph==0.2.16 pydantic databricks-langchain

In [0]:
%pip install -U mlflow langchain langgraph-checkpoint langchain_core langgraph pydantic databricks-langchain
dbutils.library.restartPython()

## Define the chat model
Create a LangGraph chat model that supports [LangGraph tool](https://langchain-ai.github.io/langgraph/how-tos/tool-calling/) calling.

In [0]:
import mlflow
# from databricks_langchain import VectorSearchRetrieverTool, ChatDatabricks
from databricks_langchain import ChatDatabricks

mlflow.langchain.autolog()

## Parse LangGraph output

The following cell defines helper methods for converting LangGraph output messages into the recommended output schema for Mosaic AI agent framework. The `wrap_output` helper returns chat-completion compatible messages, with an additional `custom_outputs` field containing custom outputs.

In [0]:
from typing import Iterator, Dict, Any
from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    ToolMessage,
    MessageLikeRepresentation,
)
from mlflow.types.llm import ChatCompletionRequest, ChatCompletionResponse, ChatChoice, ChatMessage
from random import randint
from dataclasses import asdict
import logging

import json

# You can add additional fields to the return object below
def create_flexible_chat_completion_response(content: str, id: int = 0) -> Dict:
    return asdict(ChatCompletionResponse(
        choices=[ChatChoice(message=ChatMessage(role="assistant", content=content))],
        custom_outputs={
            "id": id
        },
    ))

def wrap_output(stream: Iterator[MessageLikeRepresentation]) -> Iterator[Dict]:
    """
    Process and yield formatted outputs from the message stream.
    The invoke and stream langchain functions produce different output formats.
    This function handles both cases.
    """
    for event in stream:
        # the agent was called with invoke()
        if "messages" in event:
            output_content = ""
            for msg in event["messages"]:
                output_content += msg.content
            # Note: you can pass additional fields from your LangGraph nodes to the output here
            yield create_flexible_chat_completion_response(content=output_content, id=randint(100000000, 10000000000))
        # the agent was called with stream()
        else:
            for node in event:
                for key, messages in event[node].items():
                    if isinstance(messages, list):
                        for msg in messages:
                            # Note: you can pass additional fields from your LangGraph nodes to the output here
                            yield create_flexible_chat_completion_response(content=msg.content, id=randint(100000000, 10000000000))
                    else:
                        logging.warning(f"Unexpected value {messages} for key {key}. Expected a list of `MessageLikeRepresentation`'s")
                        yield create_flexible_chat_completion_response(content=str(messages))

## Create the agent
Use the LangGraph [`create_react_agent` function](https://langchain-ai.github.io/langgraph/how-tos/create-react-agent/#usage) to build a simple graph. For more customization, you can create your own LangGraph agent by following [LangGraph - Quick Start](https://langchain-ai.github.io/langgraph/tutorials/introduction/).

### Databricks Vector Srarchを利用したRAGのretrieve

DatabricksVectorSearch のインスタンスを作成  
[参考URL](https://python.langchain.com/docs/integrations/vectorstores/databricks_vector_search/)

In [0]:
%sql
CREATE OR REPLACE FUNCTION trainer_catalog.05_vector_search_index_for_nssol.product_docs_vector_search (
  -- The agent uses this comment to determine how to generate the query string parameter.
  query STRING
  COMMENT 'The query string for searching our product documentation.'
) RETURNS TABLE
-- The agent uses this comment to determine when to call this tool. It describes the types of documents and information contained within the index.
COMMENT 'Executes a search on product documentation to retrieve text documents most relevant to the input query.' RETURN
SELECT
  id as id,
  map('url', url, 'content', content) as metadata
FROM
  vector_search(
    -- Specify your Vector Search index name here
    index => 'trainer_catalog.05_vector_search_index_for_nssol.product_documentation_vs_index',
    query => query,
    num_results => 5
  )

In [0]:
# vs_tool = VectorSearchRetrieverTool(
#     # endpoint="vs_endpoint",
#     index_name="trainer_catalog.05_vector_search_index_for_nssol.product_documentation_vs_index",
#     tool_name="product_docs_retriever",
#     tool_description="Retrieves information about our products from official documentation."
# )

# Run a query against the vector search index locally for testing
# vs_tool.invoke("修理の手順は?")

### AI Functionの作成

In [0]:
func_name = [
    "trainer_catalog.03_data_analysis_by_gen_ai_for_nssol.product_with_many_inquiries",
    "trainer_catalog.05_vector_search_index_for_nssol.product_docs_vector_search"
             ]

In [0]:
from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit
from unitycatalog.ai.core.databricks import DatabricksFunctionClient

In [0]:
# Databricks Function クライアントの作成
client = DatabricksFunctionClient()

# UCFunctionToolkit を利用して、Unity Catalog 関数をツールとして登録
toolkit = UCFunctionToolkit(
    function_names=func_name,
    client=client
)

tools = toolkit.tools



In [0]:
# tools = [tool,vs_tool]
tools

### Agent作成

In [0]:
from typing import Annotated

# from langchain_anthropic import ChatAnthropic
# from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage
from typing_extensions import TypedDict

from langgraph.graph import StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition


class State(TypedDict):
    messages: Annotated[list, add_messages]


graph_builder = StateGraph(State)


# tool = TavilySearchResults(max_results=2)
# tools = [tool]
llm = ChatDatabricks(endpoint="openhack-gpt-4o")
llm_with_tools = llm.bind_tools(tools)


def chatbot(state: State):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}


graph_builder.add_node("chatbot", chatbot)

tool_node = ToolNode(tools=tools)
graph_builder.add_node("tools", tool_node)

graph_builder.add_conditional_edges(
    "chatbot",
    tools_condition,
)
# Any time a tool is called, we return to the chatbot to decide the next step
graph_builder.add_edge("tools", "chatbot")
graph_builder.set_entry_point("chatbot")
graph = graph_builder.compile()

In [0]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [0]:
def stream_graph_updates(user_input: str):
    for event in graph.stream({"messages": [{"role": "user", "content": user_input}]}):
        for value in event.values():
            print("Assistant:", value["messages"][-1].content)

In [0]:
# while True:
#     try:
#         user_input = input("User: ")
#         if user_input.lower() in ["quit", "exit", "q"]:
#             print("Goodbye!")
#             break

#         stream_graph_updates(user_input)
#     except:
#         # fallback if input() is not available
#         user_input = "What do you know about LangGraph?"
#         print("User: " + user_input)
#         stream_graph_updates(user_input)
#         break

# 質問例：問い合わせの多い製品の修理手順を教えてください
# "quit", "exit", "q" で終了させてください

# ここまで

In [0]:
from langchain_core.runnables import RunnableGenerator
from langgraph.prebuilt import create_react_agent

# Create the agent with the system message if it exists
agent_with_raw_output = create_react_agent(llm_with_tools, tools)

agent = agent_with_raw_output| RunnableGenerator(wrap_output)

## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

Replace this placeholder input with an appropriate domain-specific example for your agent.

In [0]:
# TODO: replace this placeholder input example with an appropriate domain-specific example for your agent
input_messages = [ChatMessage(role="user", content="問い合わせの多い製品の修理手順を教えてください")]
input_example = asdict(ChatCompletionRequest(messages=input_messages))

for event in agent.stream(input_example):
    print(event, "---" * 20 + "\n")
# for event in agent(input_example):
#     print(event, "---" * 20 + "\n")

In [0]:
mlflow.models.set_model(agent)

## Next steps

You can rerun the cells above to iterate and test the agent.

See the driver notebook, **custom-langgraph-schema-driver** ([AWS](https://docs.databricks.com/en/generative-ai/agent-framework/agent-schema.html#langchain-custom-schema-driver-notebook) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/agent-schema#langchain-custom-schemas)), to learn how to log, register, and deploy this agent.