In [None]:
import mlflow
import mlflow.tracking._model_registry.utils
from dotenv import load_dotenv

# The login and mlflow model registery are set this way because of local development. If you are running this in Databricks, you can remove the login and use the regualr mlflow registry.

load_dotenv()

mlflow.login()

mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = (
    lambda: "databricks-uc"
)

mlflow.set_experiment("/SentenceTransformersEmbeddings")

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)

# The queries should be embedded as a retrieval.query, the documents in the database as retrieval.passage
task = "retrieval.query"

embeddings = model.encode(
    ["What is the weather like in Berlin today?"],
    task=task,
    prompt_name=task,
)

In [None]:
example_sentences = ["A sentence to encode.", "Another sentence to encode."]

signature = mlflow.models.infer_signature(
    model_input=example_sentences,
    model_output=model.encode(example_sentences),
)

signature

In [None]:
# In some caes you might need a custom implementation:
# https://mlflow.org/docs/latest/llms/sentence-transformers/guide/

# This is a different implementation compared to the Transformer library.
# On the ML flow side it is unclear how to define the task type - which influences the embeddings (we might be able to pass this in the inference_config).
# https://mlflow.org/docs/latest/api_reference/python_api/mlflow.sentence_transformers.html#mlflow.sentence_transformers.log_model

# Furthermore, it is unclear what type of pooling is used - 'mean' is assumed as it is the most common.

# The transformers implementation allows for not storing the model weight, the sentence transformers implementation does not.

with mlflow.start_run():
    model_info = mlflow.sentence_transformers.log_model(
        model=model,
        artifact_path="model",
        signature=signature,
        input_example=example_sentences,
        pip_requirements="../../requirements.txt",
        # task="llm/v1/embeddings",  # this does not seem to convert the sentense transformer to the OpenAI API standard, while the docs mention it should.
        inference_config={"task": "retrieval.query"},
    )

my_model = mlflow.pyfunc.load_model(model_info.model_uri)

# Test the registered model

In [None]:
my_model.predict({"input": "This is a test sentence."})

In [None]:
my_model.predict(
    {"input": ["This is a test sentence.", "This is another test sentence."]}
)