
# The Notebook will Deploy a Pretrained Sentence Transformer Model to Model Serving

**NOTE**: Need to use 14.3 ML LTS to avoid python `snappy-c.h` related issue when perform model deployment

In [0]:
%pip install -U databricks-sdk
%pip install -U sentence-transformers
%pip install -U mlflow
%pip install python-snappy==0.7.3
%pip install einops
%pip install torch==2.4.0 torchvision==0.19.0
dbutils.library.restartPython()

In [0]:
dbutils.widgets.text(name="target_catalog", label="Catalog", defaultValue="dev_catalog")
dbutils.widgets.text(name="target_schema", label="Schema", defaultValue="dev_schema")
dbutils.widgets.text(name="embedding_model", label="Embedding model", defaultValue="Snowflake/snowflake-arctic-embed-m-long")

In [0]:
target_catalog = dbutils.widgets.get("target_catalog")
target_schema = dbutils.widgets.get("target_schema")
embedding_model = dbutils.widgets.get("embedding_model")
print(f"target catalog: {target_catalog}, target schema: {target_schema}, embedding model: {embedding_model}")

# Download the Pretrained Embedding Model from HuggingFace

In [0]:
from sentence_transformers import SentenceTransformer
import pandas as pd
import torch

In [0]:
model = SentenceTransformer(embedding_model, trust_remote_code=True)

# Define the queries and documents
documents = ['The Data Cloud!', 'Mexico City of Course!']
document_embeddings = model.encode(documents)
document_embeddings

In [0]:
model_name = embedding_model.split('/')[-1].replace('.', 'p')
model_uc_path = f'{target_catalog}.{target_schema}.{model_name}'
endpoint_name = model_name
workload_type = "GPU_SMALL"

spark.sql(f"USE CATALOG {target_catalog}")
spark.sql(f"USE SCHEMA {target_schema}")

In [0]:
import mlflow
from mlflow.models import infer_signature
from mlflow.tracking import MlflowClient

#mlflow.set_experiment("/Workspace/Users/q.yu@databricks.com/ml_experiments")

signature = mlflow.models.infer_signature(
    model_input=documents,
    model_output=model.encode(documents),
)

mlflow.set_registry_uri("databricks-uc")

with mlflow.start_run():
    logged_model_info = mlflow.sentence_transformers.log_model(
        model=model,
        artifact_path='custom_embedding',
        registered_model_name=model_uc_path,
        signature=signature,
        input_example=documents
    )

# Serving the model with Model Serving

In [0]:
# Create or update serving endpoint
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import EndpointCoreConfigInput, ServedModelInput

config = EndpointCoreConfigInput.from_dict({
    "served_models": [
        {
            "name": endpoint_name,
            "model_name": model_uc_path,
            "model_version": logged_model_info.registered_model_version,
            "workload_type": workload_type,
            "workload_size": "Small",
            "scale_to_zero_enabled": "True",
        }
    ]
})

In [0]:
import datetime
from databricks.sdk import WorkspaceClient

w = WorkspaceClient()
w.serving_endpoints.create_and_wait(name=endpoint_name, 
                                    config=config, 
                                    timeout=datetime.timedelta(minutes=30))