# Deploy Text Embedding Models to Amazon SageMaker

In this notebook, we demonstrate, how we can package both a bi-encoder model for embedding and a cross-encoder model for re-ranking to a single model archive (`model.tar.gz`) and deploy to Amazon SageMaker real-time endpoint with a custom inference script.


## Models

- Bi-Encoder model for embeddings
  - [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)
- Cross-encoder model for re-ranking
  - [sentence-transformers/ms-marco-MiniLM-L-12-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2)

## Inference script to handle both embedding and re-ranking

Refer to [./models/bi-cross-encoder-minilm/code/inference.py](./models/bi-cross-encoder/code/inference.py) for implementation details.

In [None]:
# !pip install -U sagemaker rich watermark --quiet

In [None]:
import json
import os
import subprocess
from datetime import datetime
from pathlib import Path
from uuid import uuid4

import boto3
import sagemaker
from rich import print
from sagemaker import get_execution_role, image_uris, model_uris, script_uris
from sagemaker.deserializers import JSONDeserializer
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.predictor import Predictor
from sagemaker.s3 import S3Downloader, S3Uploader, s3_path_join
from sagemaker.serializers import JSONSerializer
from sagemaker.session import Session

In [None]:
session = sagemaker.Session()
bucket_name = session.default_bucket()
# role = get_execution_role()
role = "arn:aws:iam::726793866085:role/service-role/AmazonSageMaker-ExecutionRole-20220313T104021"
region = session.boto_region_name
# Define sagemaker client object to invoke Sagemaker services
sm_client = boto3.client("sagemaker", region_name=region)

model_base_name = "bi-cross-encoder"
model_folder = Path(f"./models/{model_base_name}").absolute().resolve()
model_archive_path = model_folder.joinpath("model.tar.gz")

In [None]:
model_folder

### Create Model

- Compress model artifacts to `model.tar.gz`
- Upload model to S3
- Create Model object


In [None]:
# change to model dir and run tar command
current_dir = os.getcwd()
print(current_dir)
print(str(model_folder))
if not os.path.exists(str(model_archive_path)):
    os.chdir(str(model_folder))
    command = "tar -cf model.tar.gz --use-compress-program=pigz **/*.*"
    subprocess.run(command, shell=True, check=True)
    os.chdir(current_dir)

In [None]:
model_archive_path

In [None]:
# Upload model artifact to S3
suffix = f"/models/txt-embedding-models/{model_base_name}"
upload_path_s3 = s3_path_join(f"s3://{bucket_name}", suffix)
print(f"Uploading the model to {upload_path_s3}")
model_data_url = S3Uploader.upload(
    local_path=str(model_archive_path),
    desired_s3_uri=upload_path_s3,
    sagemaker_session=session,
)
print(f"Model Data URL: {model_data_url}")

In [None]:
suffix = f"{str(uuid4())[:5]}-{datetime.now().strftime('%d%b%Y')}"
model_name = f"{model_base_name}-{suffix}"
instance_type = "ml.c5.2xlarge"
instance_count = 1

https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-model

In [None]:
print(f"Creating model: {model_name}")
txt_embed_model = HuggingFaceModel(
    model_data=model_data_url,
    role=role,
    entry_point="inference.py",
    transformers_version="4.26.0",
    pytorch_version="1.13.1",
    sagemaker_session=session,
    py_version="py39",
    name=model_name,
    env={"SAGEMAKER_CONTAINER_LOG_LEVEL": "10"},
)

txt_embed_model.create(instance_type=instance_type)

### Deploy Model

#### Deploy to Serverless endpoint

In [None]:
## Serverless endpoint

endpoint_name = model_name
endpoint_config_name = f"{model_name}-epc"
# Memory In GiB
memory = 2048
max_concurrency = 10

# Create endpoint config
epc_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "ModelName": model_name,
            "VariantName": "AllTraffic",
            "ServerlessConfig": {
                "MemorySizeInMB": memory,
                "MaxConcurrency": max_concurrency,
            },
        }
    ],
)
status_code = epc_response["ResponseMetadata"]["HTTPStatusCode"]
epc_arn = epc_response["EndpointConfigArn"]

if status_code == 200:
    print(f"EPC : {endpoint_config_name} created")
    print(f"Creating endpoint: {endpoint_name}")
    ep_response = sm_client.create_endpoint(
        EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
    )
    status_code = ep_response["ResponseMetadata"]["HTTPStatusCode"]
    print(f"Endpoint: {endpoint_name}; Status Code: {status_code}")

### Wait for endpoint to be `InService` state

In [None]:
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

# Get the waiter object
waiter = sm_client.get_waiter("endpoint_in_service")
# Apply the waiter on the endpoint
waiter.wait(EndpointName=endpoint_name)

# Get endpoint status using describe endpoint
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

#### Deploy to Real-time endpoint

In [None]:
# endpoint_name = model_name

# predictor = txt_embed_model.deploy(
#     instance_type=instance_type,
#     initial_instance_count=instance_count,
#     endpoint_name=endpoint_name,
#     serializer=JSONSerializer(),
#     deserializer=JSONDeserializer(),
#     wait=False,
# )

### Predict

In [None]:
from sagemaker.predictor import Predictor

predictor = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=session,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)


sentences = ["This is an example sentence", "Each sentence is converted"]

# input_data = {
#     "kind": "embeddings",
#     "sentence": "I love Berlin",
# }

input_data = {
    "kind": "cross-encoder",
    "sentence": "I love Berlin",
    "candidates": ["I love Paris", "I love Stuttgart"],
}


embeddings = predictor.predict(input_data)

print(embeddings)

### Cleanup

In [None]:
predictor.delete_model()
predictor.delete_endpoint()