In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import mlflow
import mlflow.tracking._model_registry.utils
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# The login and mlflow model registery are set this way because of local development. If you are running this in Databricks, you can remove the login and use the regualr mlflow registry.

mlflow.login()

mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = (
    lambda: "databricks-uc"
)

mlflow.set_experiment("/ChatAgentTools")

In [None]:
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="model",
        # code_paths=None,  # location of the dependencies, this way you do not need to build a wheel file.
        # registered_model_name="ai_recruiter.candidates.agent_prototype",
        streamable=True,
        pip_requirements="../requirements.txt",
        python_model="ChatAgentTools.py",
        # prompts=None,  # a list of prompts registered in the prompt registry; prompt:/<name>/<version> - not supported in databricks yet
        metadata={
            "max_tokens": 8000,
            "model_name": "databricks-meta-llama-3-3-70b-instruct",
        },
        input_example={
            "messages": [{"role": "user", "content": "Hello, how are you?"}]
        },
    )

my_model = mlflow.pyfunc.load_model(model_info.model_uri)

In [None]:
# Create a generator from predict_stream
stream_generator = my_model.predict_stream(
    data={
        "messages": [{"role": "user", "content": "What is 12 + 8?"}],
    }
)

# Iterate through the generator to get responses
for response in stream_generator:
    print(response)

In [None]:
result = my_model.predict(
    data={
        "messages": [{"role": "user", "content": "What is 12 + 8?"}],
    }
)

result