# Deploy a cross encoder model for re-ranking to Amazon SageMaker endpoint

In this notebook, we demonstrate, how we can package and deploy a cross-encoder model for re-ranking.

## Bi-Encoder vs. Cross-Encoder

First, it is important to understand the difference between Bi- and Cross-Encoder.

Bi-Encoders produce for a given sentence a sentence embedding. We pass to a BERT independently the sentences A and B, which result in the sentence embeddings u and v. These sentence embedding can then be compared using cosine similarity:

![Bi vs Cross-encoder](https://raw.githubusercontent.com/UKPLab/sentence-transformers/master/docs/img/Bi_vs_Cross-Encoder.png)

In contrast, for a Cross-Encoder, we pass both sentences simultaneously to the Transformer network. It produces then an output value between 0 and 1 indicating the similarity of the input sentence pair:

A Cross-Encoder _does not produce_ a sentence embedding. Also, we are not able to pass individual sentences to a Cross-Encoder.

As detailed in [this](https://arxiv.org/abs/1908.10084) paper, Cross-Encoder achieve better performances than Bi-Encoders.

However, for many application they are not pratical as they do not produce embeddings we could e.g. index or efficiently compare using cosine similarity.

## Models

- Cross-encoder model for re-ranking
  - [sentence-transformers/ms-marco-MiniLM-L-12-v2](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2)

## Inference script to handle both embedding and re-ranking

Refer to [models/cross-encoders/ms-marco-MiniLM-L-12-v2/code/inference.py](./models/cross-encoders/ms-marco-MiniLM-L-12-v2/code/inference.py) script for implementation details.


In [None]:
# !pip install -U sagemaker rich watermark --quiet

In [None]:
import json
import os
import subprocess
from datetime import datetime
from pathlib import Path
from uuid import uuid4

import boto3
import sagemaker
from rich import print
from sagemaker import get_execution_role
from sagemaker.deserializers import JSONDeserializer
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.predictor import Predictor
from sagemaker.s3 import S3Uploader, s3_path_join
from sagemaker.serializers import JSONSerializer
from sagemaker.session import Session

In [None]:
session = sagemaker.Session()
bucket_name = session.default_bucket()
role = get_execution_role()
region = session.boto_region_name
# Define sagemaker client object to invoke Sagemaker services
sm_client = boto3.client("sagemaker", region_name=region)

model_base_name = "ms-marco-MiniLM-L-12-v2"
model_folder = Path(f"./models/cross-encoders/{model_base_name}").absolute().resolve()
model_archive_path = model_folder.joinpath("model.tar.gz")

In [None]:
model_folder

### Create Model

- Compress model artifacts to `model.tar.gz`
- Upload model to S3
- Create Model object


In [None]:
files_to_compress = [
    "pytorch_model.bin",
    "config.json",
    "vocab.txt",
    "tokenizer_config.json",
    "special_tokens_map.json",
    "code",
]

In [None]:
# change to model dir and run tar command
current_dir = os.getcwd()
print(current_dir)

model_archive_path = model_folder.joinpath("model.tar.gz")

if not os.path.exists(str(model_archive_path)):
    print(str(model_folder))
    os.chdir(str(model_folder))
    model_files = " ".join(files_to_compress)
    command = f"tar -cf model.tar.gz --use-compress-program=pigz {model_files}"
    out = subprocess.run(command, shell=True, check=True)
    if out.returncode != 0:
        raise Exception("Failed to run compress files")
    else:
        print("model.tar.gz created successfully!")
    os.chdir(current_dir)

In [None]:
# Upload model artifact to S3
suffix = f"/models/txt-embedding-models/cross-encoders/{model_base_name}"
upload_path_s3 = s3_path_join(f"s3://{bucket_name}", suffix)
print(f"Uploading the model to {upload_path_s3}")
model_data_url = S3Uploader.upload(
    local_path=str(model_archive_path),
    desired_s3_uri=upload_path_s3,
    sagemaker_session=session,
)
print(f"Model Data URL: {model_data_url}")

In [None]:
suffix = f"{str(uuid4())[:5]}-{datetime.now().strftime('%d%b%Y')}"
model_name = f"{model_base_name}-{suffix}"

Create HuggingFaceModel with model data and custom `inference.py` script

https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-model


In [None]:
print(f"Creating model: {model_name}")
txt_embed_model = HuggingFaceModel(
    model_data=model_data_url,
    role=role,
    entry_point="inference.py",
    transformers_version="4.26.0",
    pytorch_version="1.13.1",
    sagemaker_session=session,
    py_version="py39",
    name=model_name,
    env={"SAGEMAKER_CONTAINER_LOG_LEVEL": "10"},
)

txt_embed_model.create(instance_type=instance_type)

### Deploy Model


### Deploy to serverless endpoint


In [None]:
## Serverless endpoint

endpoint_name = model_name
endpoint_config_name = f"{model_name}-epc"

# Memory In GiB
memory = 2048
max_concurrency = 10

# Create endpoint config
epc_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "ModelName": model_name,
            "VariantName": "AllTraffic",
            "ServerlessConfig": {
                "MemorySizeInMB": memory,
                "MaxConcurrency": max_concurrency,
            },
        }
    ],
)
status_code = epc_response["ResponseMetadata"]["HTTPStatusCode"]
epc_arn = epc_response["EndpointConfigArn"]

if status_code == 200:
    print(f"EPC : {endpoint_config_name} created")
    print(f"Creating endpoint: {endpoint_name} ...")
    ep_response = sm_client.create_endpoint(
        EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
    )
    status_code = ep_response["ResponseMetadata"]["HTTPStatusCode"]
    print(f"Endpoint: {endpoint_name}; Status Code: {status_code}")

### Wait for endpoint to be `InService` state


In [None]:
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

# Get the waiter object
waiter = sm_client.get_waiter("endpoint_in_service")
# Apply the waiter on the endpoint
waiter.wait(EndpointName=endpoint_name)

# Get endpoint status using describe endpoint
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

### Deploy to real-time endpoint (Optional)

Uncomment below code to deploy this to a real-time endpoint instead.


In [None]:
# endpoint_name = model_name

# predictor = txt_embed_model.deploy(
#     instance_type=instance_type,
#     initial_instance_count=instance_count,
#     endpoint_name=endpoint_name,
#     serializer=JSONSerializer(),
#     deserializer=JSONDeserializer(),
#     wait=False,
# )

### Predict


In [None]:
predictor = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=session,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

input_data = {
    "sentence": "I love Berlin",
    "candidates": ["I love Paris", "I love Dusseldorf", "I love Hannover"],
}

rankings = predictor.predict(input_data)

print(rankings)

### Cleanup


In [None]:
# predictor.delete_model()
# predictor.delete_endpoint()