# Deploying Stable Diffusion using Stability AI DLC on AWS SageMaker

## Example: Stable Diffusion 2.1 on PyTorch 1.13.1 with Diffusers and Xformers

This example will deploy an endpoint running Stable Diffusion 2.1 with an easy-to-use HTTP JSON interface that mirrors the [Stability AI REST API](https://api.stability.ai/docs).

The [Stability SDK](https://github.com/Stability-AI/stability-sdk) provides this contract and is pre-installed in the Stability AI DLC.


In [None]:
!pip install "sagemaker==2.116.0" "huggingface_hub==0.10.1" "stability-sdk[sagemaker] @ git+https://github.com/Stability-AI/stability-sdk.git@sagemaker" --upgrade --quiet

import sagemaker
from sagemaker import ModelPackage, get_execution_role
from stability_sdk_sagemaker.predictor import StabilityPredictor
from stability_sdk_sagemaker.models import get_model_package_arn
from stability_sdk.api import GenerationRequest, GenerationResponse, TextPrompt

from PIL import Image
from typing import Union
import io
import os
import base64
import boto3

## 1. Custom Inference Script Creation

In [None]:
!mkdir -p code

### Inference Script: Stability API contract, Text2Image, Image2Image

In [None]:
%%writefile code/stable_diffusion_inference.py
import base64
import torch
from io import BytesIO
import json
from PIL import Image
from pydantic import Field, ValidationError
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionImg2ImgPipeline
from stability_sdk.api import GenerationRequest, GenerationResponse, GenerationErrorResponse, BinaryArtifact, TextPrompt


# GenerationRequest is the pydantic class used by StabilityPredictor to send requests, so we extend it 
# to customize validation and defaults. This is optional, but recommended.
# For custom implementations, you can extend StabilityPredictor to use your own data model.
class DiffusersGenerationRequest(GenerationRequest):
    text_prompts: list[TextPrompt] = Field(..., min_items=1, max_items=2)
    height: int = Field(512, ge=128, le=1024, multiple_of=64)
    width: int = Field(512, ge=128, le=1024, multiple_of=64)
    steps: int = Field(30, ge=0, le=150)
    samples: int = Field(1, ge=1, le=8)
    cfg_scale: float = Field(7.5, ge=0.0, le=35.0)
    seed: int = Field(None, ge=0, le=2**32)
    init_image: str = Field(None)
    image_strength: float = Field(0.8, ge=0.0, le=1.0)    

def model_fn(model_dir):

    device = "cuda"
    image2image_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
        model_dir,
        torch_dtype=torch.float16,
    )
    image2image_pipe.enable_xformers_memory_efficient_attention()
    image2image_pipe = image2image_pipe.to(device)

    # Load stable diffusion and move it to the GPU
    pipe = StableDiffusionPipeline.from_pretrained(model_dir, torch_dtype=torch.float16)
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.enable_xformers_memory_efficient_attention()
    pipe = pipe.to(device)


    return { "text2image": pipe, "image2image": image2image_pipe }


def input_fn(request_body, request_content_type):
    if request_content_type == "application/json":
        model_input = json.loads(request_body)                
        return model_input
    else:
        raise Exception("Content-type must be application/json")

def output_fn(prediction, accept):    
    return prediction.json(exclude_unset=True), accept
    
def predict_fn(data:DiffusersGenerationRequest , pipe):
    device = "cuda"
    
    # Validate the input using the pydantic model
    # This is done in predict_fn so we can return a custom error response
    try:
        request = DiffusersGenerationRequest.parse_obj(data)
    except ValidationError as e:
        error = e.errors()[0]
        error_msg = f'{error["loc"][0]}: {error["msg"]}'
        
        return GenerationResponse(result="error", error=GenerationErrorResponse(id="0", name=error["type"], message=error_msg))

    # weights could be supported using prompt_embeds, for now only 1 positive and 1 negative will be used
    prompts = []
    negative_prompts = []    
    for text_prompt in request.text_prompts:
        if text_prompt.weight < 0:                
            if len(negative_prompts) == 0:
                negative_prompts.append(text_prompt.text)
        else:
            if len(prompts) == 0:
                prompts.append(text_prompt.text)    
    
    latents = None
    seeds = []
    seed = request.seed
    mode = 'image2image' if request.init_image else 'text2image'
    
    try:
        generator = torch.Generator(device=device)
        if mode == 'text2image':
            if seed:            
                for _ in range(request.samples):
                    generator.manual_seed(seed)
                    seeds.append(seed)
                    
                    # this should be random based on the last seed, not incremental
                    seed = seed + 1

                    image_latents = torch.randn(
                        (1, pipe[mode].unet.in_channels, request.height // 8, request.width // 8),
                        generator = generator,
                        device = device
                    )
                    latents = image_latents if latents is None else torch.cat((latents, image_latents))                
            else:
                for _ in range(request.samples):
                    # Get a new random seed, store it and use it as the generator state
                    _seed = generator.seed()
                    seeds.append(_seed)
                    generator = generator.manual_seed(_seed)

                    image_latents = torch.randn(
                        (1, pipe[mode].unet.in_channels, request.height // 8, request.width // 8),
                        generator = generator,
                        device = device
                    )
                    latents = image_latents if latents is None else torch.cat((latents, image_latents))
            
            # run generation with parameters
            with torch.autocast("cuda"):
                generated_images = pipe['text2image'](
                    prompt=[prompts[0]] * request.samples,
                    height=request.height,
                    width=request.width,
                    num_inference_steps=request.steps,
                    guidance_scale=request.cfg_scale,                                                
                    negative_prompt=[negative_prompts[0]] * request.samples if len(negative_prompts) > 0 else None,
                    latents = latents
                )["images"]

            # create response
            artifacts = []                
            ix = 0
            response_seed = 0
            for image in generated_images:
                buffered = BytesIO()
                image.save(buffered, format="PNG")
                if ix in seeds:
                    response_seed = seeds[ix]                
                artifacts.append(BinaryArtifact(seed=response_seed, base64=base64.b64encode(buffered.getvalue()).decode(), finishReason="SUCCESS"))                                             
        else:
            # image2image    
            seed = seed or generator.seed()
            
            # run generation with parameters
            init_image = base64.b64decode(request.init_image)
            buffer = BytesIO(init_image)
            init_image = Image.open(buffer).convert("RGB")
            init_image = init_image.resize((request.width, request.height))        
            
            generated_images = pipe['image2image'](
                num_images_per_prompt=request.samples,
                prompt=prompts[0],
                image=init_image,
                num_inference_steps=request.steps,
                guidance_scale=request.cfg_scale,
                strength=request.image_strength,
                negative_prompt=negative_prompts[0] if len(negative_prompts) > 0 else None,        
            )["images"]

            # create response
            artifacts = []
            for image in generated_images:
                buffered = BytesIO()
                image.save(buffered, format="PNG")
                artifacts.append(BinaryArtifact(seed=seed, base64=base64.b64encode(buffered.getvalue()).decode(), finishReason="SUCCESS"))

        return GenerationResponse(result="success", artifacts=artifacts)
    
    except Exception as e:
        return GenerationResponse(result="error", error=GenerationErrorResponse(id="0", message=str(e), name="inference_error"))


## SageMaker Session Creation

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")


## Retrieving Image URI and Model URI

In [None]:
model_id, model_version = 'model-txt2img-stabilityai-stable-diffusion-v2-1-base', "*"
from sagemaker import image_uris, model_uris, script_uris, hyperparameters
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base


endpoint_name = name_from_base(f"stable-diffusion-v2-1-base")

# # Retrieve the model uri. This includes the pre-trained model and parameters as well as the inference scripts.
# # This includes all dependencies and scripts for model loading, inference handling etc..
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)
# model_uri
print(model_uri)
# Or use your own stored model
#s3_model_uri = 's3://sagemaker-us-west-2-499172972132/stable-diffusion-v2-1/model.tar.gz'
#print(s3_model_uri)



## Create and deploy a model and perform real-time inference

In [None]:
model_uri
inference_image_uri = '740929234339.dkr.ecr.us-east-1.amazonaws.com/stabilityai-pytorch-inference:1.13.1-diffusers0.14.0-gpu-xformers-py39-cu117-ubuntu20.04-2023-06-12-01-22-38'

In [None]:
from sagemaker.pytorch.model import PyTorchModel

# create PyTorch Model class
pytorch_model = PyTorchModel(
   model_data=model_uri,      # path to your model and script
   image_uri=inference_image_uri, # path to your private ecr image
   entry_point = 'stable_diffusion_inference.py', #custom inference script
   source_dir = "./code/",
   role=role,                    # iam role with permissions to create an Endpoint
   predictor_cls=StabilityPredictor  # StabilityPredictor provides serialization 
)

# Deploy the endpoint 
# This will take a while as it repackages the model then waits for deployment.
deployed_model = pytorch_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.xlarge",    
    )



In [None]:
# You can also attach the predictor to an existing endpoint by name
#deployed_model = StabilityPredictor(endpoint_name=deployed_model.endpoint_name)
deployed_model.endpoint_name

## A. Text to image

In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="A photograph of fresh pizza with basil and tomatoes, from a traditional oven")],                                             
                                             seed = 2
                                             ))

In [None]:
def decode_and_show(model_response: GenerationResponse) -> None:
    """
    Decodes and displays an image from SDXL output

    Args:
        model_response (GenerationResponse): The response object from the deployed SDXL model.

    Returns:
        None
    """
    image = model_response.artifacts[0].base64
    image_data = base64.b64decode(image.encode())
    image = Image.open(io.BytesIO(image_data))
    display(image)

decode_and_show(output)

## B. Image to image

In [None]:
def encode_image(image_path: str, resize: bool = True) -> Union[str, None]:
    """
    Encode an image as a base64 string, optionally resizing it to 512x512.

    Args:
        image_path (str): The path to the image file.
        resize (bool, optional): Whether to resize the image. Defaults to True.

    Returns:
        Union[str, None]: The encoded image as a string, or None if encoding failed.
    """
    assert os.path.exists(image_path)

    if resize:
        image = Image.open(image_path)
        image = image.resize((512, 512))
        image_base = os.path.splitext(image_path)[0]
        image_resized_path = f"{image_base}_resized.png"
        image.save(image_resized_path)
        image_path = image_resized_path
    image = Image.open(image_path)
    assert image.size == (512, 512)
    with open(image_path, "rb") as image_file:
        img_byte_array = image_file.read()
        # Encode the byte array as a Base64 string
        try:
            base64_str = base64.b64encode(img_byte_array).decode("utf-8")
            return base64_str
        except Exception as e:
            print(f"Failed to encode image {image_path} as base64 string.")
            print(e)
            return None

In [None]:
! wget https://platform.stability.ai/Cat_August_2010-4.jpg

In [None]:
# Here is the original image:
display(Image.open('Cat_August_2010-4.jpg'))

In [None]:
cat_path = "Cat_August_2010-4.jpg"
cat_data = encode_image(cat_path)

output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="cat in watercolour")],
                                                  init_image= cat_data,
                                                  cfg_scale=9,
                                                  image_strength=0.8,
                                                  seed=42
                                                  ))
decode_and_show(output)

In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="cat painted by Basquiat")],
                                                  init_image= cat_data,
                                            cfg_scale=17,
                                            image_strength=0.4,
                                             seed=42
                                             ))
decode_and_show(output)

In [None]:
!aws sagemaker list-endpoints

In [None]:
# Delete an endpoint
deployed_model.sagemaker_session.delete_endpoint(deployed_model.endpoint_name)

# Rerun the aws cli command above to confirm that its gone.