# Connect to MLflow in Databricks


In [None]:
import mlflow
import mlflow.tracking._model_registry.utils
from dotenv import load_dotenv

load_dotenv()

# The login and mlflow model registery are set this way because of local development. If you are running this in Databricks, you can remove the login and use the regualr mlflow registry.

mlflow.login()

mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = (
    lambda: "databricks-uc"
)

mlflow.set_experiment("/ConversationalRAG")

# Register the model as code


In [None]:
UC_CATALOG_NAME = "generative_ai_toolkit.use_cases.conversational_rag"

In [None]:
input_example = {
    "messages": [
        {
            "role": "user",
            "content": "What is Retrieval-augmented Generation?",
        }
    ]
}

with mlflow.start_run(
    run_name="ConversationalRAG",
    tags={"ml_type": "genai", "architecture": "rag"},
    description="Conversational RAG API using the ChatAgent class",
):
    model_info = mlflow.pyfunc.log_model(
        artifact_path="model",
        streamable=True,
        pip_requirements="../../requirements.txt",
        python_model="ConversationalRAG.py",
        registered_model_name=UC_CATALOG_NAME,
        input_example=input_example,
        example_no_conversion=True,
    )


print(f"MLflow Run: {model_info.run_id}")
print(f"Model URI: {model_info.model_uri}")

In [None]:
from mlflow.tracking import MlflowClient

# Initialize the MLflow client
client = MlflowClient()

# Get all versions of the model
model_versions = client.search_model_versions(f"name='{UC_CATALOG_NAME}'")

# Get the latest version number
latest_version = max([int(mv.version) for mv in model_versions])

print(f"The latest version of model '{UC_CATALOG_NAME}' is: {latest_version}")

In [None]:
from databricks import agents
import os

# In the UI you need to change the deployment to a GPU instance, this can not be configured in the code
# This is optional, you can also test your modal locally as shown in the next cell
deployment = agents.deploy(
    model_name=UC_CATALOG_NAME,
    model_version=latest_version,
    scale_to_zero=True,
    environment_vars={
        "DATABRICKS_HOST": os.getenv("DATABRICKS_HOST"),
        "DATABRICKS_CLIENT_ID": os.getenv("DATABRICKS_CLIENT_ID"),
        "DATABRICKS_CLIENT_SECRET": os.getenv("DATABRICKS_CLIENT_SECRET"),
    },
    endpoint_name="conversational_rag_endpoint",
)

# Locally test the registered model


In [None]:
my_model = mlflow.pyfunc.load_model(model_info.model_uri)

In [None]:
response = my_model.predict(
    {
        "messages": [
            {
                "role": "user",
                "content": "What is the medallion arthicecture?",
            }
        ]
    }
)

print(response)

In [None]:
for token in my_model.predict_stream(
    {
        "messages": [
            {
                "role": "system",
                "content": "You are a helpful assistant that can answer questions about data engineering, specifically about the delta lake definitive guide which is book you have access to. You will recieve the user question and number of extracts from the book that you should use to answer the question. If you can not answer the question based on the information provided, just say so. Do not make up information. Never mention that you have recieved extracts from the book.",
            },
            {
                "role": "user",
                "content": "How did ETL work in the first generation platforms?",
            },
        ]
    }
):
    print(token.get("delta").get("content"), flush=True, end="")