# Deploying Deep Floyd IF using Stability AI DLC on AWS SageMaker
### Example: Deep Floyd IF v1.0 on PyTorch 2.0 with Diffusers

This example will deploy an endpoint running Stable Diffusion 2.1 with an easy-to-use HTTP JSON interface that mirrors the [Stability AI REST API](https://api.stability.ai/docs).

The [Stability SDK](https://github.com/Stability-AI/stability-sdk) provides this contract and is pre-installed in the Stability AI DLC.

**The Deep Floyd weights are only available for non-commercial, research use. Weights must be downloading from the HuggingFace Hub by supplying a token, or otherwise provided. Use of the weights with this container must follow the agreed license.**

In [None]:
!pip install sagemaker huggingface_hub "stability-sdk[sagemaker] @ git+https://github.com/Stability-AI/stability-sdk.git@sagemaker" --upgrade --quiet

import sagemaker
from sagemaker import ModelPackage, get_execution_role
from stability_sdk_sagemaker.predictor import StabilityPredictor
from stability_sdk_sagemaker.models import get_model_package_arn
from stability_sdk.api import GenerationRequest, GenerationResponse, TextPrompt

from PIL import Image
from typing import Union
import io
import os
import base64
import boto3

## 1. Download the model weights.

You will need to add your 🤗 Hub token to download the weights.

In [None]:
os.environ['HUGGING_FACE_HUB_TOKEN'] = 'YOUR_TOKEN_HERE'

In [None]:
import os
from huggingface_hub import snapshot_download
local_dir = './model/if'
allow_patterns=["*.json", "*.fp16*safetensors", "watermarker/diffusion_pytorch_model.safetensors", "tokenizer/spiece.model"]
snapshot_download(
    repo_id="DeepFloyd/IF-I-XL-v1.0",    
    allow_patterns=allow_patterns,
    local_dir=os.path.join(local_dir, 'IF-I-XL'),
    local_dir_use_symlinks=False)
snapshot_download(
    repo_id="DeepFloyd/IF-II-L-v1.0",
    allow_patterns=allow_patterns,
    ignore_patterns=["text_encoder/*"],
    local_dir=os.path.join(local_dir, 'IF-II-L'),
    local_dir_use_symlinks=False)


## 2. Create custom inference script

In [None]:
!mkdir -p model/code

### Inference Script: Stability API contract, Text2Image, Image2Image

In [None]:
%%writefile model/code/deepfloyd_if_inference.py
from diffusers import DiffusionPipeline
from diffusers.utils import pt_to_pil
from io import BytesIO
import base64, torch, os, time, json, uuid
from stability_sdk.api import CreateRequest, CreateResponse
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

def model_fn(model_dir):
    if_model_dir = os.path.join(model_dir, 'if')    
    stage_1 = DiffusionPipeline.from_pretrained(os.path.join(if_model_dir, "IF-I-XL"), variant="fp16", torch_dtype=torch.float16, local_files_only=True)
    
    # Enabling CPU offload allows stage 1 and 2 to run on a single NVIDIA A10 g5.2xl instance. For better performance, 
    # deploy each stage as a distinct model on separate g5.xlarge endpoints and chain them together.
    stage_1.enable_model_cpu_offload()        
                                                  
    stage_2 = DiffusionPipeline.from_pretrained(
        os.path.join(if_model_dir, "IF-II-L"), text_encoder=None, variant="fp16", torch_dtype=torch.float16, local_files_only=True
    )
    stage_2.enable_model_cpu_offload()
    
    # Stage 3 causes generation time to exceed 1m which currently is not possible on realtime endpoints, so it's disabled for this example.
    # It is possible to use stage 3 on g5.4xlarge with async inference.
    
    # safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker}                                              
    # stage_3 = DiffusionPipeline.from_pretrained(os.path.join(if_model_dir, "sd-x4-upscaler", **safety_modules, torch_dtype=torch.float16, local_files_only=True)
    # stage_3.enable_model_cpu_offload()    
    stage_3 = None
                                              
    return {'stage_1': stage_1, 'stage_2': stage_2, 'stage_3': stage_3}

def input_fn(request_body, request_content_type):
    if request_content_type == "application/json":
        model_input = json.loads(request_body)
        request = CreateRequest(model_input)
        return request
    if request_content_type == "application/x-protobuf":
        request = generation.Request()
        request.ParseFromString(request_body)
        return request
    raise Exception("Content-type must be application/json")

def predict_fn(input_object, model):
    start_time = time.time()
    prompt = input_object.prompt[0].text
    seed = 0
    image_params = input_object.image
    if (image_params):
        if image_params.seed and image_params.seed[0]:
            seed = image_params.seed[0]                                
    
    prompt_embeds, negative_embeds = model['stage_1'].encode_prompt(prompt)
    generator = torch.manual_seed(seed)
    image = model['stage_1'](prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
    image = model['stage_2'](image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images
    
    # Skip the upscaler if not loaded
    if model['stage_3']:
        image = model['stage_3'](prompt=prompt, image=image, generator=generator, noise_level=100).images
    else:
        image = pt_to_pil(image)
    batch = generation.AnswerBatch()
    batch.batch_id = input_object.request_id
    image_buf = BytesIO()
    image[0].save(image_buf, format="PNG")
    artifact = generation.Artifact(
        type=generation.ARTIFACT_IMAGE,
        mime="image/png",
        binary=image_buf.getvalue(),
        finish_reason=generation.NULL)
    answer = generation.Answer(
        answer_id = str(uuid.uuid4()),
        request_id = input_object.request_id,
        created=int(time.time() * 1000),
        received=int(start_time * 1000))
    answer.artifacts.append(artifact)
    batch.answers.append(answer)
    return batch

def output_fn(prediction, accept):
    if accept == "application/x-protobuf":
        return prediction.SerializeToString(), accept    
    response = CreateResponse(prediction.answers[0])
    if response.result == "error" or accept.startswith('application/json'):        
        return response.json(exclude_unset=True), accept
    else: # Default to image/png
        return prediction.answers[0].artifacts[0].binary, accept


## 3. Package and upload model archive

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
from sagemaker.utils import name_from_base
model_package_name = name_from_base(f"deepfloyd-if")
model_uri = f's3://{sagemaker_session_bucket}/{model_package_name}/model.tar.gz'

In [None]:
print(f'Packaging and uploading model to {model_uri}, this will take a while...')
!tar -cf - -C model if code | gzip --fast | aws s3 cp - {model_uri}
print("Done!")

## 4. Create and deploy a model and perform real-time inference

In [None]:
# Use the Stability AI DLC inference image
inference_image_uri = '740929234339.dkr.ecr.us-east-1.amazonaws.com/stabilityai-pytorch-inference:2.0.0-diffusers0.17.0-gpu-py310-cu118-ubuntu20.04-2023-06-12-14-30-49'

In [None]:
from sagemaker.model import Model
endpoint_name = name_from_base(f"deepfloyd-if")
# create model class
model = Model(
   model_data=model_uri,      # path to your model and script
   image_uri=inference_image_uri, # path to your private ecr image 
   env={
       "SAGEMAKER_PROGRAM": "deepfloyd_if_inference.py",  # override inference with packaged code
       "TS_DEFAULT_RESPONSE_TIMEOUT": "1000",             # increase timeouts 
   }, 
   role=role,                    # iam role with permissions to create an Endpoint
   predictor_cls=StabilityPredictor  # StabilityPredictor provides serialization
)

# Deploy the endpoint 
deployed_model = model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    endpoint_name=endpoint_name
    )

In [None]:
# You can also attach the predictor to an existing endpoint by name
#deployed_model = StabilityPredictor(endpoint_name=deployed_model.endpoint_name)
deployed_model.endpoint_name

## A. Text to image

The first inference request may time out due to deferred loading of the weights. Subsequent requests should succeed.

In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="A photograph of fresh pizza with basil and tomatoes, from a traditional oven")],                                             
                                             seed = 2
                                             ))

In [None]:
def decode_and_show(model_response: GenerationResponse) -> None:
    """
    Decodes and displays an image from SDXL output

    Args:
        model_response (GenerationResponse): The response object from the deployed SDXL model.

    Returns:
        None
    """
    image = model_response.artifacts[0].base64
    image_data = base64.b64decode(image.encode())
    image = Image.open(io.BytesIO(image_data))
    display(image)

decode_and_show(output)

## B. Image to image

Image to image is not supported with the supplied inference script, so these examples will not work as expected currently.

In [None]:
def encode_image(image_path: str, resize: bool = True) -> Union[str, None]:
    """
    Encode an image as a base64 string, optionally resizing it to 512x512.

    Args:
        image_path (str): The path to the image file.
        resize (bool, optional): Whether to resize the image. Defaults to True.

    Returns:
        Union[str, None]: The encoded image as a string, or None if encoding failed.
    """
    assert os.path.exists(image_path)

    if resize:
        image = Image.open(image_path)
        image = image.resize((512, 512))
        image_base = os.path.splitext(image_path)[0]
        image_resized_path = f"{image_base}_resized.png"
        image.save(image_resized_path)
        image_path = image_resized_path
    image = Image.open(image_path)
    assert image.size == (512, 512)
    with open(image_path, "rb") as image_file:
        img_byte_array = image_file.read()
        # Encode the byte array as a Base64 string
        try:
            base64_str = base64.b64encode(img_byte_array).decode("utf-8")
            return base64_str
        except Exception as e:
            print(f"Failed to encode image {image_path} as base64 string.")
            print(e)
            return None

In [None]:
! wget https://platform.stability.ai/Cat_August_2010-4.jpg

In [None]:
# Here is the original image:
display(Image.open('Cat_August_2010-4.jpg'))

In [None]:
cat_path = "Cat_August_2010-4.jpg"
cat_data = encode_image(cat_path)

output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="cat in watercolour")],
                                                  init_image= cat_data,
                                                  cfg_scale=9,
                                                  image_strength=0.8,
                                                  seed=42
                                                  ))
decode_and_show(output)

In [None]:
output = deployed_model.predict(GenerationRequest(text_prompts=[TextPrompt(text="cat painted by Basquiat")],
                                                  init_image= cat_data,
                                            cfg_scale=17,
                                            image_strength=0.4,
                                             seed=42
                                             ))
decode_and_show(output)

In [None]:
!aws sagemaker list-endpoints

In [None]:
# Delete an endpoint
deployed_model.sagemaker_session.delete_endpoint(deployed_model.endpoint_name)

# Rerun the aws cli command above to confirm that its gone.