# Training Image Classification Models

This notebook trains an Image Classification model for skin conditions and rashes. 

## Setup

Here we define S3 file paths for input and output data, the training image containing the semantic segmentaion algorithm, and instantiate a SageMaker session.

In [131]:
import boto3
import re
import sagemaker
from sagemaker import get_execution_role
import time
from time import gmtime, strftime
import json

role = get_execution_role()
sess = sagemaker.Session()
s3 = boto3.resource('s3')

training_image = sagemaker.amazon.amazon_estimator.get_image_uri(boto3.Session().region_name, 'image-classification', repo_version='latest')

Provide inputs for training and validation datasets

In [None]:
augmented_manifest_filename_train = 'train.manifest' 
augmented_manifest_filename_validation = 'validation.manifest' # Replace with the filename for your validation data.
bucket_name = "sciencemit" # Replace with your bucket name.
s3_prefix = '' # Replace with the S3 prefix where your data files reside.
s3_output_path = 's3://{}/output'.format(bucket_name) # Replace with your desired output directory.

The setup section concludes with a few more definitions and constants.

In [None]:
# Defines paths for use in the training job request.
s3_train_data_path = 's3://{}/{}'.format(bucket_name, augmented_manifest_filename_train)
s3_validation_data_path = 's3://{}/{}'.format(bucket_name, augmented_manifest_filename_validation)

print("Augmented manifest for training data: {}".format(s3_train_data_path))
print("Augmented manifest for validation data: {}".format(s3_validation_data_path))

### Preview Input Data

Let's read the augmented manifest so we can inspect its contents to better understand the format.

In [None]:
augmented_manifest_s3_key = s3_train_data_path.split(bucket_name)[1][1:]
s3_obj = s3.Object(bucket_name, augmented_manifest_s3_key)
augmented_manifest = s3_obj.get()['Body'].read().decode('utf-8')
augmented_manifest_lines = augmented_manifest.split('\n')

num_training_samples = len(augmented_manifest_lines) # Compute number of training samples for use in training job request.


print('Preview of Augmented Manifest File Contents')
print('-------------------------------------------')
print('\n')

for i in range(2):
    print('Line {}'.format(i+1))
    print(augmented_manifest_lines[i])
    print('\n')
    


Because we know the importance of the `AttributeNames` parameter in the training job request, let's go ahead and define it so that it corresponds to what we've seen in this example augmented manifest.

In [None]:
attribute_names = ["source-ref","job-science-1201"] # Replace as appropriate for your augmented manifest.

# Create Training Job

First, we'll construct the request for the training job.

In [None]:
# Create unique job name 
job_name_prefix = 'job-science-1201'
timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
job_name = job_name_prefix + timestamp

training_params = \
{
    "AlgorithmSpecification": {
        "TrainingImage": training_image, # NB. This is one of the named constants defined in the first cell.
        "TrainingInputMode": "Pipe"
    },
    "RoleArn": role,
    "OutputDataConfig": {
        "S3OutputPath": s3_output_path
    },
    "ResourceConfig": {
        "InstanceCount": 1,   
        "InstanceType": "ml.p3.2xlarge",
        "VolumeSizeInGB": 50
    },
    "TrainingJobName": job_name,
    "HyperParameters": { # NB. These hyperparameters are at the user's discretion and are beyond the scope of this demo.
        "image_shape": "3,224,224",
        "num_training_samples": "280",
        "num_classes": "11",
        "mini_batch_size": "32",
        "epochs": "30",
        "learning_rate": "0.1",
        "num_layers": "101"
    },
    "StoppingCondition": {
        "MaxRuntimeInSeconds": 86400
    },
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "AugmentedManifestFile", # NB. Augmented Manifest
                    "S3Uri": s3_train_data_path,
                    "S3DataDistributionType": "FullyReplicated",
                    "AttributeNames": attribute_names # NB. This must correspond to the JSON field names in your augmented manifest.
                }
            },
            "ContentType": "application/x-recordio",
            "RecordWrapperType": "RecordIO",
            "CompressionType": "None"
        },
        {
            "ChannelName": "validation",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "AugmentedManifestFile", # NB. Augmented Manifest
                    "S3Uri": s3_validation_data_path,
                    "S3DataDistributionType": "FullyReplicated",
                    "AttributeNames": attribute_names # NB. This must correspond to the JSON field names in your augmented manifest.
                }
            },
            "ContentType": "application/x-recordio",
            "RecordWrapperType": "RecordIO",
            "CompressionType": "None"
        }
    ]
}
 
print('Training job name: {}'.format(job_name))
print('\nInput Data Location: {}'.format(training_params['InputDataConfig'][0]['DataSource']['S3DataSource']))

Now we create the Amazon SageMaker training job.

In [None]:
client = boto3.client(service_name='sagemaker')
client.create_training_job(**training_params)

# Confirm that the training job has started
status = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
print('Training job current status: {}'.format(status))


In [None]:
TrainingJobStatus = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
SecondaryStatus = client.describe_training_job(TrainingJobName=job_name)['SecondaryStatus']
print(TrainingJobStatus, SecondaryStatus)
while TrainingJobStatus !='Completed' and TrainingJobStatus!='Failed':
    time.sleep(60)
    TrainingJobStatus = client.describe_training_job(TrainingJobName=job_name)['TrainingJobStatus']
    SecondaryStatus = client.describe_training_job(TrainingJobName=job_name)['SecondaryStatus']
    print(TrainingJobStatus, SecondaryStatus)

In [None]:
training_info = client.describe_training_job(TrainingJobName=job_name)
print(training_info)

# Inference and Predictions

We now deploy the Image Classification model with an endpoint and perform inference.

In [None]:
%%time
import boto3
from time import gmtime, strftime
import re
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri

sage = boto3.Session().client(service_name='sagemaker') 

model_name="job-science-1201-model-3"
print(model_name)
info = sage.describe_training_job(TrainingJobName=job_name)
model_data = info['ModelArtifacts']['S3ModelArtifacts']
print(model_data)

hosting_image = get_image_uri(boto3.Session().region_name, 'image-classification')

primary_container = {
    'Image': hosting_image,
    'ModelDataUrl': model_data,
}

create_model_response = sage.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = primary_container)

print(create_model_response['ModelArn'])

**Create Endpoint**

Next, I create the endpoint that serves up the image classification model, through specifying the name and configuration defined above. The end result is an endpoint that can be validated and incorporated into applications.

In [None]:
from time import gmtime, strftime

timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
endpoint_config_name = job_name_prefix + '-epc-' + timestamp
endpoint_config_response = sage.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants=[{
        'InstanceType':'ml.m4.xlarge',
        'InitialInstanceCount':1,
        'ModelName':model_name,
        'VariantName':'AllTraffic'}])

print('Endpoint configuration name: {}'.format(endpoint_config_name))
print('Endpoint configuration arn:  {}'.format(endpoint_config_response['EndpointConfigArn']))

In [None]:
%%time
import time

timestamp = time.strftime('-%Y-%m-%d-%H-%M-%S', time.gmtime())
endpoint_name = job_name_prefix + '-ep-' + timestamp
print('Endpoint name: {}'.format(endpoint_name))

endpoint_params = {
    'EndpointName': endpoint_name,
    'EndpointConfigName': endpoint_config_name,
}
endpoint_response = sage.create_endpoint(**endpoint_params)
print('EndpointArn = {}'.format(endpoint_response['EndpointArn']))

In [None]:
# get the status of the endpoint
response = sage.describe_endpoint(EndpointName=endpoint_name)
status = response['EndpointStatus']
print('EndpointStatus = {}'.format(status))


# wait until the status has changed
sage.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name)


# print the status of the endpoint
endpoint_response = sage.describe_endpoint(EndpointName=endpoint_name)
status = endpoint_response['EndpointStatus']
print('Endpoint creation ended with EndpointStatus = {}'.format(status))

if status != 'InService':
    raise Exception('Endpoint creation failed.')

**Perform Inference**

Now we can validate the model for use.

In [None]:
import boto3
runtime = boto3.Session().client(service_name='runtime.sagemaker') 

**Download test image**

In [None]:
import os
import json
import random
import imageio
import numpy as np
import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as rocplt
import scikitplot as skplt
from sklearn.metrics import roc_curve, auc

from collections import namedtuple
from collections import Counter


# Load the output manifest's annotations.
OUTPUT_MANIFEST = 's3://sciencemit/output.manifest' # Replace with the S3 URI for your output manifest.
!aws s3 cp {OUTPUT_MANIFEST} 'output.manifest'

with open('output.manifest', 'r') as f:
    output = [json.loads(line.strip()) for line in f.readlines()]

# Create data arrays.
img_uris = [None] * len(output)
confidences = np.zeros(len(output))
labels = [None] * len(output)
human = np.zeros(len(output))

# Find the job name the manifest corresponds to.
keys = list(output[0].keys())
metakey = keys[np.where([('-metadata' in k) for k in keys])[0][0]]
jobname = metakey[:-9]

# Extract the data.
for datum_id, datum in enumerate(output):
    img_uris[datum_id] = datum['source-ref']
    labels[datum_id] = datum[metakey]['class-name']
    confidences[datum_id] = datum[metakey]['confidence']
    human[datum_id] = int(datum[metakey]['human-annotated'] == 'yes')
labels = np.array(labels)

# Compute the number of annotations in each class.
n_classes = len(set(labels))
sorted_clnames, class_sizes = zip(*Counter(labels).most_common(n_classes))

# Find ids of human-annotated images.
human_sizes = [human[labels == clname].sum() for clname in sorted_clnames]
class_sizes = np.array(class_sizes)
human_sizes = np.array(human_sizes)

# Compute the average annotation confidence per class.
human_confidences = np.array([confidences[np.logical_and(labels == clname, human)]
                              for clname in sorted_clnames])
machine_confidences = [confidences[np.logical_and(labels == clname, 1-human)]
                       for clname in sorted_clnames]

# If there is no images annotated as a specific class, set the average class confidence to 0.
for class_id in range(n_classes):
    if human_confidences[class_id].size == 0:
        human_confidences[class_id] = np.array([0])
    if machine_confidences[class_id].size == 0:
        machine_confidences[class_id] = np.array([0])

LOCAL_IMG_DIR = './dataset' # Replace with the name of a local directory to store images.
DATASET_SIZE = len(img_uris) # Change this to a reasonable number if your dataset much larger than 10K images.

subset_ids = np.random.choice(range(len(img_uris)), DATASET_SIZE, replace=False)
img_uris = [img_uris[idx] for idx in subset_ids]

labels = labels[subset_ids]
confidences = confidences[subset_ids]
human = human[subset_ids]

img_fnames = [None] * len(output)
for img_uri_id, img_uri in enumerate(img_uris):
    target_fname = os.path.join(
        LOCAL_IMG_DIR, img_uri.split('/')[-1])
    if not os.path.isfile(target_fname):
        !aws s3 cp {img_uri} {target_fname}
    img_fnames[img_uri_id] = target_fname
    
N_SHOW = 15
mean_prediction_classes = []

plt.figure(figsize=(3 * N_SHOW, 2 + 3 * n_classes), facecolor='white', dpi=60)
for class_name_id, class_name in enumerate(sorted_clnames):
    
    class_prob = 0
    num_ids = 0
    
    class_ids = np.where(np.logical_and(np.array(labels) == class_name, human))[0]
    actual_class.append(class_name)
    try:
        show_ids = np.random.choice(class_ids, N_SHOW, replace=False)
    except ValueError:
        print('Not enough human annotations to show {}'.format(class_name))
        continue
    for show_id_id, show_id in enumerate(show_ids):
    
        file_name = img_fnames[show_id]
        
        with open(file_name, 'rb') as f:
            payload = f.read()
            payload = bytearray(payload)
        try:
            response = runtime.invoke_endpoint(EndpointName=endpoint_name, 
                                   ContentType='application/x-image', 
                                   Body=payload)
        except:
            continue
        result = response['Body'].read()
        # result will be in json format and convert it to ndarray
        result = json.loads(result)

        # the result will output the probabilities for all classes
        # find the class with maximum probability and print the class index
        index = np.argmax(result)

        object_categories = ['chicken pox', 'heat rash', 'impetigo', 'eczema', 'measles', 'hives', 'baby acne', 'dermatitis', 'fifth disease', 'cold sores', 'scarlet fever']
        probability = result[index]
            
        class_prob = class_prob + (probability*7)
        num_ids = num_ids + 1
    
        print('Ground truth: ', class_name, ' Actual probability: ',str(confidences[show_id_id]), ' Predicted class: ', object_categories[index], ' Prediction: ' + str(probability*7))

    print('Class name: ', class_name, ' ', 'Mean predicted probability: ', str((class_prob/num_ids))
    mean_prediction_classes.append(class_prob/num_ids)

classes = ['chicken pox', 'heat rash', 'impetigo', 'eczema', 'measles', 'hives', 'baby acne', 'dermatitis', 'fifth disease', 'cold sores', 'scarlet fever']
y_pos = np.arange(len(classes))
 
plt.figure(figsize=(9, 3), facecolor='white', dpi=100)
plt.bar(y_pos, mean_prediction_classes, align='center', alpha=0.5)
plt.xticks(y_pos, classes, rotation=90)
plt.ylabel('predicted confidences')
plt.title('Mean of predicted confidences')
 
plt.show()

 

In [None]:
!wget -O /tmp/test.jpg https://assets.nhs.uk/prod/images/D1BMFN.2e16d0ba.fill-920x613.jpg
file_name = '/tmp/test.jpg'
# test image
from IPython.display import Image
Image(file_name)  

**Evaluation**

Evaluate the image through the network for inteference. The network outputs class probabilities and typically, one selects the class with the maximum probability as the final class output.

In [None]:
sage.delete_endpoint(EndpointName=endpoint_name)