In [None]:
!pip -q install sagemaker pandas awscli --upgrade

# Train an image classification model on Caltech-256
http://www.vision.caltech.edu/Image_Datasets/Caltech256/

### Create an S3 bucket to store the dataset and the trained model

In [None]:
import boto3
import sagemaker

print(sagemaker.__version__)
session = sagemaker.Session()
bucket = session.default_bucket()

%env bucket s3://$bucket

### Get the name of the image classification algorithm in our region

In [None]:
region_name = boto3.Session().region_name

algorithm = sagemaker.amazon.amazon_estimator.get_image_uri(
    region_name, "image-classification", "latest")

print("Using algorithm %s" % algorithm)

### Download the Caltech-256 dataset

In [None]:
!wget http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec
!wget http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec

### Upload dataset to S3 and define locations

In [None]:
session.upload_data(path='caltech-256-60-train.rec', bucket=bucket, key_prefix='train')
session.upload_data(path='caltech-256-60-val.rec',   bucket=bucket, key_prefix='validation')

s3_train      = 's3://{}/train/'.format(bucket)
s3_validation = 's3://{}/validation/'.format(bucket)
s3_output     = 's3://{}/output'.format(bucket)

%env s3_train      $s3_train
%env s3_validation $s3_validation
%env s3_output     $s3_output

In [None]:
!aws s3 ls $s3_train
!aws s3 ls $s3_validation

### Set dataset parameters

In [None]:
train_data = sagemaker.session.s3_input(s3_train, 
                                    distribution='FullyReplicated',         # Copy the full dataset 
                                    content_type='application/x-recordio',  # RecordIO format
                                    s3_data_type='S3Prefix')

validation_data = sagemaker.session.s3_input(s3_validation,
                                             distribution='FullyReplicated', 
                                             content_type='application/x-recordio', 
                                             s3_data_type='S3Prefix')

data_channels = {'train': train_data, 'validation': validation_data}

### Configure the training job

In [None]:
ic = sagemaker.estimator.Estimator(algorithm,
                                   sagemaker.get_execution_role(), 
                                   train_instance_count=1, 
                                   train_instance_type='ml.p3.16xlarge',
                                   output_path=s3_output,
                                   sagemaker_session=session)

### Set algorithm parameters

In [None]:
ic.set_hyperparameters(num_layers=18,               # Train a Resnet-18 model
                       use_pretrained_model=1,      # Fine-tune on our dataset
                       image_shape="3,224,224",   # 3 channels (RGB), 224x224 pixels
                       num_classes=257,             # 256 classes + 1 clutter class
                       num_training_samples=15420,  # Number of training samples
                       mini_batch_size=128,
                       epochs=10,                  # Learn the training samples 10 times
                       learning_rate=0.01)

### Train the model

In [None]:
ic.fit(inputs=data_channels, logs=True)

### Deploy the best model on a GPU instance, and with Elastic Inference

In [None]:
ic_endpoint = ic.deploy(initial_instance_count=1,
                        instance_type='ml.p2.xlarge', # $1.361/hour in eu-west-1
                        endpoint_name='ic-endpoint')   



In [None]:
ic_endpoint_ei = ic.deploy(initial_instance_count=1,
                         instance_type='ml.c5.large',        # $0.134/hour in eu-west-1
                         accelerator_type='ml.eia2.medium',  # $0.181/hour in eu-west-1
                         endpoint_name='ic-endpoint-ei')

Amazon Elastic Inference allows you to attach low-cost GPU-powered acceleration to Amazon EC2 and Amazon SageMaker instances.

c5.large+eia1.medium give you performance comparable to p2.xlarge at ***77% discount***.

You'll save ***$754 per instance per month***. 

### Download a test image

In [None]:
!wget -O /tmp/test.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/3/30/Large_Gautama_Buddha_statue_in_Buddha_Park_of_Ravangla%2C_Sikkim.jpg/220px-Large_Gautama_Buddha_statue_in_Buddha_Park_of_Ravangla%2C_Sikkim.jpg

# test image
file_name = '/tmp/test.jpg'
from IPython.display import Image
Image(file_name)

### Predict test image

In [None]:
object_categories = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat', 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101', 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire', 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks', 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor', 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe', 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell', 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern', 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight', 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan', 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose', 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus', 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass', 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris', 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder', 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning', 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder', 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card', 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator', 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus', 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass', 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan', 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee', 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops', 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn', 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask', 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101', 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter']

In [None]:
import json
import numpy as np

# Load test image from file
with open(file_name, 'rb') as f:
    payload = f.read()
    payload = bytearray(payload)

# Set content type
ic_endpoint.content_type = 'application/x-image'

# Predict image and print JSON predicton
prediction = ic_endpoint.predict(payload)

result = json.loads(prediction)
# Print top class
index = np.argmax(result)
print("Result: label - " + object_categories[index] + ", probability - " + str(result[index]))

### Delete endpoints

In [None]:
ic_endpoint.delete_endpoint()
ic_endpoint_ei.delete_endpoint()