# Deploy `BAAI/bge-large-en` Text Embedding Model (1024 Dimension) to Amazon SageMaker

In this notebook, we demonstrate, how we can package and deploy `BAAI/bge-large-en` embedding model with 1024 dimensions.

`bge` is short for BAAI general embedding.

*NOTE*: If you need to search the long relevant passages to a short query (s2p retrieval task), you need to add the instruction to the query; in other cases, no instruction is needed, just use the original query directly. In all cases, no instruction need to be added to passages.


**NOTE:** bge model sizes and dimension
- `BAAI/bge-base-en`: **~438MB** (Dimensions: 768)
- `BAAI/bge-large-en`: **~1.34GB** (Dimensions: 1024)

## References
https://github.com/FlagOpen/FlagEmbedding

## Inference script to handle both embedding and re-ranking

Refer to [./models/bi-encoders/bge-large-en/code/inference.py](./models/bi-encoders/gte-base/code/inference.py) for implementation details.

In [None]:
!pip install -Uq sagemaker rich watermark ipywidgets

%load_ext rich
%load_ext watermark
%watermark -p sagemaker,ipywidgets

In [None]:
import json
import os
import subprocess
from datetime import datetime
from pathlib import Path
from uuid import uuid4

import boto3
import sagemaker
from huggingface_hub import snapshot_download
from rich import print
from sagemaker import get_execution_role
from sagemaker.deserializers import JSONDeserializer
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.predictor import Predictor
from sagemaker.s3 import S3Uploader, s3_path_join
from sagemaker.serializers import JSONSerializer
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.session import Session
from tqdm import tqdm

In [None]:
current_dir = os.getcwd()
print(current_dir)

session = Session()
bucket_name = session.default_bucket()
role = get_execution_role()
region = session.boto_region_name

HF_MODEL_ID = "BAAI/bge-large-en"
model_base_name = HF_MODEL_ID.split("/")[-1]
model_folder = Path(f"./models/bi-encoders/{model_base_name}").absolute().resolve()
model_archive_path = model_folder.joinpath("model.tar.gz")

print(f"Region: [i]{region}[/i]")
print(f"bucket name: {bucket_name}")
print(model_folder)
print(model_archive_path)

In [None]:
model_bin = model_folder.joinpath("pytorch_model.bin")

if not model_bin.exists():
    print(f"Downloading model ...")
    snapshot_download(
        repo_id=HF_MODEL_ID,
        local_dir=model_folder,
        local_dir_use_symlinks=False,
        allow_patterns=["1_Pooling", "*.txt", "*.json", "*.bin", "*.safetensors"],
    )
else:
    print(f"Model already downloaded.")

### Create Model

- Compress model artifacts to `model.tar.gz`
- Upload model to S3
- Create Model object


In [None]:
files_to_compress = [
    "pytorch_model.bin",
    "config.json",
    "modules.json",
    "vocab.txt",
    "tokenizer.json",
    "tokenizer_config.json",
    "special_tokens_map.json",
    "config_sentence_transformers.json",
    "sentence_bert_config.json",
    "1_Pooling",
    "code",
]
model_archive_path = model_folder.joinpath("model.tar.gz")

In [None]:
str(model_archive_path)

In [None]:
# change to model dir and run tar command
# current_dir = os.getcwd()
print(current_dir)
print(model_folder)

In [None]:
os.chdir(current_dir)
model_archive_path = model_folder.joinpath("model.tar.gz")

if model_archive_path.exists():
    model_archive_path.unlink()

if not model_archive_path.exists():
    print(str(model_folder))
    # change to model dir and run tar command
    os.chdir(str(model_folder))
    model_files = " ".join(files_to_compress)
    command = f"tar -cvzf model.tar.gz --use-compress-program=pigz --exclude='**/.ipynb_checkpoints' --exclude='.DS_Store' {model_files}"
    result = subprocess.run(command, shell=True, check=True)
    if result.returncode == 0:
        print(f"tar created successfully at {model_archive_path}!")
    else:
        os.chdir(current_dir)
        print(result.stderr)
os.chdir(current_dir)

In [None]:
# Verify contents of the model archive.
!tar tvf $model_archive_path

In [None]:
# Upload model artifact to S3
suffix = f"/models/txt-embedding-models/{model_base_name}"
upload_path_s3 = s3_path_join(f"s3://{bucket_name}", suffix)
print(f"Uploading the model to {upload_path_s3} ...")

model_data_url = S3Uploader.upload(
    local_path=str(model_archive_path),
    desired_s3_uri=upload_path_s3,
    sagemaker_session=session,
)

print(f"Model Data URL: {model_data_url}")

Create HuggingFaceModel with model data and custom `inference.py` script

https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-model


In [None]:
suffix = f"{str(uuid4())[:5]}-{datetime.now().strftime('%d%b%Y')}"
model_name = f"{model_base_name}-{suffix}"

print(f"Creating model: {model_name}")

txt_embed_model = HuggingFaceModel(
    model_data=model_data_url,
    role=role,
    transformers_version="4.26.0",
    pytorch_version="1.13.1",
    sagemaker_session=session,
    py_version="py39",
    name=model_name,
)

### Deploy Model

### Deploy to serverless endpoint

In [None]:
from sagemaker.serverless import ServerlessInferenceConfig

# Memory In GiB
memory = 4096
max_concurrency = 10
endpoint_name = model_name
serverless_config = ServerlessInferenceConfig(
    memory_size_in_mb=memory, max_concurrency=max_concurrency
)

print(f"Creating endpoint: [b]{endpoint_name}[/b] ...")

# Returns a HuggingFacePredictor
predictor = txt_embed_model.deploy(
    endpoint_name=endpoint_name,
    serverless_inference_config=serverless_config,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
    wait=False,
)

### Wait for endpoint to be `InService` state

In [None]:
sm_client = boto3.client("sagemaker")
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

# Get the waiter object
waiter = sm_client.get_waiter("endpoint_in_service")
# Apply the waiter on the endpoint
waiter.wait(EndpointName=endpoint_name)

# Get endpoint status using describe endpoint
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

### Deploy to real-time endpoint (Optional)

Uncomment below code to deploy this to a real-time endpoint instead.

In [None]:
# endpoint_name = model_name

# predictor = txt_embed_model.deploy(
#     instance_type=instance_type,
#     initial_instance_count=instance_count,
#     endpoint_name=endpoint_name,
#     serializer=JSONSerializer(),
#     deserializer=JSONDeserializer(),
#     wait=False,
# )

### Predict

```python
def generate_embeddings(texts, model, tokenizer, normalize=True):
    """
    Generate embeddings for a list of texts using a pre-trained model.

    Args:
        texts (List[str]): List of texts to calculate embeddings for.
        model (AutoModel): Pre-trained model.
        tokenizer (AutoTokenizer): Tokenizer corresponding to the pre-trained model.
        normalize (bool, optional): Whether to normalize the embeddings. Defaults to True.

    Returns:
        Tensor: Tensor containing the embeddings for the texts.
    """

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Tokenize the texts
    encoded_input = tokenizer(
        texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
    )

    encoded_input = encoded_input.to(device)

    # Get the embeddings for the texts
    with torch.no_grad():
        model_output = model(**encoded_input)

        # Perform pooling. In this case, cls pooling.
        sentence_embeddings = model_output[0][:, 0]


    # Normalize embeddings if required
    if normalize:
        sentence_embeddings = F.normalize(text_embeddings, p=2, dim=1)

    # convert to numpy array
    sentence_embeddings = sentence_embeddings.cpu().numpy()
    ret_value = sentence_embeddings.tolist()

    return ret_value


def model_fn(model_dir):
    logger.info("model_fn")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    embeddings_tokenizer = AutoTokenizer.from_pretrained(model_dir)
    embeddings_model = AutoModel.from_pretrained(model_dir)
    embeddings_model.eval()
    embeddings_model.to(device)

    model = {
        "embeddings_model": embeddings_model,
        "embeddings_tokenizer": embeddings_tokenizer
    }

    return model


def predict_fn(texts, model):
    logger.info("predict_fn")

    embeddings_model = model["embeddings_model"]
    embeddings_tokenizer = model["embeddings_tokenizer"]

    ret_value = generate_embeddings(texts, embeddings_model, embeddings_tokenizer)

    return ret_value
```

Ref: <https://huggingface.co/thenlper/gte-base>


#### Uncomment below code block if you are invoking an existing endpoint

In [None]:
# endpoint_name = "gte-base-cc0cc-03Aug2023"
# predictor = Predictor(
#     endpoint_name=endpoint_name,
#     sagemaker_session=session,
#     serializer=JSONSerializer(),
#     deserializer=JSONDeserializer()
# )

In [None]:
sentences = ["That is a happy person", "That is a very happy person"]

embeddings = predictor.predict(sentences)

print(f"Embedding dimensions: {len(embeddings[0])}")
print(embeddings[0])

### Cleanup

In [None]:
predictor.delete_model()
predictor.delete_endpoint()