# Connect to MLflow in Databricks

In [None]:
import mlflow
import mlflow.tracking._model_registry.utils
from dotenv import load_dotenv

load_dotenv()

# The login and mlflow model registery are set this way because of local development. If you are running this in Databricks, you can remove the login and use the regualr mlflow registry.

mlflow.login()

mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = (
    lambda: "databricks-uc"
)

mlflow.set_experiment("/TransformersEmbeddings")

# Register the model 

In [None]:
# Use a pipeline as a high-level helper
from transformers import AutoModel, AutoTokenizer, pipeline

# Details:
# https://huggingface.co/jinaai/jina-embeddings-v3
# https://mlflow.org/docs/latest/llms/deployments#embeddings
# https://mlflow.org/docs/latest/llms/transformers/guide/

# Other models are also possible
model_name = "jinaai/jina-embeddings-v3"

tokenizer = AutoTokenizer.from_pretrained(model_name)

# model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="sdpa", output_attentions=False)
model = AutoModel.from_pretrained(
    model_name, trust_remote_code=True, attn_implementation="eager"
)

task = "retrieval.query"
model._adaptation_map[task]


# Feature extraction is the Transformer model task that represents the 'embeddings' task
# https://huggingface.co/tasks

extractor = pipeline(
    "feature-extraction",
    model=model,
    tokenizer=tokenizer,
    trust_remote_code=True,
)

In [None]:
# https://mlflow.org/docs/latest/api_reference/python_api/mlflow.transformers.html

with mlflow.start_run():
    model_info = mlflow.transformers.log_model(
        transformers_model=extractor,
        artifact_path="model",
        task="llm/v1/embeddings",
        save_pretrained=False,
        pip_requirements="../../requirements.txt",
    )

my_model = mlflow.pyfunc.load_model(model_info.model_uri)

# Test the registered model

In [None]:
my_model.predict({"input": "This is a test sentence."})

In [None]:
my_model.predict(
    {"input": ["This is a test sentence.", "This is another test sentence."]}
)