In [1]:
! ls

cnn14max.tar.gz                    sound-event-deployment-Copy1.ipynb
[34mcode[m[m                               sound-event-deployment.ipynb
[34mcontainer[m[m


In [3]:
import os
import json

import boto3
import sagemaker
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role, Session

sess = Session()

# for acoustic account & local deployment
role = 'arn:aws:iam::302145289873:role/service-role/AmazonSageMaker-ExecutionRole-20211223T112159'

# upload the trained pytorch model to S3
cnn14max_model_data = sess.upload_data(
    path="cnn14max.tar.gz", bucket=sess.default_bucket(), key_prefix="sound-event-model/pytorch"
)

In [5]:
cnn14max_model_data

's3://sagemaker-ap-southeast-1-302145289873/sound-event-model/pytorch/cnn14max.tar.gz'

In [6]:
model = PyTorchModel(
    entry_point = "inference.py",
    source_dir = "code",
    role = role,
    model_data = cnn14max_model_data,
    framework_version = "1.9.0",
    py_version = "py38",
    image_uri = '302145289873.dkr.ecr.ap-southeast-1.amazonaws.com/sound-event-detection:v1'
)

In [7]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# set local_mode to False if you want to deploy on a remote
# SageMaker instance

local_mode = False

if local_mode:
    instance_type = "local"
else:
    instance_type = "ml.c4.xlarge"

predictor = model.deploy(
    initial_instance_count =  1,
    instance_type =instance_type,
    serializer = JSONSerializer(),
    deserializer = JSONDeserializer()
)

----------!

In [8]:
# TESTING
import numpy as np
import sagemaker

rng = np.random.default_rng()
audio = 2 * rng.random(size=44100, dtype=np.float32) - 1
data_dict = {"inputs" : audio.tolist()}
predictor.predict(data_dict)

[[0.0010335678234696388,
  7.137481588870287e-06,
  1.1322070349706337e-05,
  8.696159966348205e-06,
  2.402715608695871e-06,
  3.4492586564738303e-06,
  3.335505027735053e-07,
  1.1256815923843533e-05,
  2.9589864425361156e-05,
  0.00020050859893672168,
  9.063819743460044e-05,
  8.019575761863962e-05,
  2.305712087036227e-06,
  2.182652497140225e-05,
  0.00010989647853421047,
  3.1973701197784976e-07,
  3.839698820229387e-06,
  2.102123744407436e-06,
  2.528662889744737e-06,
  4.753420398628805e-06,
  1.3424393046079786e-06,
  1.5409175375680206e-06,
  3.921057214029133e-06,
  7.313649803109001e-06,
  1.3773871614830568e-05,
  3.7409429296531016e-06,
  3.3732176234479994e-05,
  3.089029996772297e-05,
  6.311317974905251e-06,
  5.385564918469754e-07,
  2.7414491796662332e-06,
  8.466465260426048e-06,
  8.325172530021518e-06,
  5.767331458628178e-06,
  2.6276686639903346e-06,
  8.911444638215471e-06,
  4.501685452851234e-06,
  2.6850377707887674e-06,
  0.0001060889262589626,
  8.301894

In [9]:
import boto3
import json
runtime= boto3.client('runtime.sagemaker')
response = runtime.invoke_endpoint(EndpointName=predictor.endpoint,
                                       ContentType='application/json',
                                       Body=json.dumps(data_dict))

The endpoint attribute has been renamed in sagemaker>=2.
See: https://sagemaker.readthedocs.io/en/stable/v2.html for details.


In [10]:
result = json.loads(response['Body'].read().decode())

In [11]:
result

[[0.0010335678234696388,
  7.137481588870287e-06,
  1.1322070349706337e-05,
  8.696159966348205e-06,
  2.402715608695871e-06,
  3.4492586564738303e-06,
  3.335505027735053e-07,
  1.1256815923843533e-05,
  2.9589864425361156e-05,
  0.00020050859893672168,
  9.063819743460044e-05,
  8.019575761863962e-05,
  2.305712087036227e-06,
  2.182652497140225e-05,
  0.00010989647853421047,
  3.1973701197784976e-07,
  3.839698820229387e-06,
  2.102123744407436e-06,
  2.528662889744737e-06,
  4.753420398628805e-06,
  1.3424393046079786e-06,
  1.5409175375680206e-06,
  3.921057214029133e-06,
  7.313649803109001e-06,
  1.3773871614830568e-05,
  3.7409429296531016e-06,
  3.3732176234479994e-05,
  3.089029996772297e-05,
  6.311317974905251e-06,
  5.385564918469754e-07,
  2.7414491796662332e-06,
  8.466465260426048e-06,
  8.325172530021518e-06,
  5.767331458628178e-06,
  2.6276686639903346e-06,
  8.911444638215471e-06,
  4.501685452851234e-06,
  2.6850377707887674e-06,
  0.0001060889262589626,
  8.301894