<a href="https://colab.research.google.com/github/shameeryaseen/Picture_Healer/blob/main/stable_diffusion_api_on_sagemaker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install "sagemaker==2.116.0" "huggingface_hub==0.10.1" --upgrade

Collecting sagemaker==2.116.0
  Downloading sagemaker-2.116.0.tar.gz (592 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/592.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━[0m [32m501.8/592.4 kB[0m [31m15.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m592.4/592.4 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting huggingface_hub==0.10.1
  Downloading huggingface_hub-0.10.1-py3-none-any.whl.metadata (6.1 kB)
Collecting attrs<23,>=20.3.0 (from sagemaker==2.116.0)
  Downloading attrs-22.2.0-py3-none-any.whl.metadata (13 kB)
Collecting boto3<2.0,>=1.20.21 (from sagemaker==2.116.0)
  Downloading boto3-1.35.81-py3-none-any.whl.metadata (6.7 kB)
Collecting protobuf<4.0,>=3.1 (from sagemaker==2.116.0)
  Downloading protobuf-3.20.3-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.w

In [None]:
!pip install python-dotenv

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
from dotenv import load_dotenv
load_dotenv()

boto_session = boto3.setup_default_session(
    aws_access_key_id=os.environ.get('aws_access_key_id'),
    aws_secret_access_key=os.environ.get('aws_secret_access_key'),
    region_name=os.environ.get('region_name')
)

sess = sagemaker.Session(boto_session=boto_session)

print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
!mkdir code

In [None]:
%%writefile code/requirements.txt
diffusers==0.6.0
transformers==4.23.1

In [None]:
%%writefile code/inference.py
import base64
import torch
from io import BytesIO
from diffusers import StableDiffusionPipeline


def model_fn(model_dir):
    # Load stable diffusion and move it to the GPU
    pipe = StableDiffusionPipeline.from_pretrained(model_dir, torch_dtype=torch.float16)
    pipe = pipe.to("cuda")

    return pipe


def predict_fn(data, pipe):

    # get prompt & parameters
    prompt = data.pop("inputs", data)
    # set valid HP for stable diffusion
    num_inference_steps = data.pop("num_inference_steps", 20)
    guidance_scale = data.pop("guidance_scale", 7.5)
    num_images_per_prompt = data.pop("num_images_per_prompt", 1)

    # run generation with parameters
    generated_images = pipe(
        prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        num_images_per_prompt=num_images_per_prompt,
    )["images"]

    # create response
    encoded_images = []
    for image in generated_images:
        buffered = BytesIO()
        image.save(buffered, format="JPEG")
        encoded_images.append(base64.b64encode(buffered.getvalue()).decode())

    # create response
    return {"generated_images": encoded_images}

In [None]:
from distutils.dir_util import copy_tree
from pathlib import Path
from huggingface_hub import snapshot_download
import random

HF_MODEL_ID="CompVis/stable-diffusion-v1-4"
HF_TOKEN=os.environ.get('HF_TOKEN') # your hf token: https://huggingface.co/settings/tokens
assert len(HF_TOKEN) > 0, "Please set HF_TOKEN to your huggingface token. You can find it here: https://huggingface.co/settings/tokens"

# download snapshot
snapshot_dir = snapshot_download(repo_id=HF_MODEL_ID,revision="fp16",use_auth_token=HF_TOKEN)

# create model dir
model_tar = Path(f"model-{random.getrandbits(16)}")
model_tar.mkdir(exist_ok=True)

# copy snapshot to model dir
copy_tree(snapshot_dir, str(model_tar))

In [None]:
# copy code/ to model dir
copy_tree("code/", str(model_tar.joinpath("code")))

In [None]:
import tarfile
import os

# helper to create the model.tar.gz
def compress(tar_dir=None,output_file="model.tar.gz"):
    parent_dir=os.getcwd()
    os.chdir(tar_dir)
    with tarfile.open(os.path.join(parent_dir, output_file), "w:gz") as tar:
        for item in os.listdir('.'):
          print(item)
          tar.add(item, arcname=item)
    os.chdir(parent_dir)

compress(str(model_tar))

In [None]:
from sagemaker.s3 import S3Uploader

# upload model.tar.gz to s3
s3_model_uri=S3Uploader.upload(local_path="model.tar.gz", desired_s3_uri=f"s3://{sess.default_bucket()}/stable-diffusion-v1-4")

print(f"model uploaded to: {s3_model_uri}")


In [None]:
from sagemaker.huggingface.model import HuggingFaceModel


# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   model_data='',
   role='', # iam role with permissions to create an Endpoint
   transformers_version="4.17",  # transformers version used
   pytorch_version="1.10",       # pytorch version used
   py_version='py38',            # python version used
)

# deploy the endpoint endpoint
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge"
)

In [None]:
import boto3
from PIL import Image
from io import BytesIO
from IPython.display import display
import base64
import matplotlib.pyplot as plt
import time

# helper decoder
def decode_base64_image(image_string):
  base64_image = base64.b64decode(image_string)
  buffer = BytesIO(base64_image)
  return Image.open(buffer)

# display PIL images as grid
def display_images(images=None,columns=3, width=100, height=100):
    plt.figure(figsize=(width, height))
    for i, image in enumerate(images):
        plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        plt.axis('off')
        plt.imshow(image)

start_time = time.time()
# run prediction

client = boto3.client('sagemaker-runtime')

prompt = "A dog trying to catch a flying pizza art"

num_images_per_prompt = 1

payload = {
    "inputs": prompt,
    "num_images_per_prompt": num_images_per_prompt
}

serialized_payload = json.dumps(payload)  # Serialize the payload to JSON format

endpoint_name = "huggingface-pytorch-inference-2023-11-05-04-19-03-160"

response = client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=serialized_payload,
    ContentType='application/json'  # Specify the format of the payload
)

response_payload = json.loads(response['Body'].read().decode("utf-8"))

end_time = time.time()

elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")

# decode images
decoded_images = [decode_base64_image(image) for image in response_payload["generated_images"]]

# visualize generation
display_images(decoded_images)