# Introduction to the SageMaker Large Model Inference (LMI) Container 
### Deploy a Stable Diffusion model on a SageMaker multi-model endpoint with LMI

In this notebook, we explore how to host multiple Stable Diffusion models behind a Multi-Model Endpoint on SageMaker using the [Large Model Inference](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-inference.html) container. This container is optimized for hosting large models using DJLServing. DJLServing is a high-performance universal model serving solution powered by the Deep Java Library (DJL) that is programming language agnostic. To learn more about DJL and DJLServing, you can refer to our recent [blog post](https://aws.amazon.com/blogs/machine-learning/deploy-large-models-on-amazon-sagemaker-using-djlserving-and-deepspeed-model-parallel-inference/).

This notebook was tested on a `ml.g5.xlarge` SageMaker Notebook instance using the `conda_pytorch_p39` kernel.

## Save a pre-trained Stable Diffusion model to deploy
As a first step, we'll import the relevant libraries, configure several global variables such as the hosting image that will be used and the S3 location of our model artifacts. We also download the model weights from the [Diffusers](https://huggingface.co/docs/diffusers/index) library and save them to S3.

In [None]:
!pip install sagemaker boto3 diffusers --upgrade  --quiet

In [None]:
import sagemaker
from sagemaker.model import Model
from sagemaker import serializers, deserializers
from sagemaker import image_uris
import boto3
import os
import time
import json
import jinja2
from pathlib import Path

In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
model_bucket = sess.default_bucket()  # bucket to house artifacts
s3_code_prefix = "large-model-djl-sd/code"  # folder within bucket where code artifact will go
s3_model_prefix = "hf-large-model-djl-sd/model"  # folder where model checkpoint will go

region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

s3_client = boto3.client("s3")  # client to intreract with S3 API
sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints
jinja_env = jinja2.Environment()  # jinja environment to generate model configuration templates

We get the container image for the Large Model Inference container.

In [None]:
inference_image_uri = image_uris.retrieve(
        framework="djl-deepspeed",
        region=sess.boto_session.region_name,
        version="0.21.0"
    )

The latest version of the LMI container does not include PyTorch 2.0, which has a bunch of new, useful optimization features; let's extend the public container and push that to ECR so that we are able to accelerate the Diffusion model's performance.

In [None]:
new_image_name = 'extended-djl-deepspeed'

In [None]:
"""
Often there are missing ECR permissions errors thrown when pulling public DLC's,
so I first login to the public image's ECR repo
"""
public_acct_number = inference_image_uri.split('.')[0]
!aws ecr get-login-password --region "$region" | docker login --username AWS --password-stdin "$public_acct_number".dkr.ecr."$region".amazonaws.com

In [None]:
%%capture build_output
!cd docker && bash build_and_push.sh "$new_image_name" 0.21.0 "$inference_image_uri" "$region"

The next cell checks if the container build process went OK, and retrieves the URI for our new container image.

In [45]:
if 'Error response from daemon' in str(build_output):
    print(build_output)
    raise SystemExit('\n\n!!There was an error with the container build!!')
else:
    extended_djl_image_uri = str(build_output).strip().split('\n')[-1]

We define a base URL location for our pretrained model artifacts.

In [None]:
pretrained_model_location = f's3://{bucket}/{s3_model_prefix}/sd'

Grab a pre-trained Stable Diffusion v1.4 model in fp16 precision (faster inference than fp32).

In [None]:
import diffusers
import torch 

pipeline = diffusers.DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",
                                                             cache_dir='hf_cache',
                                                             torch_dtype=torch.float16,
                                                             revision="fp16",
                                                            )

Save the model locally.

In [None]:
pipeline.save_pretrained('stable_diff')

Now, we simulate having several different finetuned stable diffusion models to serve on S3, by copying the same model artifacts to several different S3 locations.

In [None]:
n_models = 4

for i in range(n_models):
    str_idx = str(i)
    !aws s3 cp --recursive stable_diff/ "$pretrained_model_location"/sd-"$str_idx"

Check out all your models on S3.

In [None]:
!aws s3 ls "$pretrained_model_location"/

## Deploying Stable diffusion Using HF Accelerate


First, we create a `serving.properties` file for each model, which will specify which one of the LMI backends to use, specific backend configurations (see all [here](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-configuration.html)), the name of the inference handler to be used (if it is a custom script, it should be packaged along with the properties file, in the same directory) , and also the s3URL of the corresponding model artifacts.

We then tar each code directory (properties file + inference script) and upload it to S3. With other containers, Sagemaker downloads the .tar.gz file containing your model artifact and your inference code. With the LMI container, MME works a bit differently: the artifact that is automatically downloaded by SageMaker is actually the code artifact .tar.gz; within that tar, the `serving.properties` file holds the S3url information of the uncompressed model artifacts, which it is then the container's (in reality, the serving software inside it) responsibility to download.


In [None]:
for i in range(n_models):
    template = jinja_env.from_string(Path("accelerate_src/serving.template").open().read())
    Path("accelerate_src/serving.properties").open("w").write(
        template.render(s3url=f'{pretrained_model_location}/sd-{i}/')
    )
    !pygmentize accelerate_src/serving.properties | cat -n

    !tar czvf acc_model"$i".tar.gz accelerate_src/

    ds_s3_code_artifact = sess.upload_data(f'acc_model{i}.tar.gz', bucket, s3_code_prefix)
    print(f"S3 Code or Model tar ball uploaded to --- > {ds_s3_code_artifact}")

Notice that for each configuration file, we don't specify an [option.entryPoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-configuration.html#realtime-endpoints-large-model-configuration-general); by default, the server will look for a `model.py` file located in the same directory as the properties file (the naming convention is necessary).

The inference code will be the same for all of the deployed models. You can understand the required structure for the inference handler [here](https://docs.djl.ai/docs/serving/serving/docs/modes.html#python-mode), and see other example inference handlers [here](https://github.com/deepjavalibrary/djl-serving/tree/master/engines/python/setup/djl_python).

In [None]:
!pygmentize accelerate_src/model.py

## SageMaker multi-model endpoint

Having uploaded all our artifacts to S3, we will now deploy a SageMaker multi-model endpoint.

In [42]:
from sagemaker.multidatamodel import MultiDataModel

Create a MultiDataModel, deploy it, and instantiate a Predictor to run predictions against.

In [43]:
ds_endpoint_name = sagemaker.utils.name_from_base("lmi-mme-test")

model = MultiDataModel(ds_endpoint_name,
                       # This is where all the code tar.gz files are located with LMI, not the model artifacts
                       f's3://{bucket}/{s3_code_prefix}/',
                       image_uri=inference_image_uri,
                       role=role)

model.deploy(initial_instance_count=1,
             instance_type="ml.g5.xlarge",
             endpoint_name=ds_endpoint_name)

# our requests will be in json format, and responses as PNG Bytes, so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=ds_endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
    deserializer=deserializers.BytesDeserializer(),
)

-------------!

Invoke each model one time.

In [None]:
for i in range(n_models):
    start = time.time()
    response = predictor.predict(
        {"prompt": "astronaut spending money at a luxury store",},target_model=f'acc_model{i}.tar.gz'
    )

    print(f'Took {time.time()-start} seconds')

Decode the last response and show the image corresponding to it.

In [None]:
from io import BytesIO
from PIL import Image

stream = BytesIO(a)
image = Image.open(stream).convert("RGBA")
stream.close()
image.show()

Shut down the endpoint.

In [None]:
predictor.delete_endpoint()