In [0]:
%pip install -U -qqqq mlflow openai databricks-agents uv
%pip install -U -qqqq botocore==1.37.11 s3transfer==0.11.0
dbutils.library.restartPython()

In [0]:
%%writefile agent.py
from typing import Any, Generator, Optional
from uuid import uuid4

import mlflow
from databricks.sdk import WorkspaceClient
from mlflow.entities import SpanType
from mlflow.pyfunc.model import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

mlflow.openai.autolog()
LLM_ENDPOINT_NAME = "databricks-claude-3-7-sonnet"

class CustomChatAgent(ChatAgent):
    def __init__(self):
        self.workspace_client = WorkspaceClient()
        self.client = self.workspace_client.serving_endpoints.get_open_ai_client()
        self.llm_endpoint = LLM_ENDPOINT_NAME

    def prepare_messages_for_llm(self, messages: list[ChatAgentMessage]) -> list[dict[str, Any]]:
        """Filter out ChatAgentMessage fields that are not compatible with LLM message formats"""
        compatible_keys = ["role", "content", "name", "tool_calls", "tool_call_id"]
        return [
            {k: v for k, v in m.model_dump_compat(exclude_none=True).items() if k in compatible_keys} for m in messages
        ]

    #@mlflow.trace(span_type=SpanType.AGENT)
    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None
    ) -> ChatAgentResponse:
        fixed_prompt = "You're and AI assistant, users will provide the process deviation or process SOP related paragraphs from those Identify the most important keywords  and generate a concise, 10-12 word title that includes keywords.that includes them without adding additional notes or explanations Ensure to include any of the following keywords if they appear: SOP, TM-, KG, GMP (maintain original capitalization).  Adopt a formal, business-appropriate tone, Additionally, respond politely to users if they greet you, but focus on delivering high-quality service. Provide only the top 4 suggestions."

        messages.insert(0, ChatAgentMessage(role="system", content=fixed_prompt))

        resp = self.client.chat.completions.create(
            model=self.llm_endpoint,
            messages=self.prepare_messages_for_llm(messages),
        )

        return ChatAgentResponse(
            messages=[ChatAgentMessage(**resp.choices[0].message.to_dict(), id=str(uuid4()))],
        )


from mlflow.models import set_model

AGENT = CustomChatAgent()
set_model(AGENT)

In [0]:
from mlflow.types.agent import ChatAgentMessage
from agent import AGENT

# Define your question or message
user_message = ChatAgentMessage(
    role="user",
    content="Collaboration is a partnership; a union; the act of producing or making something together. Collaboration can take place between two people or many people, strangers or best friends. To collaborate is to commit to the possibility of producing an outcome greater than one that would be developed in a silo."
)

try:
    # Call the predict method with your message
    response = AGENT.predict(messages=[user_message])

    # Print the response
    if response and response.messages:
        for message in response.messages:
            print(f"{message.role}: {message.content}")
    else:
        print("No response received from the agent.")

except Exception as e:
    print(f"An error occurred: {e}")

In [0]:
import mlflow

In [0]:
import mlflow
from agent import LLM_ENDPOINT_NAME
from mlflow.models.resources import DatabricksServingEndpoint

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agent.py",
        pip_requirements=[
            "mlflow",
            "openai",
            "databricks-sdk",
        ],
        resources=[DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
    )

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={"messages": [{"role": "user", "content": "Hello!"}]},
    env_manager="uv",
)

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = "pdm-pdm-dl-quality-docs-genai-dev"
schema = "gvault_test"
model_name = "agent1"
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME)

In [0]:
from databricks import agents

agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version,scale_to_zero=True, tags={"manual": "testing"})