# Connect to MLflow in Databricks

In [None]:
import os

import mlflow
import mlflow.tracking._model_registry.utils
from dotenv import load_dotenv

load_dotenv()

# The login and mlflow model registery are set this way because of local development. If you are running this in Databricks, you can remove the login and use the regualr mlflow registry.

mlflow.login()

mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = (
    lambda: "databricks-uc"
)

mlflow.set_experiment("/TransformersTextGeneration")

In [None]:
from huggingface_hub import login

# Needed if you want to download a hugginface model on which permission is needed
login(token=os.getenv("HUGGINGFACE_TOKEN"))

In [None]:
from transformers import pipeline

# Sometimes you need a custom implementation:
# https://mlflow.org/docs/latest/llms/transformers/guide/
# https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.TextGenerationPipeline
# https://huggingface.co/docs/transformers/v4.50.0/en/internal/pipelines_utils#transformers.pipelines.ArgumentHandler

generator = pipeline(
    "text-generation",
    model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
)

# Register the model

In [None]:
from transformers import pipeline

# By setting the task to llm/v1/chat, we can use the model as a chat model
# https://mlflow.org/docs/latest/llms/transformers/task/

with mlflow.start_run():
    model_info = mlflow.transformers.log_model(
        artifact_path="model",
        transformers_model=generator,
        task="llm/v1/chat",
        streamable=True,
        save_pretrained=False,
        pip_requirements="../requirements.txt",
    )

my_model = mlflow.pyfunc.load_model(model_info.model_uri)

# Test the registered model

In [None]:
response = my_model.predict(
    {
        "messages": [
            {"role": "user", "content": "Write me a hello world program in python"}
        ]
    }
)

print(response)

In [None]:
# This specific model does not support streaming

# for i in my_model.predict_stream(
#     {
#         "messages": [
#             {"role": "user", "content": "Write me a hello world program in python"}
#         ]
#     }
# ):
#     print(i)