In [None]:
# !pip -q install sagemaker awscli boto3 --upgrade

In [None]:
import tarfile
import sagemaker
import tensorflow as tf
import tensorflow.keras as keras
import shutil
import os
import json
import numpy as np
import time

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()

In [None]:
import boto3
client = boto3.client(service_name='sagemaker')
runtime = boto3.client(service_name='sagemaker-runtime')

## Download model and upload to S3

In [None]:
from tensorflow.keras.applications.resnet50 import ResNet50

def load_save_resnet50_model(model_path):
    model = ResNet50(weights='imagenet')
    shutil.rmtree(model_path, ignore_errors=True)
    model.save(model_path, include_optimizer=False, save_format='tf')

saved_model_dir = 'resnet50_saved_model' 
model_ver = '1'
model_path = os.path.join(saved_model_dir, model_ver)

load_save_resnet50_model(model_path)

In [None]:
shutil.rmtree('model.tar.gz', ignore_errors=True)
!tar cvfz model.tar.gz code -C resnet50_saved_model .

In [None]:
prefix = 'keras_models_serverless'
s3_model_path = sess.upload_data(path='model.tar.gz', key_prefix=prefix)

## Get model serving container

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="tensorflow",
    region=region,
    version="2.1",
    py_version="py3",
    image_scope='inference',
    instance_type='ml.c5.large'
)

print('Container image with TensorFlow Serving:')
image_uri

## 1. Create a model

In [None]:
from time import gmtime, strftime
model_name = 'keras-serverless' + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print('Model name: ' + model_name)

create_model_response = client.create_model(
    ModelName = model_name,
    Containers=[{
        "Image": image_uri,
        "Mode": "SingleModel",
        "ModelDataUrl": s3_model_path,
    }],
    ExecutionRoleArn = role
)

print("Model Arn: " + create_model_response['ModelArn'])

## 2. Define endpoint configuration

**MemorySizeInMB:** Choose any of the following values for your memory size: 1024 MB, 2048 MB, 3072 MB, 4096 MB, 5120 MB, or 6144 MB

**MaxConcurrency:** Choose between 1-50

In [None]:
keras_epc_name = "keras-serverless-epc" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=keras_epc_name,
    ProductionVariants=[
        {
            "VariantName": "kerasVariant",
            "ModelName": model_name,
            "ServerlessConfig": {
                "MemorySizeInMB": 4096,
                "MaxConcurrency": 1,
            }
        },
    ],
)

print("Serverless Endpoint Configuration Arn: " + endpoint_config_response['EndpointConfigArn'])

## 3. Create an endpoint

In [None]:
endpoint_name = "keras-serverless-ep" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=keras_epc_name,
)

print('Endpoint Arn: ' + create_endpoint_response['EndpointArn'])

In [None]:
%%time
# wait for endpoint to reach a terminal state (InService) using describe endpoint
import time

describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)

while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)

### Test endpoint

In [None]:
file_name = 'kitten.jpg'

!wget -q https://s3.amazonaws.com/model-server/inputs/kitten.jpg -O {file_name}
with open(file_name, 'rb') as f:
    image_data = f.read()

In [None]:
%%time
response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=image_data)

pred = np.array(json.loads(response['Body'].read())['predictions'][0]).argsort()[-5:][::-1] 

In [None]:
with open('imagenet_class_index.json', 'r') as f:
    labels = json.load(f)
    
for l in pred:
    print(labels[str(l)][1])

In [None]:
%%time
response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=image_data)

pred = np.array(json.loads(response['Body'].read())['predictions'][0]).argsort()[-5:][::-1] 

In [None]:
keras_realtime_epc_name = "keras-realtime-epc" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=keras_realtime_epc_name,
    ProductionVariants=[
        {
            "VariantName": "kerasVariant",
            "ModelName": model_name,
            "InstanceType": "ml.c5.xlarge",
            "InitialInstanceCount": 1
        },
    ],
)

print("Realtime Endpoint Configuration Arn: " + endpoint_config_response['EndpointConfigArn'])

In [None]:
realtime_endpoint_response = client.update_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=keras_realtime_epc_name)

In [None]:
%%time
# wait for endpoint to reach a terminal state (InService) using describe endpoint
describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)

while describe_endpoint_response["EndpointStatus"] == "Updating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)

In [None]:
%%time
response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=image_data)

pred = np.array(json.loads(response['Body'].read())['predictions'][0]).argsort()[-5:][::-1] 