# Connect to MLflow in Databricks

In [None]:
import mlflow
import mlflow.tracking._model_registry.utils
from dotenv import load_dotenv

load_dotenv()

# The login and mlflow model registery are set this way because of local development. If you are running this in Databricks, you can remove the login and use the regualr mlflow registry.

mlflow.login()

mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = (
    lambda: "databricks-uc"
)

mlflow.set_experiment("/ChatAgentLangGraph")

# Register the model as code

In [None]:
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="model",
        code_paths=None,  # location of the dependencies, this way you do not need to build a wheel file.
        streamable=True,
        pip_requirements="../../requirements.txt",
        python_model="ChatAgentLangGraph.py",
        input_example={
            "messages": [{"role": "user", "content": "Hello, how are you?"}]
        },
    )

my_model = mlflow.pyfunc.load_model(model_info.model_uri)

# Test the registered model

In [None]:
my_model.predict(
    {
        "messages": [
            {"role": "user", "content": "Calculate 9 + 1 with your available tools"}
        ]
    }
)

In [None]:
for i in my_model.predict_stream(
    {"messages": [{"role": "user", "content": "What is 12 + 8?"}]}
):
    print(i)

## LangGraph Details

In [None]:
import os
import sys

sys.path.append(os.path.dirname(os.getcwd()))

from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod
from scratch.LangGraph import agent

# Display the agent's graph structure using Mermaid
if hasattr(agent, "get_graph"):
    graph = agent.get_graph()
    try:
        # Use the Mermaid visualization method with PNG output
        display(
            Image(
                graph.draw_mermaid_png(
                    draw_method=MermaidDrawMethod.API,
                )
            )
        )
    except Exception as e:
        print(f"Visualization error: {e}")
        print("Graph structure:", graph)
else:
    print("Agent doesn't have a graph visualization method")

In [None]:
import json

# Mimick a frontend by parsing the respones from the agent differently depending on the message type.


def parse_and_display_api_response(chunk):
    """
    Parse and display a single chunk from a streaming LLM API response.

    Args:
        chunk: A single chunk from the streaming API response
    """
    try:
        # Check if this is a tuple containing messages
        if isinstance(chunk, tuple) and len(chunk) >= 2 and chunk[0] == "messages":
            # Access the message chunk (AIMessageChunk) and metadata
            message_chunk = chunk[1][0]  # First element of the second tuple item

            # Handle tool calls
            if (
                hasattr(message_chunk, "additional_kwargs")
                and "tool_calls" in message_chunk.additional_kwargs
            ):
                tool_calls = message_chunk.additional_kwargs.get("tool_calls", [])

                # Display tool calls
                for tool_call in tool_calls:
                    # Extract function name and arguments
                    function_name = tool_call.get("function", {}).get("name", "")
                    if not function_name and "id" in tool_call:
                        # Try to get name from id if available
                        function_name = (
                            tool_call.get("id", "").split("_")[1]
                            if "_" in tool_call.get("id", "")
                            else ""
                        )

                    arguments = tool_call.get("function", {}).get("arguments", "")

                    # Only display if we have meaningful data
                    if function_name or arguments:
                        print("\033[33m", end="")  # Yellow color for tool calls
                        print(f"TOOL CALL: {function_name}")

                        # Format arguments if they're JSON
                        if arguments and arguments.strip():
                            try:
                                arg_obj = json.loads(arguments)
                                args_formatted = json.dumps(arg_obj, indent=2)
                                print(f"ARGUMENTS: {args_formatted}")
                            except Exception as e:
                                if arguments:
                                    print(f"ARGUMENTS: {arguments}")
                                raise e
                        print("\033[0m", end="")  # Reset color

            # Display content if present
            if hasattr(message_chunk, "content") and message_chunk.content:
                # Print content in normal color (no tool call)
                print(message_chunk.content, end="", flush=True)

            # Display finish reason if present
            if hasattr(
                message_chunk, "response_metadata"
            ) and message_chunk.response_metadata.get("finish_reason"):
                finish_reason = message_chunk.response_metadata.get("finish_reason")
                print(f"\n\033[32m[FINISHED: {finish_reason}]\033[0m")

    except Exception as e:
        print(f"\n\033[31mError parsing chunk: {e}\033[0m")
        print(f"Raw chunk: {chunk}")

In [None]:
for i in agent.stream(
    {
        "messages": [
            {
                "role": "user",
                "content": "What is 1000000 + 4, break it down in details, us the function then explain the answer.",
            }
        ]
    },
    stream_mode=["messages"],
):
    parse_and_display_api_response(i)