# Deploy hkunlp/instructor embedding models to Amazon SageMaker

In this notebook, we demonstrate packaging/deploying hkunlp/instructor-\* embedding models with 768 dimensions to Amazon SageMaker.

an instruction-finetuned text embedding model that can generate text embeddings tailored to any task (e.g., classification, retrieval, clustering, text evaluation, etc.) and domains (e.g., science, finance, etc.) by simply providing the task instruction, without any finetuning. Instructor👨‍ achieves sota on 70 diverse embedding tasks! The model is easy to use with our customized sentence-transformer library.

## Papers

- https://arxiv.org/abs/2212.09741
- https://instructor-embedding.github.io/
- https://huggingface.co/papers/2212.09741

## Models

- instructor-base
  - [hkunlp/instructor-base](https://huggingface.co/hkunlp/instructor-base)
- instructor-large
  - [hkunlp/instructor-large](https://huggingface.co/hkunlp/instructor-large)

## Inference script to handle both embedding and re-ranking

Refer to [./models/bi-encoders/instructor-base/code/inference.py](./models/bi-encoders/instructor-base/code/inference.py) for implementation details.


In [None]:
# !pip install -U sagemaker rich watermark InstructorEmbedding --quiet

In [None]:
import json
import os
import subprocess
from datetime import datetime
from pathlib import Path
from uuid import uuid4

import boto3
import sagemaker
from huggingface_hub import snapshot_download
from rich import print
from sagemaker import get_execution_role
from sagemaker.deserializers import JSONDeserializer
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.predictor import Predictor
from sagemaker.s3 import S3Downloader, S3Uploader, s3_path_join
from sagemaker.serializers import JSONSerializer
from sagemaker.session import Session

In [None]:
session = sagemaker.Session()
bucket_name = session.default_bucket()
role = get_execution_role()
region = session.boto_region_name
# Define sagemaker client object to invoke Sagemaker services
sm_client = boto3.client("sagemaker", region_name=region)

HF_MODEL_ID = "hkunlp/instructor-base"
model_base_name = HF_MODEL_ID.split("/")[-1]
model_folder = Path(f"./models/bi-encoders/{model_base_name}").absolute().resolve()
model_archive_path = model_folder.joinpath("model.tar.gz")
code_archive_path = model_folder.joinpath("sourcedir.tar.gz")
current_dir = os.getcwd()
print(current_dir)

In [None]:
if not model_folder.exists():
    snapshot_download(HF_MODEL_ID, local_dir=str(model_folder), local_dir_use_symlinks=False)
else:
    print(f"Model {HF_MODEL_ID} exists at: {str(model_folder)}")

### Create Model

- Compress model artifacts to `model.tar.gz`
- Upload model to S3
- Create Model object


In [None]:
model_files_to_compress = [
    "pytorch_model.bin",
    "spiece.model",
    "config.json",
    "tokenizer.json",
    "tokenizer_config.json",
    "special_tokens_map.json",
    "modules.json",
    "sentence_bert_config.json",
    "config_sentence_transformers.json",
    "1_Pooling",
    "2_Dense",
]

In [None]:
# change to model dir and run tar command
print(current_dir)

model_archive_path = model_folder.joinpath("model.tar.gz")
if model_archive_path.exists():
    model_archive_path.unlink()

if not os.path.exists(str(model_archive_path)):
    print(str(model_folder))
    os.chdir(str(model_folder))
    model_files = " ".join(model_files_to_compress)
    print(f"Compressing model files to model.tar.gz...")
    command = f"tar -cf model.tar.gz --use-compress-program=pigz {model_files}"
    result = subprocess.run(command, shell=True, check=True)
    if result.returncode != 0:
        raise Exception(f"Failed to compress model files: {result.stderr}")
    else:
        print("Successfully compressed model files")
    os.chdir(current_dir)

In [None]:
# Upload model artifact to S3
s3_suffix = f"models/txt-embedding-models/{model_base_name}"
upload_path_s3 = s3_path_join(f"s3://{bucket_name}", s3_suffix)
print(f"Uploading model ...")
model_data_url = S3Uploader.upload(
    local_path=str(model_archive_path),
    desired_s3_uri=upload_path_s3,
    sagemaker_session=session,
)
print(f"Model Data URL: {model_data_url}")

In [None]:
suffix = f"{str(uuid4())[:5]}-{datetime.now().strftime('%d%b%Y')}"
model_name = f"{model_base_name}-{suffix}"
print(f"Model name: [b]{model_name}[/b]")

Create HuggingFaceModel with model data and custom `inference.py` script

https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-model

**NOTE:** If you specify `entry_point=` parameter in `HuggingFaceModel` then model artifacts will be uploaded to root of default Sagemaker S3 Bucket. If ignored, model will be created using artifacts from `model_data_url` in S3.


In [None]:
print(current_dir)

In [None]:
print(f"Creating model: {model_name}")

txt_embed_model = HuggingFaceModel(
    model_data=model_data_url,
    role=role,
    transformers_version="4.26.0",
    source_dir="code",
    entry_point="inference.py",
    pytorch_version="1.13.1",
    sagemaker_session=session,
    py_version="py39",
    name=model_name,
)

### Deploy Model to Serverless endpoint


In [None]:
from sagemaker.serverless import ServerlessInferenceConfig

# Memory In GiB
memory = 4096
max_concurrency = 10
endpoint_name = model_name
serverless_config = ServerlessInferenceConfig(
    memory_size_in_mb=memory, max_concurrency=max_concurrency
)

print(f"Creating endpoint: [b]{endpoint_name}[/b] ...")

# Returns a HuggingFacePredictor
predictor = txt_embed_model.deploy(
    endpoint_name=endpoint_name,
    serverless_inference_config=serverless_config,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
    wait=False,
    env={
        "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
    },
)

### Wait for endpoint to be `InService` state


In [None]:
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

# Get the waiter object
waiter = sm_client.get_waiter("endpoint_in_service")
# Apply the waiter on the endpoint
waiter.wait(EndpointName=endpoint_name)

# Get endpoint status using describe endpoint
status = sm_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
print(f"Endpoint [b]{endpoint_name}[/b] Status: [i]{status}[/i]")

### Deploy to real-time endpoint (Optional)

Uncomment below code to deploy this to a real-time endpoint instead.


In [None]:
# endpoint_name = model_name

# predictor = txt_embed_model.deploy(
#     instance_type=instance_type,
#     initial_instance_count=instance_count,
#     endpoint_name=endpoint_name,
#     serializer=JSONSerializer(),
#     deserializer=JSONDeserializer(),
#     wait=False,
# )

### Predict

**NOTE:**

Inputs must contain "instruction" based on the embedding task.

For e.g., for information retrieval

```python

import numpy as np

from sklearn.metrics.pairwise import cosine_similarity
query  = [['Represent the question for retrieving documents: ','Which regions is Amazon SageMaker available?']]
corpus = [['Represent the document for retrieval: ','Amazon SageMaker is a fully managed service to prepare data and build, train, and deploy machine learning (ML) models for any use case with fully managed infrastructure, tools, and workflows. 1	For a list of the supported Amazon SageMaker AWS Regions, please visit the AWS Regional Services page. Also, for more information, see Regional endpoints in the AWS general reference guide.'],
          ['Represent the document for retrieval: ',"The disparate impact theory is especially controversial under the Fair Housing Act because the Act regulates many activities relating to housing, insurance, and mortgage loansâ€”and some scholars have argued that the theory's use under the Fair Housing Act, combined with extensions of the Community Reinvestment Act, contributed to rise of sub-prime lending and the crash of the U.S. housing market and ensuing global economic recession"],
          ['Represent the document for retrieval: ','Disparate impact in United States labor law refers to practices in employment, housing, and other areas that adversely affect one group of people of a protected characteristic more than another, even though rules applied by employers or landlords are formally neutral. Although the protected classes vary by statute, most federal civil rights laws protect based on race, color, religion, national origin, and sex as protected traits, and some laws include disability status and other traits as well.']]

query_embeddings = model.encode(query)
corpus_embeddings = model.encode(corpus)

similarities = cosine_similarity(query_embeddings,corpus_embeddings)

retrieved_doc_id = np.argmax(similarities)

print(retrieved_doc_id)
```

Ref: <https://huggingface.co/hkunlp/instructor-base>


### Uncomment below for existing endpoint


In [None]:
# endpoint_name ="instructor-base-77ffb-02Aug2023"

# predictor = sagemaker.huggingface.HuggingFacePredictor(
#     endpoint_name=endpoint_name,
#     serializer=JSONSerializer(),
#     deserializer=JSONDeserializer(),
#     sagemaker_session=session
# )

In [None]:
print(f"Invoking endpoint: {endpoint_name}")

sentences = [
    [
        "Represent the question for retrieving documents: ",
        "This is an example question sentence",
    ],
    ["Represent the document for retrieval: ", "This is an example corpus document"],
]

embeddings = predictor.predict(sentences)

print(f"Embedding dimensions: {len(embeddings[0])}")
# print(embeddings[0])

### Verify Logs emitted by the endpoint in CloudWatch


In [None]:
import time

logs_client = boto3.client("logs")

# Get the current time and calculate the timestamp for 10 minutes ago
current_time = int(time.time() * 1000)
ten_minutes_ago = current_time - (10 * 60 * 1000)

log_group_name = f"/aws/sagemaker/Endpoints/{endpoint_name}"

# Get the log streams within the log group
log_streams_response = logs_client.describe_log_streams(logGroupName=log_group_name)

# Iterate through each log stream and print the logs
for log_stream in log_streams_response["logStreams"]:
    log_stream_name = log_stream["logStreamName"]

    # Get the logs for the specific log stream
    log_events_response = logs_client.get_log_events(
        logGroupName=log_group_name,
        logStreamName=log_stream_name,
        startTime=ten_minutes_ago,
        endTime=current_time,
    )

    # Print the logs
    for event in log_events_response["events"]:
        print(event["message"])

### Cleanup


In [None]:
# predictor.delete_model()
# predictor.delete_endpoint()