In [None]:
import os

# clone the repository
%cd /content
if not os.path.exists('MODNet'):
  !git clone https://github.com/ZHKKKe/MODNet
%cd MODNet/

# dowload the pre-trained ckpt for image matting
pretrained_ckpt = 'pretrained/modnet_photographic_portrait_matting.ckpt'
if not os.path.exists(pretrained_ckpt):
  !gdown --id 1mcr7ALciuAsHCpLnrtG_eop5-EYhbCmz \
          -O pretrained/modnet_photographic_portrait_matting.ckpt


In [None]:
# import os
# import shutil

# # clean and rebuild the image folders
# input_folder = 'demo/image_matting/colab/input'
# if os.path.exists(input_folder):
#   shutil.rmtree(input_folder)
# os.makedirs(input_folder)

# output_folder = 'demo/image_matting/colab/input'
# if os.path.exists(output_folder):
#   shutil.rmtree(output_folder)
# os.makedirs(output_folder)


In [None]:
!python -m demo.image_matting.colab.inference \
        --input-path images/input \
        --output-path images/output \
        --ckpt-path ./pretrained/modnet_photographic_portrait_matting.ckpt


In [None]:
import numpy as np
from PIL import Image

def combined_display(image, matte):
  # calculate display resolution
  w, h = image.width, image.height
  rw, rh = 800, int(h * 800 / (3 * w))
  
  # obtain predicted foreground
  image = np.asarray(image)
  if len(image.shape) == 2:
    image = image[:, :, None]
  if image.shape[2] == 1:
    image = np.repeat(image, 3, axis=2)
  elif image.shape[2] == 4:
    image = image[:, :, 0:3]
  matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
  foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
  
  # combine image, foreground, and alpha into one line
  combined = np.concatenate((image, foreground, matte * 255), axis=1)
  combined = Image.fromarray(np.uint8(combined)).resize((rw, rh))
  return combined

# visualize all images
input_folder = 'images/input'
output_folder = 'images/output'

image_names = os.listdir(input_folder)
for image_name in image_names:
  if os.path.isdir(os.path.join(input_folder, image_name)):
    continue;
    
  matte_name = image_name.split('.')[0] + '.png'
  image = Image.open(os.path.join(input_folder, image_name))
  matte = Image.open(os.path.join(output_folder, matte_name))
  display(combined_display(image, matte))
  print(image_name, '\n')

In [None]:
# Create model artifact

import tarfile

with tarfile.open("model.tar.gz", "w:gz") as f:
    f.add("pretrained/modnet_photographic_portrait_matting.ckpt")

In [None]:
# Upload model archive to S3

import boto3
import sagemaker

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()

pt_modnet_model_data = sess.upload_data(
    path="model.tar.gz", bucket=sess.default_bucket(), key_prefix="model/pytorch/modnet"
)

print(pt_modnet_model_data)

In [None]:
from sagemaker.pytorch import PyTorchModel

model = PyTorchModel(
    entry_point="inference.py",
    source_dir="code",
    role=role,
    model_data=pt_modnet_model_data,
    framework_version="1.5.0",
    py_version="py3",
)

In [None]:
from sagemaker.serializers import IdentitySerializer
from sagemaker.deserializers import BytesDeserializer

# set local_mode to False if you want to deploy on a remote
# SageMaker instance
# local mode not supported in SageMaker Studio

local_mode = False

if local_mode:
    instance_type = "local"
else:
    instance_type = "ml.g4dn.xlarge"

predictor = model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    serializer=IdentitySerializer(),
    deserializer=BytesDeserializer(),
)

In [None]:
# use this to update endpoint with new inference code

# predictor.delete_endpoint(delete_endpoint_config=True)
# model.delete_model()

# predictor = model.deploy(
#     initial_instance_count=1,
#     instance_type=instance_type,
#     serializer=IdentitySerializer(),
#     deserializer=BytesDeserializer(),
# )

In [None]:
# test predictor with sagemaker SDK

import io
from PIL import Image

file = open(os.path.join('images/input', 'download.jeg'), 'rb')

res = predictor.predict(file, {'Accept': 'application/octet-stream'})
                
print(res)

In [None]:
# test predictor with API
# do this call from Lambda/EC2/Container etc
import boto3
import os
from PIL import Image
import io

runtime= boto3.client('runtime.sagemaker')

file = open(os.path.join('images/input', 'download.jpg'), 'rb')
endpoint_name = predictor.endpoint_name

response = runtime.invoke_endpoint(
  EndpointName=endpoint_name,
  Accept='application/octetstream',
  ContentType='application/octetstream',
  Body=file
)

res_byte_im = response['Body'].read()

# print(res_byte_im)

api_im = Image.open(os.path.join('images/input', 'download.jpg'))
api_matte = Image.open(io.BytesIO(res_byte_im))
display(combined_display(api_im, api_matte))