## Stable Diffusion deployment

How to use DJL
- DJL container list : https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-dlc.html
- DJL config list : https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-configuration.html
- DJL tutorial : https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-tutorials.html
- DJL default Stable diffusion infernece script : https://github.com/deepjavalibrary/djl-serving/blob/master/engines/python/setup/djl_python/stable-diffusion.py

### Additional optimization
Stable diffusion w/ DeepSpeed (1.4)
- https://www.philschmid.de/stable-diffusion-deepspeed-inference

Few ways to make it faster (It uses not diffusers)
- https://lightning.ai/pages/community/serve-stable-diffusion-three-times-faster/


### Container that used for deployment
- deepspeed: `763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117`
- fastertransformer: `763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-fastertransformer5.3.0-cu117`

In [None]:
%store -r

In [None]:
# sd20_basic_model_path
sd21_basic_model_path

In [None]:
# sd20_model_artifact
sd21_model_artifact

In [None]:
import boto3
import sagemaker
from sagemaker.utils import name_from_base
from sagemaker import image_uris

In [None]:
llm_engine = "deepspeed"
# llm_engine = "fastertransformer"

In [None]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
sm_client = sagemaker_session.sagemaker_client
sm_runtime_client = sagemaker_session.sagemaker_runtime_client

In [None]:
framework_name = f"djl-{llm_engine}"
inference_image_uri = image_uris.retrieve(
    framework=framework_name, region=sagemaker_session.boto_session.region_name, version="0.21.0"
)

print(f"Inference container uri: {inference_image_uri}")

In [None]:
s3_target = f"s3://{sagemaker_session.default_bucket()}/llm/stable-diffusion/code/"
print(s3_target)

In [None]:
!rm -rf sd21-src.tar.gz
!tar zcvf sd21-src.tar.gz sd21-src --exclude ".ipynb_checkpoints" --exclude "__pycache__"
!aws s3 cp sd21-src.tar.gz {s3_target}

In [None]:
model_uri = f"{s3_target}sd21-src.tar.gz"
print(model_uri)

In [None]:
# model_name = name_from_base(f"sd20-djl")
model_name = name_from_base(f"sd21-djl")
print(model_name)

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "ModelDataUrl": model_uri},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
# instance_type = "ml.g5.2xlarge"
# instance_type = "ml.g5.xlarge"
instance_type = "ml.g4dn.xlarge"

endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
        },
    ],
)
print(endpoint_config_response)

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

In [None]:
import json

In [None]:
# prompt = "Sage are playing games with his pet, disney style"
prompt = "John snow from game of throne, disney style"

In [None]:
%%time
prompts = [prompt]
response_model = sm_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps(
        {
            "text": prompts,
            "upload_s3_bucket": sagemaker_session.default_bucket(),
            "prompt": prompt
        }
    ),
    ContentType="application/json",
)

In [None]:
# # Direct return of image
# import io
# from PIL import Image
# img_output = response_model["Body"].read()
# data_io = io.BytesIO(img_output)
# img = Image.open(data_io)
# img

In [None]:
output = str(response_model["Body"].read(), "utf-8")

In [None]:
output

In [None]:
import os
import boto3
from IPython.display import Image

s3_client = boto3.client('s3')

def view_s3_file(s3_uri):
    chunks = s3_uri.split("/")
    filename = chunks[-1]
    bucket = chunks[2]
    object_name = "/".join(chunks[3:])
    local_path = os.path.join("./test-output", filename)
    s3_client.download_file(bucket, object_name, local_path)
    display(Image(filename=local_path))
    


In [None]:
view_s3_file(output)