##### NOTE: This example has been updated to work with SageMaker SDK 2.x which introduces breaking changes. Make sure you upgrade SageMaker SDK using the commands below

In [None]:
!pip install --upgrade pip
!pip -q install sagemaker awscli boto3 pandas --upgrade 

## Example: PyTorch deployments using TorchServe and Amazon SageMaker

In this example, we’ll show you how you can build a TorchServe container and host it using Amazon SageMaker. With Amazon SageMaker hosting you get a fully-managed hosting experience. Just specify the type of instance, and the maximum and minimum number desired, and SageMaker takes care of the rest.

With a few lines of code, you can ask Amazon SageMaker to launch the instances, download your model from Amazon S3 to your TorchServe container, and set up the secure HTTPS endpoint for your application. On the client side, get prediction with a simple API call to this secure endpoint backed by TorchServe.

Code, configuration files, Jupyter notebooks and Dockerfiles used in this example are available here:
https://github.com/shashankprasanna/torchserve-examples.git


### Clone the TorchServe repository and install torch-model-archiver

You'll use `torch-model-archiver` to create a model archive file (.mar). The .mar model archive file contains model checkpoints along with it’s `state_dict` (dictionary object that maps each layer to its parameter tensor).

In [None]:
!git clone https://github.com/pytorch/serve.git
!pip install serve/model-archiver/

### Download a PyTorch model and create a TorchServe archive

In [None]:
!wget -q https://download.pytorch.org/models/densenet161-8d451a50.pth
    
model_file_name = 'densenet161'

!torch-model-archiver --model-name {model_file_name} \
--version 1.0 --model-file serve/examples/image_classifier/densenet_161/model.py \
--serialized-file densenet161-8d451a50.pth \
--extra-files serve/examples/image_classifier/index_to_name.json \
--handler image_classifier

!ls *.mar

### Upload the generated densenet161.mar archive file to Amazon S3
Create a compressed tar.gz file from the densenet161.mar file since Amazon SageMaker expects that models are in a tar.gz file. 
Uploads the model to your default Amazon SageMaker S3 bucket under the models directory

### Create a boto3 session and get specify a role with SageMaker access

In [None]:
import boto3, time, json
sess    = boto3.Session()
sm      = sess.client('sagemaker')
region  = sess.region_name
account = boto3.client('sts').get_caller_identity().get('Account')

In [None]:
import sagemaker
role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session(boto_session=sess)

In [None]:
bucket_name = sagemaker_session.default_bucket()
prefix = 'torchserve'

!tar cvfz {model_file_name}.tar.gz densenet161.mar
!aws s3 cp {model_file_name}.tar.gz s3://{bucket_name}/{prefix}/models/

### Create an Amazon ECR registry
Create a new docker container registry for your torchserve container images.

In [None]:
registry_name = 'torchserve-1'
!aws ecr create-repository --repository-name {registry_name}

### Build a TorchServe Docker container and push it to Amazon ECR

In [None]:
image_label = 'v1'
image = f'{account}.dkr.ecr.{region}.amazonaws.com/{registry_name}:{image_label}'

!docker build -t {registry_name}:{image_label} .
!$(aws ecr get-login --no-include-email --region {region})
!docker tag {registry_name}:{image_label} {image}
!docker push {image}

### Deploy endpoint and make prediction using Amazon SageMaker SDK

In [None]:
from sagemaker.model import Model
from sagemaker.predictor import Predictor

model_data = f's3://{bucket_name}/{prefix}/models/{model_file_name}.tar.gz'
sm_model_name = 'torchserve-densenet161'

torchserve_model = Model(model_data = model_data, 
                         image_uri = image,
                         role  = role,
                         predictor_cls=Predictor,
                         name  = sm_model_name)

In [None]:
endpoint_name = 'torchserve-endpoint-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

predictor = torchserve_model.deploy(instance_type='ml.m4.xlarge',
                                    initial_instance_count=1,
                                    endpoint_name = endpoint_name)

#### Test the TorchServe hosted model

In [None]:
!wget -q https://s3.amazonaws.com/model-server/inputs/kitten.jpg    
file_name = 'kitten.jpg'
with open(file_name, 'rb') as f:
    payload = f.read()
    payload = payload
    
response = predictor.predict(data=payload)
print(*json.loads(response), sep = '\n')

### Deploy endpoint and make prediction using Python SDK (Boto3)

In [None]:
model_data = f's3://{bucket_name}/{prefix}/models/{model_file_name}.tar.gz'
sm_model_name = 'torchserve-densenet161-boto'

container = {
    'Image': image,
    'ModelDataUrl': model_data
}

create_model_response = sm.create_model(
    ModelName         = sm_model_name,
    ExecutionRoleArn  = role,
    PrimaryContainer  = container)

print(create_model_response['ModelArn'])

In [None]:
import time
endpoint_config_name = 'torchserve-endpoint-config-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
print(endpoint_config_name)

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants = [{
        'InstanceType'        : 'ml.m4.xlarge',
        'InitialVariantWeight': 1,
        'InitialInstanceCount': 1,
        'ModelName'           : sm_model_name,
        'VariantName'         : 'AllTraffic'}])

print("Endpoint Config Arn: " + create_endpoint_config_response['EndpointConfigArn'])

In [None]:
endpoint_name = 'torchserve-endpoint-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
print(endpoint_name)

create_endpoint_response = sm.create_endpoint(
    EndpointName         = endpoint_name,
    EndpointConfigName   = endpoint_config_name)
print(create_endpoint_response['EndpointArn'])

In [None]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print("Status: " + status)

while status=='Creating':
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp['EndpointStatus']
    print("Status: " + status)

print("Arn: " + resp['EndpointArn'])
print("Status: " + status)

In [None]:
!wget https://s3.amazonaws.com/model-server/inputs/kitten.jpg    
file_name = 'kitten.jpg'
with open(file_name, 'rb') as f:
    payload = f.read()
    payload = payload

In [None]:
import json
client = boto3.client('runtime.sagemaker')

response = client.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=payload)

print(*json.loads(response['Body'].read()), sep = '\n')