# Training and Deploying the Linfa image for SageMaker

This example notebook shows you how to do model training and deployment using SageMaker Python SDK. 

Notice: You must push the container image to ECR before running this notebook. See the `make build-and-push` command in `Makefile`.

The code example in `src/bin` trains a decision tree classifier using the Linfa library on the Iris data set. 

# Training

In [None]:
import boto3
import argparse
from sagemaker import get_execution_role
import sagemaker as sage
from datetime import datetime
import pandas as pd
import json

## Configurations
Please modify the configuration in here as needed

In [None]:
# Set the S3 bucket to store the dataset, model artifact and other SageMaker resources
# This is the bucket name, not the ARN or S3 URI
default_bucket = "your-bucket-name" # Please create the bucket yourself
image_name = "linfa"

In [None]:
role = get_execution_role() # Use this when you run in a SageMaker notebook
sess = sage.Session(default_bucket=default_bucket)
sagemaker_client = boto3.client('sagemaker')

In [None]:
# Getting the image 
account = sess.boto_session.client("sts").get_caller_identity()["Account"]
region = sess.boto_session.region_name
image = "{}.dkr.ecr.{}.amazonaws.com/{}".format(account, region, image_name)
print(image)

In [None]:
# Upload the dataset to default bucket
dataset_dir = "test_dir/input/data/training/"
prefix = f"{image_name}/input/data/training"  # S3 prefix # Fixme
data_location = sess.upload_data(dataset_dir, key_prefix=prefix)

In [None]:
# Create the model estimator
model = sage.estimator.Estimator(
    image_uri=image,
    role=role,
    instance_count=1,
    instance_type="ml.m5.large",
    output_path=f"s3://{sess.default_bucket()}/{image_name}/output",
    sagemaker_session=sess,
)

job_name = f'{image_name}-train-{datetime.now().strftime("%Y%m%dT%H%M%S")}'
# Start training
model.fit(
    inputs=data_location,
    job_name=job_name
)

In [None]:
# Getting some statistics about the training job

res = sagemaker_client.describe_training_job(TrainingJobName=job_name)
result = {
    'TrainingJobName': [res['TrainingJobName']],
    'TrainingStartTime': [res['TrainingStartTime']],
    'TrainingEndTime': [res['TrainingEndTime']],
    'ProcessingTime': [(res['TrainingEndTime'] - res['TrainingStartTime'])]
}
print(json.dumps(res, indent=4, default=str))

# Deployment

In [None]:

# Define the model
model_name = image_name


In [None]:

# Create the model
create_model_response = sagemaker_client.create_model(
    ModelName=model_name,
    Containers=[
        {
            'Image': image,
            'Mode': 'SingleModel',
            'ModelDataUrl': model.model_data,
            'Environment': {'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'}
        }
    ],
    ExecutionRoleArn=role
)
if create_model_response['ResponseMetadata']['HTTPStatusCode'] == 200:
    print('Model created successfully')
else:
    print('Model creation failed')
    print(create_model_response)

In [None]:
# Create the endpoint configuration
endpoint_config_name = f'{image_name}-endpoint-config'
endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            'VariantName': 'byoVariant',
            'ModelName': model_name,
            'ServerlessConfig': {
                'MemorySizeInMB': 1024,
                'MaxConcurrency': 1
            }
        }
    ]
)

# Create the endpoint
endpoint_name = f'endpoint-{job_name}'
create_endpoint_response = sagemaker_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name
)
if create_endpoint_response['ResponseMetadata']['HTTPStatusCode'] == 200:
    print('Endpoint created successfully!')
else:
    print('Endpoint creation failed!')
    print(create_endpoint_response)

In [None]:

# Wait for the endpoint to be created
describe_endpoint_response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response['EndpointStatus'] == 'Creating':
    describe_endpoint_response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response['EndpointStatus'])
    import time
    time.sleep(15)


In [None]:
runtime = boto3.client('sagemaker-runtime')
# Make an inference request
response = runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=b'1,2,3,4,5',
    ContentType='text/csv'
)

print(response['Body'].read())

# Cleanup


In [None]:
# Delete the endpoint
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)

# Delete the endpoint configuration
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)

# Delete the model
sagemaker_client.delete_model(ModelName=model_name)

# Delete the S3 bucket objects
s3 = boto3.client('s3')
objects = s3.list_objects_v2(Bucket=default_bucket, Prefix=prefix)
for obj in objects['Contents']:
    s3.delete_object(Bucket=bucket_name, Key=obj['Key'])