## 初期設定

In [None]:
import boto3
import json
import sagemaker
from sagemaker.estimator import Estimator

client = boto3.client(service_name="sagemaker")
runtime = boto3.client(service_name="sagemaker-runtime")

boto_session = boto3.session.Session()
region = boto_session.region_name
print(region)

sagemaker_session = sagemaker.Session()
base_job_prefix = "demo-sagemaker-inference"
# role = sagemaker.get_execution_role()
account_id = sagemaker_session.account_id()
role = f"arn:aws:iam::{account_id}:role/service-role/SagemakerExecutionRole"
print(role)

default_bucket = sagemaker_session.default_bucket()
print(f"default_bucket = {default_bucket}")

## model1 の設定

In [None]:
from time import gmtime, strftime

image_uri = f"{account_id}.dkr.ecr.{region}.amazonaws.com/{base_job_prefix}:latest"

model1_name = "demo-serverless-model1-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("Model name: " + model1_name)

model1_artifacts = f"s3://{default_bucket}/{base_job_prefix}/model1.tar.gz"

model1_env_vars = {"SAGEMAKER_CONTAINER_LOG_LEVEL": "20", "SOME_ENV_VAR": "myEnvVar"}

create_model_response = client.create_model(
    ModelName=model1_name,
    Containers=[
        {
            "Image": image_uri,
            "Mode": "SingleModel",
            "ModelDataUrl": model1_artifacts,
            "Environment": model1_env_vars,
        }
    ],
    ExecutionRoleArn=role,
)

print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
epc1_name = "demo-serverless-epc1-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=epc1_name,
    ProductionVariants=[
        {
            "VariantName": "Variant1",
            "ModelName": model1_name,
            "ServerlessConfig": {
                "MemorySizeInMB": 1024,
                "MaxConcurrency": 1,
            },
        },
    ],
)

print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

## モデル2の設定 (最初は使わない)

In [None]:
model2_name = "demo-serverless-model2-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print("Model name: " + model1_name)

model2_artifacts = f"s3://{default_bucket}/{base_job_prefix}/model2.tar.gz"

model2_env_vars = {"SAGEMAKER_CONTAINER_LOG_LEVEL": "20", "SOME_ENV_VAR": "myEnvVar"}

create_model_response = client.create_model(
    ModelName=model2_name,
    Containers=[
        {
            "Image": image_uri,
            "Mode": "SingleModel",
            "ModelDataUrl": model2_artifacts,
            "Environment": model2_env_vars,
        }
    ],
    ExecutionRoleArn=role,
)

print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
epc2_name = "demo-serverless-epc2-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=epc2_name,
    ProductionVariants=[
        {
            "VariantName": "Variant2",
            "ModelName": model2_name,
            "ServerlessConfig": {
                "MemorySizeInMB": 1024,
                "MaxConcurrency": 1,
            },
        },
    ],
)

print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

## エンドポイントの設定

In [None]:
endpoint_name = "demo-serverless-ep" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())

create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=epc1_name,
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

## InServiceになるまで待機

In [None]:
%%time

# wait for endpoint to reach a terminal state (InService) using describe endpoint
import time

describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)

while describe_endpoint_response["EndpointStatus"] != "InService":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)

describe_endpoint_response

## 呼び出し

In [None]:
%%time

response = runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=b'{"key": "b"}',
    ContentType="application/json",
)

body = response["Body"].read()
data = json.loads(body)

## 削除

In [None]:
client.delete_endpoint(EndpointName=endpoint_name)
client.delete_endpoint_config(EndpointConfigName=epc1_name)
client.delete_model(ModelName=model1_name)
client.delete_endpoint_config(EndpointConfigName=epc2_name)
client.delete_model(ModelName=model2_name)

In [None]:
# 完全に消すには、ECRやS3のファイルも消す必要がある

## エンドポイントの切り替え

In [None]:
update_endpoint_response = client.update_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=epc2_name,
)

In [None]:
update_endpoint_response = client.update_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=epc1_name,
)