In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [ ]:
from sagemaker.huggingface.model import HuggingFaceModel

s3_model_uri="s3://sagemaker-ap-northeast-2-590183743566/photomaker/model.tar.gz"

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
   model_data=s3_model_uri,      # path to your model and script
   role=role,                    # iam role with permissions to create an Endpoint
   transformers_version="4.37.0",  # transformers version used
   pytorch_version="2.1.0",       # pytorch version used
   py_version='py310',            # python version used
)

# deploy the endpoint endpoint
predictor = huggingface_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g4dn.xlarge"
    )

In [ ]:
from PIL import Image
from io import BytesIO
from IPython.display import display
import base64
import matplotlib.pyplot as plt

# helper decoder
def decode_base64_image(image_string):
  base64_image = base64.b64decode(image_string)
  buffer = BytesIO(base64_image)
  return Image.open(buffer)

# display PIL images as grid
def display_images(images=None,columns=3, width=100, height=100):
    plt.figure(figsize=(width, height))
    for i, image in enumerate(images):
        plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        plt.axis('off')
        plt.imshow(image)

In [ ]:
prompt = "스트릿 패션"
input_id_images = ["https://sagemaker-ap-northeast-2-590183743566.s3.ap-northeast-2.amazonaws.com/photomaker/photo/1.jpg", "https://sagemaker-ap-northeast-2-590183743566.s3.ap-northeast-2.amazonaws.com/photomaker/photo/2.jpg", "https://sagemaker-ap-northeast-2-590183743566.s3.ap-northeast-2.amazonaws.com/photomaker/photo/3.png"]
data = {
    "inputs": prompt,
    "input_id_images" : input_id_images
}
prompt = "A luxurious and sophisticated look: Tailored blazer with sharp lines, silk blouse, wide-leg trousers or a midi skirt, designer heels, statement jewelry, sleek hair and makeup, structured handbag."


# run prediction
response = predictor.predict(data={
  "inputs": prompt,
  "input_id_images" : input_id_images
  }
)
# decode images
decoded_images = [decode_base64_image(image) for image in response["generated_images"]]

# visualize generation
display_images(decoded_images)

In [ ]:
predictor.delete_model()
predictor.delete_endpoint()