In [None]:
! pip install huggingface_hub
! pip install tensorflow
! pip install transformers

In [None]:
# test pre-trained model
from transformers import pipeline, set_seed
generator = pipeline('text-generation', model="rwang5688/distilgpt2-finetuned-wikitext2-tf")
set_seed(42)
generator("Hello, I'm a language model,", max_length=30, num_return_sequences=5)

In [None]:
# install AWS packages
! pip install boto3
! pip install sagemaker

In [None]:
# set profile name as opposed to entering credentials
profile_name = 'default'
region_name = 'us-west-2'

In [None]:
# get and test sagemaker client
import boto3 
session = boto3.Session(profile_name=profile_name)
sm_client = session.client('sagemaker', region_name=region_name)
response = sm_client.list_endpoints()
print(response)

In [None]:
# set model name and endpoint configuration name
import time
ml_model_name = "distilgpt2-tf"
timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
model_name = ml_model_name + '-model' + timestamp
endpoint_config_name = ml_model_name + '-epc' + timestamp
endpoint_name = ml_model_name + '-ep' + timestamp
print(model_name)
print(endpoint_config_name)
print(endpoint_name)

In [None]:
# set sagemaker execution role
import sagemaker
# create a sagemaker execution role via the AWS SageMaker console, then paste in the arn here
role = 'arn:aws:iam::123456789012:role/sagemaker-execution-role'

In [None]:
# see deep learning containers (DLC) available images here:
# https://github.com/aws/deep-learning-containers/blob/master/available_images.md 
model_image_url="763104351884.dkr.ecr."+region_name+".amazonaws.com/"+\
                "huggingface-tensorflow-inference:2.11.0-transformers4.26.0-cpu-py39-ubuntu20.04"
print(model_image_url)

# set container config
container_config = {
    'Image': model_image_url,
    'Mode': 'SingleModel',
    'Environment': {
        'HF_MODEL_ID': 'rwang5688/distilgpt2-finetuned-wikitext2-tf"',
        'HF_TASK' : 'text-generation',
        'SAGEMAKER_CONTAINER_LOG_LEVEL' : '20',
        'SAGEMAKER_REGION' : region_name
    }
}
print(container_config)

# create model
# ... models console: https://console.aws.amazon.com/sagemaker/home?#/models
response = sm_client.create_model(
    ModelName=model_name,
    PrimaryContainer=container_config,
    ExecutionRoleArn=role, 
    EnableNetworkIsolation=False
)
print(response)

In [None]:
# create endpoint config
# ... endpoint configs console: https://console.aws.amazon.com/sagemaker/home?#/endpointConfig
endpoint_config_response = sm_client.create_endpoint_config(
   EndpointConfigName=endpoint_config_name,
   ProductionVariants=[
        {
            "ModelName": model_name,
            "VariantName": "AllTraffic",
            "ServerlessConfig": {
                # Specify MemorySizeInMB and MaxConcurrency in the serverless config object
                "MemorySizeInMB": 3072,
                "MaxConcurrency": 10
            }
        }
    ]
)
print(endpoint_config_response)

print('Endpoint configuration name: {}'.format(endpoint_config_name))
print('Endpoint configuration arn:  {}'.format(endpoint_config_response['EndpointConfigArn']))

In [None]:
# create endpoint
# ... endpoints console: https://console.aws.amazon.com/sagemaker/home?#/endpoints
endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name
)
print(endpoint_response)

print('Endpoint name: {}'.format(endpoint_name))
print('Endpoint arn:  {}'.format(endpoint_response['EndpointArn']))

In [None]:
# WAIT FOR ENDPOINT TO BE "IN SERVICE" BEFORE PROCEEDING WITH THIS STEP

# invoke endpoint by endpoint name
import json
sm_runtime = session.client("sagemaker-runtime", region_name=region_name)

content_type = "application/json"

# specify "Inputs"
data = {
   "inputs": "Hello, I'm a language model,"
}

response = sm_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=content_type,
    Body=json.dumps(data)
)
print(response)
print(response["Body"].read().decode("utf-8"))

In [None]:
# clean up: uncomment the following lines
#sm_client.delete_endpoint(EndpointName=endpoint_name)
#sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
#sm_client.delete_model(ModelName=model_name)