### SageMaker Stable Diffusion Quick Kit - Inference Deployment (Stable Diffusion XL, SDXL LORA)
   [SageMaker Stable Diffusion Quick Kit](https://github.com/aws-samples/sagemaker-stablediffusion-quick-kit) provides a set of ready-to-use code and configuration files that help customers quickly build Stable Diffusion AI image generation services on AWS using Amazon SageMaker, Lambda, and Cloudfront.
   
   ![Architecture](https://raw.githubusercontent.com/aws-samples/sagemaker-stablediffusion-quick-kit/main/images/architecture.png)


#### Prerequisites
1. Amazon AWS Account
2. Recommended to use ml.g5.xlarge

### Notebook Deployment Steps
1. Upgrade boto3, sagemaker python sdk
2. Build docker image
3. Deploy AIGC inference service
    * Configure model parameters
    * Configure asynchronous inference
    * Deploy SageMaker Endpoint
4. Test ControlNet model
5. Clean up resources

### 1. Upgrade boto3, sagemaker python sdk

In [None]:
!pip install --upgrade boto3 sagemaker

In [None]:
#Import corresponding libraries

import re
import os
import json
import uuid

import numpy as np
import pandas as pd
from time import gmtime, strftime


import boto3
import sagemaker

from sagemaker import get_execution_role,session

role = get_execution_role()


sage_session = session.Session()
bucket = sage_session.default_bucket()
aws_region = boto3.Session().region_name


print(f'sagemaker sdk version: {sagemaker.__version__}\nrole:  {role}  \nbucket:  {bucket}')


### 2. Build docker image (sdxl-inference-v2)

Use loca environment or else to build and push the container with the script `build_push.sh`
- Check region definition
- separate pip install to avoid long package dependencies resolution time

In [None]:
# !./build_push.sh

### 3. Deploy AIGC Inference Service

#### 3.1 Create dummy model_data file (actual model is loaded using code/inference.py)

In [None]:
!touch dummy
!tar czvf model.tar.gz dummy sagemaker-logo-small.png
assets_dir = 's3://{0}/{1}/assets/'.format(bucket, 'stablediffusion')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(bucket, 'stablediffusion')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

#### 3.2 Create model configuration

In [None]:

boto3_session = boto3.session.Session()
current_region=boto3_session.region_name

client = boto3.client("sts")
account_id=client.get_caller_identity()["Account"]

client = boto3.client('sagemaker')

#Use docker images built in step 2
#Default name: sdxl-inference-v2
container = f'{account_id}.dkr.ecr.{current_region}.amazonaws.com/sdxl-inference-v2'

model_data = f's3://{bucket}/stablediffusion/assets/model.tar.gz'

model_name = 'AIGC-Quick-Kit-' +  strftime("%Y-%m-%d-%H-%M-%S", gmtime())
role = get_execution_role()

primary_container = {
    'Image': container,
    'ModelDataUrl': model_data,
    'Environment':{
        's3_bucket': bucket,
        'model_name':'stabilityai/stable-diffusion-xl-base-1.0' #Use SDXL 1.0
    }
}

create_model_response = client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = primary_container,


)

In [None]:
_time_tag = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
_variant_name =  'AIGC-Quick-Kit-'+ _time_tag
endpoint_config_name = 'AIGC-Quick-Kit-' +  _time_tag

response = client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            'VariantName': _variant_name,
            'ModelName': model_name,
            'InitialInstanceCount': 1,
            'InstanceType': 'ml.g5.2xlarge',
            'InitialVariantWeight': 1
        },
    ]
    ,
    AsyncInferenceConfig={
        'OutputConfig': {
            'S3OutputPath': f's3://{bucket}/stablediffusion/asyncinvoke/out/'
        }
    }
)

#### 4.1 Create testing helper methods

In [None]:
import time
import uuid
import io
import traceback
from PIL import Image


s3_resource = boto3.resource('s3')

def get_bucket_and_key(s3uri):
    pos = s3uri.find('/', 5)
    bucket = s3uri[5 : pos]
    key = s3uri[pos + 1 : ]
    return bucket, key


def predict_async(endpoint_name,payload):
    runtime_client = boto3.client('runtime.sagemaker')
    input_file=str(uuid.uuid4())+".json"
    s3_resource = boto3.resource('s3')
    s3_object = s3_resource.Object(bucket, f'stablediffusion/asyncinvoke/input/{input_file}')
    payload_data = json.dumps(payload).encode('utf-8')
    s3_object.put( Body=bytes(payload_data))
    input_location=f's3://{bucket}/stablediffusion/asyncinvoke/input/{input_file}'
    print(f'input_location: {input_location}')
    response = runtime_client.invoke_endpoint_async(
        EndpointName=endpoint_name,
        InputLocation=input_location
    )
    result =response.get("OutputLocation",'')
    wait_async_result(result)


def s3_object_exists(s3_path):
    """
    s3_object_exists
    """
    try:
        s3 = boto3.client('s3')
        base_name=os.path.basename(s3_path)
        _,ext_name=os.path.splitext(base_name)
        bucket,key=get_bucket_and_key(s3_path)
        
        s3.head_object(Bucket=bucket, Key=key)
        return True
    except Exception as ex:
        print("job is not completed, waiting...")   
        return False
    
def draw_image(output_location):
    try:
        bucket, key = get_bucket_and_key(output_location)
        obj = s3_resource.Object(bucket, key)
        body = obj.get()['Body'].read().decode('utf-8') 
        predictions = json.loads(body)
        print(predictions['result'])
        for image in predictions['result']:
            bucket, key = get_bucket_and_key(image)
            obj = s3_resource.Object(bucket, key)
            bytes = obj.get()['Body'].read()
            image = Image.open(io.BytesIO(bytes))
            #resize image to 50% size
            half = 0.5
            out_image = image.resize( [int(half * s) for s in image.size] )
            out_image.show()
    except Exception as e:
        print("result is not completed, waiting...")   
    

    
def wait_async_result(output_location,timeout=60):
    current_time=0
    while current_time<timeout:
        if s3_object_exists(output_location):
            print("have async result")
            draw_image(output_location)
            break
        else:
            time.sleep(5)

            
        
def check_sendpoint_status(endpoint_name,timeout=600):
    client = boto3.client('sagemaker')
    current_time=0
    while current_time<timeout:
        client = boto3.client('sagemaker')
        try:
            response = client.describe_endpoint(
            EndpointName=endpoint_name
            )
            if response['EndpointStatus'] !='InService':
                raise Exception (f'{endpoint_name} not ready , please wait....')
        except Exception as ex:
            print(f'{endpoint_name} not ready , please wait....')
            time.sleep(10)
        else:
            status = response['EndpointStatus']
            print(f'{endpoint_name} is ready, status: {status}')
            break
        

#### 3.3 Deploy SageMaker endpoint (only need to run this once!!!)

### Local mode

In [None]:
# # CHECK if GPUs are available and set the corresponding "instance_type"
# import os
# import subprocess

# set_instance_type_local = True

# if set_instance_type_local:
#     instance_type = "local"
#     if subprocess.call("nvidia-smi") == 0:
#         ## Set type to GPU if one is present
#         instance_type = "local_gpu"
# else:
#     instance_type = 'ml.g5.2xlarge'


# print("Instance type = " + instance_type)

In [None]:
# from sagemaker import Model

# import sagemaker
# from sagemaker.local import LocalSession

# session_local = LocalSession()
# sagemaker_session.config = {"local": {"local_code": True}}
# print(type(session_local))

# model_data = model_data

# estimator = Model(
#     image_uri=container,
#     model_data=model_data,
#     role=role,
#     # source_dir="container/bert-topic",  # this argument is used to override internal container entrypoint, if needed!
#     # entry_point="bert-topic-inference.py",  # this argument is used to override internal container entrypoint, if needed!
#     sagemaker_session=session_local,  # local session
#     #                   predictor_cls=None,
#                       env={"s3_bucket": bucket},
#     #                   name=None,
#     #                   vpc_config=None,
#     #                   enable_network_isolation=False,
#     #                   model_kms_key=None,
#     #                   image_config=None,
#     #                   code_location=None,
#     #                   container_log_level=20,
#     #                   dependencies=None,
#     #                   git_config=None
# )

# predictor = estimator.deploy(1, instance_type, container_startup_health_check_timeout=600)

#### Test locally

In [None]:
# bucket

In [None]:
# import json

# sagemaker_session = LocalSession()
# sagemaker_session.config = {"local": {"local_code": True}}

# # payload={
# #     "prompt": "a fantasy creaturefractal dragon",
# #     "steps":20,
# #     "sampler":"euler_a",
# #     "count":1
# #   }

# payload={
#                     "prompt": "a fantasy creaturefractal dragon",
#                     "steps":20,
#                     "sampler":"euler_a",
#                     "count":1,
#                     "control_net_enable":"disable",
#                      "sdxl_refiner":"enable",
#                      "lora_name":"dragon",
#                     # "lora_url":"https://civitai.com/api/download/models/129363"
#                     "lora_url":"https://sagemaker-us-east-1-777200923596.s3.us-east-1.amazonaws.com/sagemaker/DTStyle.safetensors?response-content-disposition=inline&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Security-Token=IQoJb3JpZ2luX2VjEFUaCXVzLWVhc3QtMSJGMEQCIFTFGUJ6mvdtyUSkeJGdWVx4hL3xWyViQt17eFgvfv8vAiA%2BIVgl%2FzqWc2PM9xAKkmecKQXOmQ3CurwgUup49R0N6yrJAwheEAAaDDc3NzIwMDkyMzU5NiIMSCgciDztstFeveWaKqYDMv3iEE7HYUpb7MJTwSNIGSJAa7DZFvxCfvE0J05%2F12b9trnNsGCjbGHLGXFI1lkC1U0OQSKaA1XoSdH0G2gcUdAFqbELWaXoD2IsrRSNhnIxKWsQf9hoIAK321jX%2BdVxpyVM6so%2FghtEFUAq8PiYTiQ5FYUPVScpMW1nbn9c4DefI62l0gB25vdUdVZ37hPFO7rIA3n9M2%2F943GDkXjSWtH9tC4llW1LvaNGcxRE9RuYcRQAsePDlqDifUW0zrUMTHvdbCdK2%2FQyzMhbuIpNqY2fiTYKMHus8C8HyueNhOUG8z3cvWuQxzCrYe5%2BprEfq8s80PAJGOhBL9OswGAJmn2Z07rq4UWrowPx9tQGS6ZF7ATXx0sUhNAbzuDjJkn6cDb1VTOZ%2B5UvCgjdpNY09g8LLVr9jPLNtqUsIoNnVYchf2V%2B9D7CpEOIFVrSaxm5JlWAUW84satZf4PuWsR1%2BgIfkPuYPs6dPtRUygTbr3nDO4P4G8eeVKWKGYctzD13kTZaBXGd6%2BHgyY5369IUUWqJaYCuZZnEbp1VDooUUyFoy%2BdmbHkw%2FoLdvAY6uAIntU43hhgm6qntFafZVr6L7h7aAKs8rCt4EzMbFcheGCb0kprsSMvkzRFHuKXu3F4L%2FbeoHqzhiHv6W%2BV%2B3oFasbaK6rB9l18Wlifj9mpGzQt21WYgM9VBKlRe2%2BnthwFMPCYpPi2fYKsuxiH2fNkV7HYqI0zGph%2BzKB9EYNvn3mW8U0CHyxxuX8hS2SKRRBad5EM8u6QVhq3Sd3aLZZBgWAfNgZzRgFnm7pjBBZ3Cso0eb98y1PaJemMpSDNuFLUi3e5mY76bPogUs69sURlhqy0Z84%2BRUQa13wEUdR%2B9qCaf1Xr5bnyRR0QYAeNADjIKA4G9XLfgGKxamNfr60k2vFLPxIflF3po94OALZgjs0%2FcKQm2eEstZCcdk9hjgJbS%2F%2FDNy8XtoQYmoYeWbIzo9biglK3Wgv8%3D&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=ASIA3J5GI6PGHTX5SFHJ%2F20250127%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250127T125043Z&X-Amz-Expires=43200&X-Amz-SignedHeaders=host&X-Amz-Signature=95e9e40f75d9a87527d9bd6aa550f3961c95105c4e0f62ad53043587887ac96b"
# }

# # # Convert the native request to JSON.
# # request = json.dumps(native_request)

# sm_client = sagemaker_session.sagemaker_runtime_client
# response = sm_client.invoke_endpoint(
#     EndpointName="local-endpoint",
#     ContentType="application/json",
#     Body=json.dumps(payload),
# )

# r = response["Body"]
# r_dec = r.read().decode()
# print("RESULT r.read().decode():", r_dec)

In [None]:
# obj = s3_resource.Object(bucket, "stablediffusion/asyncinvoke/images/"+"d62b26f9-6344-4927-a544-70312befc202.jpg")
# bytes = obj.get()['Body'].read()
# image = Image.open(io.BytesIO(bytes))
# #resize image to 50% size
# half = 0.5
# out_image = image.resize( [int(half * s) for s in image.size] )
# out_image.show()

### Deploy

In [None]:
endpoint_name = f'AIGC-Quick-Kit-{str(uuid.uuid4())}'

print(f'Endpoint:{endpoint_name} is being created. Model is loading during first startup, please wait patiently. Check status in console')

response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
)



### 4. Testing endpoint

#### Check endpoint status

In [None]:
endpoint_name= response["EndpointArn"].split("endpoint/")[1].replace("aigc-quick-kit","AIGC-Quick-Kit")
check_sendpoint_status(endpoint_name)

### 4.1 Test SDXL without refiner
When executing for the first time, SageMaker will pull the stabilityai/stable-diffusion-xl-base-1.0 model from HuggingFace, please wait a moment

In [None]:

payload={
    "prompt": "a fantasy creaturefractal dragon",
    "steps":20,
    "sampler":"euler_a",
    "count":1
  }

predict_async(endpoint_name,payload)


### 4.2 Test SDXL with refiner
When executing for the first time, SageMaker will pull the stabilityai/stable-diffusion-xl-refiner-1.0 model from HuggingFace, please wait a moment

Set SDXL_REFINER to enable

In [None]:
payload={
                    "prompt": "a fantasy creaturefractal dragon",
                    "steps":20,
                    "sampler":"euler_a",
                    "count":1,
                    "sdxl_refiner":"enable"
}

predict_async(endpoint_name,payload)

### 4.3 Test SDXL with LORA

In [None]:
# the first time the Lora is downloaded and loaded into memory it will take some more time.
# This specific Lora will not work when using the civitai link as it requires to be logged, so uploaded it to S3 and used pre-signed URL

payload={
                    "prompt": "a fantasy creaturefractal dragon,<lora:DTStyle:1> DTstyle,", # refined dragon LoRA cfr https://civitai.com/models/119157
                    "steps":20,
                    "sampler":"euler_a",
                    "count":1,
                    "control_net_enable":"disable",
                     "sdxl_refiner":"enable",
                    #  "lora_name":"dragon",
                    # "lora_url":"https://civitai.com/api/download/models/129363" # download requires being logged in, put it to S3 and create pre-signed URL or make public
                     "lora_name":"dragonlora",
                    "lora_url":"https://sagemaker-us-east-1-123456789012.s3.us-east-1.amazonaws.com/sagemaker/DTStyle.safetensors"
}

predict_async(endpoint_name,payload)

In [None]:
payload={
                    "prompt": "man playing guitar in a subway station", # NO LORA
                    "steps":50,
                    "sampler":"euler_a",
                    "count":1,
                    "control_net_enable":"disable",
                     "sdxl_refiner":"enable",
                     # "lora_name":"picassolora",
                    # "lora_url":"https://civitai.com/api/download/models/140189?type=Model&format=SafeTensor"
}

predict_async(endpoint_name,payload)

In [None]:
payload={
                    "prompt": "<lora:p5c0:1>p5c0 man playing guitar in a subway station", # PICASSO LORA, cfr https://civitai.com/models/128076/pablo-picasso-sdxl-10-art-style-lora
                    "steps":50,
                    "sampler":"euler_a",
                    "count":1,
                    "control_net_enable":"disable",
                     "sdxl_refiner":"enable",
                     "lora_name":"picassolora",
                    "lora_url":"https://civitai.com/api/download/models/140189?type=Model&format=SafeTensor"
}

predict_async(endpoint_name,payload)

### 4.4 Test SDXL with controlnet

We will use 2 Controlnets from HuggingFace for testing: Canny(diffusers/controlnet-canny-sdxl-1.0-small) and Depth(diffusers/controlnet-depth-sdxl-1.0-small).

First, test Canny

In [None]:
payload={
    "prompt": "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting",
    "steps":20,
    "sampler":"euler_a",
    "count":1,
    "control_net_enable":"enable",
    "sdxl_refiner":"enable",
    "control_net_model":"canny",
    "input_image":"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
  }


predict_async(endpoint_name,payload)

SDXL Depth Testing

In [None]:
payload={
    "prompt": "stormtrooper lecture, photorealistic",
    "steps":20,
    "sampler":"euler_a",
    "count":1,
    "control_net_enable":"enable",
    "sdxl_refiner":"enable",
    "control_net_model":"depth",
    "input_image":"https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png"
  }



predict_async(endpoint_name,payload)

### 5 Clean up resources

In [None]:
break

In [None]:
response = client.delete_endpoint(
    EndpointName=endpoint_name
    
)

response = client.delete_endpoint_config(
    EndpointConfigName=endpoint_config_name
)


print(f'Endpoint:{endpoint_name} has been deleted, please check status in console')
