# Deploying Stable Diffusion using Stability AI DLC on AWS SageMaker

## Example: Stable Diffusion XL v0.9 on PyTorch 2.0.1

This example will deploy an endpoint running Stable Diffusion XL on AWS SageMaker using the Stability AI DLC. This example can provide inference as-is or serve as a basis for custom development & deployment scenarios.

If you are looking for a turnkey solution for inference with a full-featured API, check out [SDXL on AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?id=seller-mybtdwpr2puau) and the related [Jumpstart notebooks](https://github.com/Stability-AI/aws-jumpstart-examples).

In [None]:
# NOTE: You may have to restart your kernel after installing boto3
!pip install "sagemaker>=2.173.0" "huggingface_hub>=0.16.4" "boto3>=1.28.9" --upgrade --quiet

import sagemaker
from sagemaker import ModelPackage, get_execution_role

from PIL import Image
from typing import Union
import io
import os
import base64
import boto3

## 1. Download the model weights

### SDXL 0.9 is a non-commercial model. You must apply for access to download the [base model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9) and [refiner model](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-0.9) then add your [HuggingFace token](https://huggingface.co/settings/tokens) below.

In [None]:
os.environ['HUGGING_FACE_HUB_TOKEN'] = 'YOUR TOKEN HERE'

In [None]:
import os
from huggingface_hub import snapshot_download
local_dir = './model'
snapshot_download(
    repo_id="stabilityai/stable-diffusion-xl-base-0.9",    
    allow_patterns="sd_xl_base_0.9.safetensors",
    local_dir=local_dir,
    local_dir_use_symlinks=False)
snapshot_download(
    repo_id="stabilityai/stable-diffusion-xl-refiner-0.9",
    allow_patterns="sd_xl_refiner_0.9.safetensors",    
    local_dir=local_dir,
    local_dir_use_symlinks=False)


## 2. Custom Inference Script Creation

In [None]:
!mkdir -p model/code

### Inference Script: Text2Image, Image2Image

In [None]:
%%writefile model/code/sdxl_inference.py
from io import BytesIO
from einops import rearrange
import json
from omegaconf import OmegaConf
from pathlib import Path
from PIL import Image
from pytorch_lightning import seed_everything
import numpy as np
from sagemaker_inference.errors import BaseInferenceToolkitError
from sgm.inference.helpers import (
    load_model_from_config,
    get_input_image_tensor,
    get_sampler,
    get_discretization,
    get_guider,
    do_sample,
    do_img2img,
    apply_refiner,
    embed_watermark,
    Img2ImgDiscretizationWrapper,
)
import os


def model_fn(model_dir, context=None):
    # Enable the refiner by default
    disable_refiner = os.environ.get("SDXL_DISABLE_REFINER", "false").lower() == "true"

    # This hardcoded path is needed due to an import error in the sgm library
    sgm_path = "/opt/ml/stability-ai/generative-models"  # os.path.dirname(sgm.__file__)
    base_config_path = os.path.join(sgm_path, "configs/inference/sd_xl_base.yaml")
    refiner_config_path = os.path.join(sgm_path, "configs/inference/sd_xl_refiner.yaml")
    base_model_path = os.path.join(model_dir, "sd_xl_base_0.9.safetensors")
    refiner_model_path = os.path.join(model_dir, "sd_xl_refiner_0.9.safetensors")
    print(f"Loading base model config from {base_config_path}")
    base_config = OmegaConf.load(os.path.join(sgm_path, base_config_path))
    print(f"Loading base model from {base_model_path}")
    base_model, _ = load_model_from_config(base_config, base_model_path)
    base_model.conditioner.half()
    base_model.model.half()
    if disable_refiner:
        print("Refiner model disabled by SDXL_DISABLE_REFINER environment variable")
        refiner_model = None
    else:
        print(f"Loading refiner model config from {refiner_config_path}")
        refiner_config = OmegaConf.load(os.path.join(sgm_path, refiner_config_path))
        print(f"Loading refiner model from {refiner_model_path}")
        refiner_model, _ = load_model_from_config(refiner_config, refiner_model_path)
        refiner_model.conditioner.half()
        refiner_model.model.half()

    return {"base": base_model, "refiner": refiner_model}


def input_fn(request_body, request_content_type):
    if request_content_type == "application/json":
        model_input = json.loads(request_body)
        if not "text_prompts" in model_input:
            raise BaseInferenceToolkitError(400, "text_prompts missing")
        return model_input
    else:
        raise BaseInferenceToolkitError(400, "Content-type must be application/json")


def predict_fn(data, model, context=None):
    # Only a single positive and optionally a single negative prompt are supported by this example.
    prompts = []
    negative_prompts = []
    if "text_prompts" in data:
        for text_prompt in data["text_prompts"]:
            if "text" not in text_prompt:
                raise BaseInferenceToolkitError(400, "text missing from text_prompt")
            if "weight" not in text_prompt:
                text_prompt["weight"] = 1.0
            if text_prompt["weight"] < 0:
                negative_prompts.append(text_prompt["text"])
            else:
                prompts.append(text_prompt["text"])

    if len(prompts) != 1:
        raise BaseInferenceToolkitError(
            400, "One prompt with positive or default weight must be supplied"
        )
    if len(negative_prompts) > 1:
        raise BaseInferenceToolkitError(
            400, "Only one negative weighted prompt can be supplied"
        )

    seed = 0
    height = 1024
    width = 1024
    sampler_name = "DPMPP2MSampler"
    cfg_scale = 7.0
    steps = 50
    use_pipeline = model["refiner"] is not None

    if "height" in data:
        height = data["height"]
    if "width" in data:
        width = data["width"]
    if "sampler" in data:
        sampler_name = data["sampler"]
    if "cfg_scale" in data:
        cfg_scale = data["cfg_scale"]
    if "steps" in data:
        steps = data["steps"]
    if "seed" in data:
        seed = data["seed"]
        seed_everything(seed)
    if "use_pipeline" in data:
        use_pipeline = data["use_pipeline"]

    if model["refiner"] is None and use_pipeline:
        raise BaseInferenceToolkitError(400, "Pipeline is not available")

    value_dict = {
        "prompt": prompts[0],
        "negative_prompt": negative_prompts[0] if len(negative_prompts) > 0 else "",
        "aesthetic_score": 6.0,
        "negative_aesthetic_score": 2.5,
        "orig_height": height,
        "orig_width": width,
        "target_height": height,
        "target_width": width,
        "crop_coords_top": 0,
        "crop_coords_left": 0,
    }

    try:
        sampler = get_sampler(
            sampler_name=sampler_name,
            steps=steps,
            discretization_config=get_discretization("LegacyDDPMDiscretization"),
            guider_config=get_guider(guider="VanillaCFG", scale=cfg_scale),
        )

        output = do_sample(
            model=model["base"],
            sampler=sampler,
            value_dict=value_dict,
            num_samples=1,
            H=height,
            W=width,
            C=4,
            F=8,
            return_latents=use_pipeline,
        )

        if isinstance(output, (tuple, list)):
            samples, samples_z = output
        else:
            samples = output
            samples_z = None

        if use_pipeline and samples_z is not None:
            print("Running Refinement Stage")
            refiner_sampler = get_sampler(
                sampler_name="EulerEDMSampler",
                steps=50,
                discretization_config=get_discretization("LegacyDDPMDiscretization"),
                guider_config=get_guider(guider="VanillaCFG", scale=5.0),
            )
            refiner_sampler.discretization = Img2ImgDiscretizationWrapper(sampler.discretization, strength=0.3)

            samples = apply_refiner(
                input=samples_z,
                model=model["refiner"],
                sampler=refiner_sampler,
                num_samples=1,
                prompt=prompts[0],
                negative_prompt=negative_prompts[0]
                if len(negative_prompts) > 0
                else "",
            )

        samples = embed_watermark(samples)
        images = []
        for sample in samples:
            sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
            image_bytes = BytesIO()
            Image.fromarray(sample.astype(np.uint8)).save(image_bytes, format="PNG")
            image_bytes.seek(0)
            images.append(image_bytes.read())

        return images

    except ValueError as e:
        raise BaseInferenceToolkitError(400, str(e))


def output_fn(prediction, accept):
    # This only returns a single image since that's all the example code supports
    if accept != "image/png":
        raise BaseInferenceToolkitError(400, "Accept header must be image/png")
    return prediction[0], accept


## 3. Package and upload model archive

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


In [None]:
# Rerun this cell only if you need to re-upload the weights, otherwise you can reuse the existing model_package_name and upload only your new code 
from sagemaker.utils import name_from_base
model_package_name = name_from_base(f"sdxl-v0-9") # You may want to make this a fixed name of your choosing instead
model_uri = f's3://{sagemaker_session_bucket}/{model_package_name}/'
print(f'Uploading base model to {model_uri}, this will take a while...')
!aws s3 cp model/sd_xl_base_0.9.safetensors {model_uri}
print(f'Uploading refiner model to {model_uri}, this will take a while...')
!aws s3 cp model/sd_xl_refiner_0.9.safetensors {model_uri}

In [None]:
# Rerun this cell when you have changed the code or are uploading a fresh copy of the weights
print(f'Uploading code to {model_uri}code')
!aws s3 cp model/code/sdxl_inference.py {model_uri}code/sdxl_inference.py
print("Done!")

## 4. Create and deploy a model and perform real-time inference

boto3 is being used to deploy the model here to take advantage of [Uncompressed model downloads](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-uncompressed.html)

In [None]:
inference_image_uri = '740929234339.dkr.ecr.us-east-1.amazonaws.com/stabilityai-pytorch-inference:2.0.1-sgm0.0.1-gpu-py310-cu118-ubuntu20.04-sagemaker-2023-07-24-03-47-13'

In [None]:
endpoint_name = name_from_base(f"sdxl-v0-9")
sagemaker_client = boto3.client('sagemaker')
create_model_response = sagemaker_client.create_model(
    ModelName = endpoint_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": inference_image_uri,
        "ModelDataSource": {
            "S3DataSource": {               # S3 Data Source configuration:
                "S3Uri": model_uri,         # path to your model and script
                "S3DataType": "S3Prefix",   # causes SageMaker to download from a prefix
                "CompressionType": "None"   # disables compression
            }
        },
        "Environment": {
            "SAGEMAKER_PROGRAM": "sdxl_inference.py",  # This script was uploaded to the data source above             
            "TS_DEFAULT_RESPONSE_TIMEOUT": "1000", # We need a long timeout for the model to load
            "HUGGINGFACE_HUB_CACHE": "/tmp/cache/huggingface/hub" # Put this cache somewhere with a lot of space
        }
    }
)

create_endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName = endpoint_name,
    ProductionVariants = [{
        "ModelName": endpoint_name,
        "VariantName": "sdxl",
        "InitialInstanceCount": 1,
        "InstanceType": "ml.g5.4xlarge",     # 4xlarge is required to load the model
    }]
)
        

deploy_model_response = sagemaker_client.create_endpoint(
    EndpointName = endpoint_name,
    EndpointConfigName = endpoint_name
)
    
print('Waiting for the endpoint to be in service, this can take 5-10 minutes...')
waiter = sagemaker_client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=endpoint_name)
print(f'Endpoint {endpoint_name} is in service, but the model is still loading. This may take another 5-10 minutes.')

In [None]:
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import BytesDeserializer

# Create a predictor with proper serializers
predictor = Predictor(
    endpoint_name=endpoint_name, 
    sagemaker_session=sess,
    serializer=JSONSerializer(),
    deserializer=BytesDeserializer(accept="image/png")

)

## A. Text to image

**Note**: The endpoint will be "InService" before the model has finished loading, so this request will initially time out. Check the endpoint logs in CloudWatch for status.

In [None]:
output = predictor.predict({"text_prompts": [{"text":"A photograph of fresh pizza with basil and tomatoes, from a traditional oven", "weight": 1}],                                             
                                             "seed": 2})

In [None]:
def decode_and_show(model_response) -> None:
    """
    Decodes and displays an image from SDXL output

    Args:
        model_response (GenerationResponse): The response object from the deployed SDXL model.

    Returns:
        None
    """        
    image = Image.open(io.BytesIO(model_response))
    display(image)

decode_and_show(output)

## B. Image to image **NOT YET IMPLEMENTED IN THIS EXAMPLE**

In [None]:
def encode_image(image_path: str, resize: bool = True) -> Union[str, None]:
    """
    Encode an image as a base64 string, optionally resizing it to 512x512.

    Args:
        image_path (str): The path to the image file.
        resize (bool, optional): Whether to resize the image. Defaults to True.

    Returns:
        Union[str, None]: The encoded image as a string, or None if encoding failed.
    """
    assert os.path.exists(image_path)

    if resize:
        image = Image.open(image_path)
        image = image.resize((512, 512))
        image_base = os.path.splitext(image_path)[0]
        image_resized_path = f"{image_base}_resized.png"
        image.save(image_resized_path)
        image_path = image_resized_path
    image = Image.open(image_path)
    assert image.size == (512, 512)
    with open(image_path, "rb") as image_file:
        img_byte_array = image_file.read()
        # Encode the byte array as a Base64 string
        try:
            base64_str = base64.b64encode(img_byte_array).decode("utf-8")
            return base64_str
        except Exception as e:
            print(f"Failed to encode image {image_path} as base64 string.")
            print(e)
            return None

In [None]:
! wget https://platform.stability.ai/Cat_August_2010-4.jpg

In [None]:
# Here is the original image:
display(Image.open('Cat_August_2010-4.jpg'))

In [None]:
cat_path = "Cat_August_2010-4.jpg"
cat_data = encode_image(cat_path)

output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="cat in watercolour")],
                                                  init_image= cat_data,
                                                  cfg_scale=9,
                                                  image_strength=0.8,
                                                  seed=42
                                                  ))
decode_and_show(output)

In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="cat painted by Basquiat")],
                                                  init_image= cat_data,
                                            cfg_scale=17,
                                            image_strength=0.4,
                                             seed=42
                                             ))
decode_and_show(output)

In [None]:
!aws sagemaker list-endpoints

In [None]:
# Delete an endpoint
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
# Rerun the aws cli command above to confirm that its gone.