### References
- https://github.com/huggingface/notebooks/blob/main/sagemaker/19_serverless_inference/sagemaker-notebook.ipynb
- https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html

### Imports

In [12]:
import base64
import boto3
import json
import sagemaker
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.pytorch import PyTorchModel, PyTorchPredictor
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

### Upload tar.gz into S3

In [2]:
s3_resource = boto3.resource("s3")
aws_ils_model_bucket = s3_resource.Bucket("smu-ils-model")

In [3]:
with open("../model/clip_image/clip_image.tar.gz", "rb") as f:
    aws_ils_model_bucket.put_object(Key="model/clip_image.tar.gz", Body=f)
with open("../model/clip_text/clip_text.tar.gz", "rb") as f:
    aws_ils_model_bucket.put_object(Key="model/clip_text.tar.gz", Body=f)

### Serverless Config and Role

In [4]:
image_serverless_config = ServerlessInferenceConfig(
    memory_size_in_mb=4096, 
    # max_concurrency=10, # default is 5, sticking to 5
    # provisioned_concurrency=1, # default is None, assign a value to address cold start
)

text_serverless_config = ServerlessInferenceConfig(
    memory_size_in_mb=3072, 
    # max_concurrency=10,
    # provisioned_concurrency=1, # not addressing cold start problem to save costs
)

In [None]:
iam = boto3.client('iam')
role = iam.get_role(RoleName='aws-elasticbeanstalk-ec2-role')['Role']['Arn']

### Models

In [6]:
model_bucket = "smu-ils-model/model"

image_model = PyTorchModel(
    model_data=f"s3://{model_bucket}/clip_image.tar.gz",
    role=role,
    framework_version="2.0.0",
    py_version="py310",
)

text_model = PyTorchModel(
    model_data=f"s3://{model_bucket}/clip_text.tar.gz",
    role=role,
    framework_version="2.0.0",
    py_version="py310"
)

### One-time Deployment

In [7]:
image_predictor = image_model.deploy(
    serverless_inference_config=image_serverless_config,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)
text_predictor = text_model.deploy(
    serverless_inference_config=text_serverless_config,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

----!

### Inference

In [8]:
with open("../data/imagenet/tench/test/n01440764_1383.JPEG", "rb") as f:
    image_bytes = f.read()
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
image_data = {"img_b64": image_b64}
image_features = image_predictor.predict(image_data)
len(image_features)

768

In [None]:
text_data = {"class_name": "dog"}
text_features = text_predictor.predict(text_data)
len(text_features)

768


### Inference by Instantiating Predictor with Endpoint

In [9]:
image_model_endpoint_name = image_predictor.endpoint_name
print(image_model_endpoint_name)
text_model_endpoint_name = text_predictor.endpoint_name
print(text_model_endpoint_name)

pytorch-inference-2024-01-14-08-45-53-556
pytorch-inference-2024-01-10-09-34-42-078


In [10]:
image_model_endpoint_name = "pytorch-inference-2024-01-14-08-45-53-556"
loaded_image_predictor = PyTorchPredictor(
    endpoint_name=image_model_endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

with open("../data/imagenet/tench/test/n01440764_1383.JPEG", "rb") as f:
    image_bytes = f.read()
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
image_data = {"img_b64": image_b64}
image_features = image_predictor.predict(image_data)
len(image_features)

768

In [11]:
text_model_endpoint_name = "pytorch-inference-2024-01-10-09-34-42-078"
loaded_text_predictor = PyTorchPredictor(
    endpoint_name=text_model_endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)
text_data = {"class_name": "tench"}
text_features = loaded_text_predictor.predict(text_data)
len(text_features)

768

### Delete (Optional)

In [None]:
# image_predictor.delete_endpoint()
# text_predictor.delete_endpoint()

# image_model.delete_model()
# text_model.delete_model()