# Deploy `thenlper/gte-base` Text Embedding Model (768 Dimension) to Amazon SageMaker

In this notebook, we demonstrate, how we can package and deploy `thenlper/gte-base` embedding model with 768 dimensions.

**Gegeral Text Embeddings (GTE) model**

The GTE models are trained by Alibaba DAMO Academy. They are mainly based on the BERT framework and currently offer three different sizes of models, including GTE-large, GTE-base, and GTE-small. The GTE models are trained on a large-scale corpus of relevance text pairs, covering a wide range of domains and scenarios. This enables the GTE models to be applied to various downstream tasks of text embeddings, including information retrieval, semantic textual similarity, text reranking, etc.

**NOTE:** gte model sizes are comparitively smaller than other top performing embedding models

- `thenlper/gte-small`: **~67MB**
- `thenlper/gte-base`: **~220MB**
- `thenlper/gte-large`: **~670MB**

## Papers

N/A as of 03/08/2023

## Models

- [`thenlper/gte-small`](https://hf.co/thenlper/gte-small)
- [`thenlper/gte-base`](https://hf.co/thenlper/gte-base)
- [`thenlper/gte-large`](https://hf.co/thenlper/gte-large)

## Inference script to handle both embedding and re-ranking

Refer to [./models/bi-encoders/gte-base/code/inference.py](./models/bi-encoders/gte-base/code/inference.py) for implementation details.


In [None]:
# !pip install -U sagemaker rich watermark --quiet

In [None]:
import json
import os
import subprocess
from datetime import datetime
from pathlib import Path
from uuid import uuid4

import boto3
import sagemaker
from rich import print
from sagemaker import get_execution_role
from sagemaker.deserializers import JSONDeserializer
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.predictor import Predictor
from sagemaker.s3 import S3Uploader, s3_path_join
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.session import Session
from pathlib import Path
from huggingface_hub import snapshot_download
from sagemaker.serverless import ServerlessInferenceConfig

In [None]:
session = Session()
bucket_name = session.default_bucket()
role = get_execution_role()
region = session.boto_region_name
# Define sagemaker client object to invoke Sagemaker services
sm_client = boto3.client("sagemaker", region_name=region)

HF_MODEL_ID = "thenlper/gte-base"
model_base_name = HF_MODEL_ID.split("/")[-1]
model_folder = Path(f"./models/bi-encoders/{model_base_name}").absolute().resolve()
model_archive_path = model_folder.joinpath("model.tar.gz")
current_dir = os.getcwd()

print(model_folder)
print(model_archive_path)

In [None]:
if not model_folder.exists():
    print(f"Downloading model ...")
    snapshot_download(repo_id=HF_MODEL_ID, local_dir=model_folder, local_dir_use_symlinks=False)
else:
    print(f"Model already downloaded.")

### Create Model

- Compress model artifacts to `model.tar.gz`
- Upload model to S3
- Create Model object


In [None]:
files_to_compress = [
    "pytorch_model.bin",
    "config.json",
    "vocab.txt",
    "tokenizer.json",
    "tokenizer_config.json",
    "special_tokens_map.json",
    "sentence_bert_config.json",
    "1_Pooling",
    "code",
]

In [None]:
# change to model dir and run tar command
print(current_dir)
print(model_folder)

In [None]:
model_archive_path = model_folder.joinpath("model.tar.gz")

if model_archive_path.exists():
    model_archive_path.unlink()

if not model_archive_path.exists():
    print(str(model_folder))
    os.chdir(str(model_folder))
    model_files = " ".join(files_to_compress)
    command = f"tar -cf model.tar.gz --use-compress-program=pigz {model_files}"
    out = subprocess.run(command, shell=True, check=True)
    if out.returncode != 0:
        raise Exception("Failed to compress model files")
    else:
        print("Model files compressed successfully")
    os.chdir(current_dir)

In [None]:
# Upload model artifact to S3
suffix = f"/models/txt-embedding-models/{model_base_name}"
upload_path_s3 = s3_path_join(f"s3://{bucket_name}", suffix)
print(f"Uploading the model to {upload_path_s3}")
model_data_url = S3Uploader.upload(
    local_path=str(model_archive_path),
    desired_s3_uri=upload_path_s3,
    sagemaker_session=session,
)
print(f"Model Data URL: {model_data_url}")

In [None]:
suffix = f"{str(uuid4())[:5]}-{datetime.now().strftime('%d%b%Y')}"
model_name = f"{model_base_name}-{suffix}"
print(f"Model Name: {model_name}")

Create HuggingFaceModel with model data and custom `inference.py` script

https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-model


In [None]:
print(f"Creating model: {model_name}")
txt_embed_model = HuggingFaceModel(
    model_data=model_data_url,
    role=role,
    transformers_version="4.26.0",
    pytorch_version="1.13.1",
    sagemaker_session=session,
    py_version="py39",
    name=model_name,
)

### Deploy Model


### Deploy to serverless endpoint


In [None]:
from sagemaker.serverless import ServerlessInferenceConfig

# Memory In GiB
memory = 2048
max_concurrency = 10
endpoint_name = model_name
serverless_config = ServerlessInferenceConfig(
    memory_size_in_mb=memory, max_concurrency=max_concurrency
)

print(f"Creating endpoint: [b]{endpoint_name}[/b] ...")

# Returns a HuggingFacePredictor
predictor = txt_embed_model.deploy(
    endpoint_name=endpoint_name,
    serverless_inference_config=serverless_config,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
    wait=False,
)

### Wait for endpoint to be `InService` state


In [None]:
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

# Get the waiter object
waiter = sm_client.get_waiter("endpoint_in_service")
# Apply the waiter on the endpoint
waiter.wait(EndpointName=endpoint_name)

# Get endpoint status using describe endpoint
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

### Deploy to real-time endpoint (Optional)

Uncomment below code to deploy this to a real-time endpoint instead.


In [None]:
# endpoint_name = model_name

# predictor = txt_embed_model.deploy(
#     instance_type=instance_type,
#     initial_instance_count=instance_count,
#     endpoint_name=endpoint_name,
#     serializer=JSONSerializer(),
#     deserializer=JSONDeserializer(),
#     wait=False,
# )

### Predict

**NOTE:**

Do I need to add the prefix "query: " and "passage: " to input texts?

Yes, this is how the model is trained, otherwise you will see a performance degradation.

Here are some rules of thumb:

- Use _"query: "_ and _"passage: "_ correspondingly for **asymmetric tasks** such as passage retrieval in open QA, ad-hoc information retrieval.
- Use **"query: "** prefix for **symmetric tasks** such as semantic similarity, paraphrase retrieval.
- Use **"query: "** prefix if you want to use embeddings as features, such as linear probing classification, clustering.

Ref: <https://huggingface.co/intfloat/e5-base-v2#faq>


#### Uncomment below code block if you are invoking an existing endpoint


In [None]:
# endpoint_name = "gte-base-cc0cc-03Aug2023"
# predictor = Predictor(
#     endpoint_name=endpoint_name,
#     sagemaker_session=session,
#     serializer=JSONSerializer(),
#     deserializer=JSONDeserializer()
# )

In [None]:
sentences = ["That is a happy person", "That is a very happy person"]

embeddings = predictor.predict(sentences)

print(f"Embedding dimensions: {len(embeddings[0])}")
print(embeddings[0])

### Cleanup


In [None]:
# predictor.delete_model()
# predictor.delete_endpoint()