In [1]:
# install pre-trained model packages
%pip install torch torchvision torchaudio -f https://download.pytorch.org/whl/torch_stable.html
%pip install huggingface_hub==0.1.0 
%pip install transformers==4.12

[0mLooking in links: https://download.pytorch.org/whl/torch_stable.html
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [None]:
# test pre-trained model
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TextClassificationPipeline

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")

pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)
pipe("I love Amazon SageMaker Studio Lab!")

In [2]:
# install AWS packages
%pip install boto3
%pip install sagemaker

[0mNote: you may need to restart the kernel to use updated packages.
[0mNote: you may need to restart the kernel to use updated packages.


In [3]:
# set profile name as opposed to entering credentials
profile_name = 'default'
region_name = 'us-west-2'

In [4]:
# get and test sagemaker client
import boto3 
session = boto3.Session(profile_name=profile_name)
sm_client = session.client('sagemaker', region_name=region_name)
response = sm_client.list_endpoints()
print(response)

{'Endpoints': [], 'ResponseMetadata': {'RequestId': 'e2e3b5c8-6bde-44b5-a72c-7fffa580336f', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'e2e3b5c8-6bde-44b5-a72c-7fffa580336f', 'content-type': 'application/x-amz-json-1.1', 'content-length': '16', 'date': 'Fri, 15 Jul 2022 23:54:36 GMT'}, 'RetryAttempts': 0}}


In [5]:
# set model name and endpoint configuration name
import time
ml_model_name = "distilbert-text-classification"
timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
model_name = ml_model_name + '-model' + timestamp
endpoint_config_name = ml_model_name + '-epc' + timestamp
endpoint_name = ml_model_name + '-ep' + timestamp
print(model_name)
print(endpoint_config_name)
print(endpoint_name)

distilbert-text-classification-model-2022-07-15-23-54-36
distilbert-text-classification-epc-2022-07-15-23-54-36
distilbert-text-classification-ep-2022-07-15-23-54-36


In [6]:
# set sagemaker execution role
import sagemaker
# create a sagemaker execution role via the AWS SageMaker console, then paste in the arn here
role = 'arn:aws:iam::105065840964:role/sagemaker-execution-role'

In [7]:
# see deep learning containers (DLC) available images here:
# https://github.com/aws/deep-learning-containers/blob/master/available_images.md 
model_image_url="763104351884.dkr.ecr."+region_name+".amazonaws.com/"+\
                "huggingface-pytorch-inference:1.9-transformers4.12-cpu-py38-ubuntu20.04"
print(model_image_url)

# set container config
container_config = {
    'Image': model_image_url,
    'Mode': 'SingleModel',
    'Environment': {
        'HF_MODEL_ID': 'distilbert-base-uncased-finetuned-sst-2-english',
        'HF_TASK' : 'text-classification',
        'SAGEMAKER_CONTAINER_LOG_LEVEL' : '20',
        'SAGEMAKER_REGION' : region_name
    }
}
print(container_config)

# create model
# ... models console: https://console.aws.amazon.com/sagemaker/home?#/models
response = sm_client.create_model(
    ModelName=model_name,
    PrimaryContainer=container_config,
    ExecutionRoleArn=role, 
    EnableNetworkIsolation=False
)
print(response)

763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.9-transformers4.12-cpu-py38-ubuntu20.04
{'Image': '763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.9-transformers4.12-cpu-py38-ubuntu20.04', 'Mode': 'SingleModel', 'Environment': {'HF_MODEL_ID': 'distilbert-base-uncased-finetuned-sst-2-english', 'HF_TASK': 'text-classification', 'SAGEMAKER_CONTAINER_LOG_LEVEL': '20', 'SAGEMAKER_REGION': 'us-west-2'}}
{'ModelArn': 'arn:aws:sagemaker:us-west-2:105065840964:model/distilbert-text-classification-model-2022-07-15-23-54-36', 'ResponseMetadata': {'RequestId': '38584e44-1f8b-4870-bda1-db92b7beb5da', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '38584e44-1f8b-4870-bda1-db92b7beb5da', 'content-type': 'application/x-amz-json-1.1', 'content-length': '118', 'date': 'Fri, 15 Jul 2022 23:54:41 GMT'}, 'RetryAttempts': 0}}


In [8]:
# create endpoint config
# ... endpoint configs console: https://console.aws.amazon.com/sagemaker/home?#/endpointConfig
endpoint_config_response = sm_client.create_endpoint_config(
   EndpointConfigName=endpoint_config_name,
   ProductionVariants=[
        {
            "ModelName": model_name,
            "VariantName": "AllTraffic",
            "ServerlessConfig": {
                # Specify MemorySizeInMB and MaxConcurrency in the serverless config object
                "MemorySizeInMB": 3072,
                "MaxConcurrency": 10
            }
        }
    ]
)
print(endpoint_config_response)

print('Endpoint configuration name: {}'.format(endpoint_config_name))
print('Endpoint configuration arn:  {}'.format(endpoint_config_response['EndpointConfigArn']))

{'EndpointConfigArn': 'arn:aws:sagemaker:us-west-2:105065840964:endpoint-config/distilbert-text-classification-epc-2022-07-15-23-54-36', 'ResponseMetadata': {'RequestId': 'd61a8f6a-d810-46f8-851a-d0da30d45e91', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'd61a8f6a-d810-46f8-851a-d0da30d45e91', 'content-type': 'application/x-amz-json-1.1', 'content-length': '135', 'date': 'Fri, 15 Jul 2022 23:54:42 GMT'}, 'RetryAttempts': 0}}
Endpoint configuration name: distilbert-text-classification-epc-2022-07-15-23-54-36
Endpoint configuration arn:  arn:aws:sagemaker:us-west-2:105065840964:endpoint-config/distilbert-text-classification-epc-2022-07-15-23-54-36


In [9]:
# create endpoint
# ... endpoints console: https://console.aws.amazon.com/sagemaker/home?#/endpoints
endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name
)
print(endpoint_response)

print('Endpoint name: {}'.format(endpoint_name))
print('Endpoint arn:  {}'.format(endpoint_response['EndpointArn']))

{'EndpointArn': 'arn:aws:sagemaker:us-west-2:105065840964:endpoint/distilbert-text-classification-ep-2022-07-15-23-54-36', 'ResponseMetadata': {'RequestId': '88994ef1-5dea-4c47-b216-10671d9a4902', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '88994ef1-5dea-4c47-b216-10671d9a4902', 'content-type': 'application/x-amz-json-1.1', 'content-length': '121', 'date': 'Fri, 15 Jul 2022 23:54:50 GMT'}, 'RetryAttempts': 0}}
Endpoint name: distilbert-text-classification-ep-2022-07-15-23-54-36
Endpoint arn:  arn:aws:sagemaker:us-west-2:105065840964:endpoint/distilbert-text-classification-ep-2022-07-15-23-54-36


In [10]:
# WAIT FOR ENDPOINT TO BE "IN SERVICE" BEFORE PROCEEDING WITH THIS STEP

# invoke endpoint by endpoint name
import json
sm_runtime = session.client("sagemaker-runtime", region_name=region_name)

content_type = "application/json"

# specify "Inputs"
data = {
   "inputs": "Hi, I am Dogbert."
}

response = sm_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=content_type,
    Body=json.dumps(data)
)
print(response)
print(response["Body"].read().decode("utf-8"))

{'ResponseMetadata': {'RequestId': 'eac152c0-107d-4dad-a644-5e76b532fe29', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'eac152c0-107d-4dad-a644-5e76b532fe29', 'x-amzn-invoked-production-variant': 'AllTraffic', 'date': 'Fri, 15 Jul 2022 23:56:51 GMT', 'content-type': 'application/json', 'content-length': '48'}, 'RetryAttempts': 0}, 'ContentType': 'application/json', 'InvokedProductionVariant': 'AllTraffic', 'Body': <botocore.response.StreamingBody object at 0x7ff08b642c40>}
[{"label":"POSITIVE","score":0.997832715511322}]


In [None]:
# clean up: uncomment the following lines
#sm_client.delete_endpoint(EndpointName=endpoint_name)
#sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
#sm_client.delete_model(ModelName=model_name)