# PyTorch Training and Serving in SageMaker "Script Mode"

Script mode is a training script format for PyTorch that lets you execute any PyTorch training script in SageMaker with minimal modification. The [SageMaker Python SDK](https://github.com/aws/sagemaker-python-sdk) handles transferring your script to a SageMaker training instance. On the training instance, SageMaker's native PyTorch support sets up training-related environment variables and executes your training script. In this tutorial, we use the SageMaker Python SDK to launch a training job and deploy the trained model.

Script mode supports training with a Python script, a Python module, or a shell script. In this example, we use a Python script to train a classification model on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). In this example, we will show how easily you can train a SageMaker using PyTorch scripts with SageMaker Python SDK. In addition, this notebook demonstrates how to perform real time inference with the [SageMaker PyTorch Serving container](https://github.com/aws/sagemaker-pytorch-serving-container). The PyTorch Serving container is the default inference method for script mode. For full documentation on deploying PyTorch models, please visit [here](https://github.com/aws/sagemaker-python-sdk/blob/master/doc/using_pytorch.rst#deploy-pytorch-models).

## Contents

1. [Background](#Background)
1. [Setup](#Setup)
1. [Data](#Data)
1. [Train](#Train)
1. [Host](#Host)

---

## Background

MNIST is a widely used dataset for handwritten digit classification. It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits. The dataset is split into 60,000 training images and 10,000 test images. There are 10 classes (one for each of the 10 digits). This tutorial will show how to train and test an MNIST model on SageMaker using PyTorch.

For more information about the PyTorch in SageMaker, please visit [sagemaker-pytorch-containers](https://github.com/aws/sagemaker-pytorch-containers) and [sagemaker-python-sdk](https://github.com/aws/sagemaker-python-sdk) github repositories.

---

## Setup

_This notebook was created and tested on an ml.m4.xlarge notebook instance._

## Install SageMaker Python SDK

In [3]:
!pip install sagemaker --upgrade --ignore-installed --no-cache --user

Collecting sagemaker
[?25l  Downloading https://files.pythonhosted.org/packages/d7/a4/c2ac91b538769cba93763b3aafa1141db9a580b885877913ba4aa5b54d58/sagemaker-1.50.1.tar.gz (291kB)
[K    100% |████████████████████████████████| 296kB 22.2MB/s ta 0:00:01
[?25hCollecting boto3>=1.10.32 (from sagemaker)
[?25l  Downloading https://files.pythonhosted.org/packages/ac/15/93f5961dde2d14027e0215733633bd709eb6fb0e9af6d046dea9c4c0769c/boto3-1.10.49-py2.py3-none-any.whl (128kB)
[K    100% |████████████████████████████████| 133kB 44.3MB/s ta 0:00:01
[?25hCollecting numpy>=1.9.0 (from sagemaker)
[?25l  Downloading https://files.pythonhosted.org/packages/62/20/4d43e141b5bc426ba38274933ef8e76e85c7adea2c321ecf9ebf7421cedf/numpy-1.18.1-cp36-cp36m-manylinux1_x86_64.whl (20.1MB)
[K    100% |████████████████████████████████| 20.2MB 83.6MB/s ta 0:00:01
[?25hCollecting protobuf>=3.1 (from sagemaker)
[?25l  Downloading https://files.pythonhosted.org/packages/ca/ac/838c8c8a5f33a58132dd2ad2a30329f6ae1614

In [4]:
!pip install torch==1.3.1 torchvision==0.4.2 --upgrade --ignore-installed --no-cache --user

Collecting torch==1.3.1
[?25l  Downloading https://files.pythonhosted.org/packages/88/95/90e8c4c31cfc67248bf944ba42029295b77159982f532c5689bcfe4e9108/torch-1.3.1-cp36-cp36m-manylinux1_x86_64.whl (734.6MB)
[K    83% |██████████████████████████▋     | 612.0MB 80.4MB/s eta 0:00:024  12% |████                            | 91.4MB 79.5MB/s eta 0:00:09    12% |████▏                           | 95.4MB 80.9MB/s eta 0:00:08    14% |████▊                           | 107.5MB 81.5MB/s eta 0:00:08    23% |███████▋                        | 175.3MB 79.4MB/s eta 0:00:08    | 419.5MB 89.5MB/s eta 0:00:04

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[K    90% |█████████████████████████████   | 666.8MB 80.8MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.8MB 81.7MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.8MB 85.0MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.8MB 85.7MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.8MB 83.0MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.9MB 81.7MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.9MB 81.6MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.9MB 84.9MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.9MB 87.9MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.9MB 84.9MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.9MB 86.1MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.9MB 86.7MB/s eta 0:00:01[K    90% |█████████████████████████████   | 666.9MB 86.6MB/s eta 0:00:01[K    90% |████████████

.8MB 87.3MB/s eta 0:00:01[K    91% |█████████████████████████████▏  | 670.8MB 83.7MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.8MB 83.6MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.8MB 85.9MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.8MB 84.9MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.8MB 86.7MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.9MB 87.2MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.9MB 86.7MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.9MB 87.4MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.9MB 86.6MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.9MB 87.9MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.9MB 91.2MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.9MB 87.8MB/s eta 0:00:01[K    91% |█████████████████████████████▎  | 670.9MB 88.5MB/s eta 0:00:01

[K    91% |█████████████████████████████▍  | 674.7MB 80.9MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.7MB 82.8MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.7MB 83.1MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.7MB 81.1MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.7MB 80.9MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.7MB 79.8MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.7MB 80.9MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.7MB 84.6MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.7MB 82.9MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.7MB 82.9MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.8MB 81.7MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.8MB 78.6MB/s eta 0:00:01[K    91% |█████████████████████████████▍  | 674.8MB 80.6MB/s eta 0:00:01[K    91% |████████████

[K    92% |█████████████████████████████▋  | 678.8MB 83.4MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.8MB 82.7MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.8MB 82.6MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.8MB 81.1MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.8MB 81.8MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.9MB 79.8MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.9MB 79.3MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.9MB 83.0MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.9MB 83.1MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.9MB 82.9MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.9MB 81.6MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.9MB 81.6MB/s eta 0:00:01[K    92% |█████████████████████████████▋  | 678.9MB 83.1MB/s eta 0:00:01[K    92% |████████████

[K    92% |█████████████████████████████▊  | 683.0MB 85.8MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.0MB 86.5MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.0MB 88.7MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.0MB 83.4MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.0MB 84.7MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.1MB 84.9MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.1MB 84.8MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.1MB 84.1MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.1MB 80.4MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.1MB 80.3MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.1MB 79.6MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.1MB 77.6MB/s eta 0:00:01[K    92% |█████████████████████████████▊  | 683.1MB 78.2MB/s eta 0:00:01[K    92% |████████████

[K    93% |██████████████████████████████  | 687.2MB 79.7MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.2MB 79.6MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 77.8MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 75.7MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 78.1MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 81.8MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 80.3MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 82.1MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 80.3MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 79.5MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 80.7MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.3MB 77.7MB/s eta 0:00:01[K    93% |██████████████████████████████  | 687.4MB 82.1MB/s eta 0:00:01[K    93% |████████████

[K    94% |██████████████████████████████▏ | 691.4MB 81.9MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.4MB 79.2MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 78.7MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 80.1MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 79.0MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 82.4MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 83.3MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 83.1MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 84.6MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 82.9MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 83.3MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.5MB 85.7MB/s eta 0:00:01[K    94% |██████████████████████████████▏ | 691.6MB 83.2MB/s eta 0:00:01[K    94% |████████████

[K    94% |██████████████████████████████▎ | 695.7MB 80.2MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.7MB 77.4MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.7MB 79.4MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.7MB 81.9MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.7MB 77.8MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.7MB 77.5MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.7MB 75.6MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.7MB 76.0MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.7MB 78.1MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.8MB 76.7MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.8MB 78.6MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.8MB 78.6MB/s eta 0:00:01[K    94% |██████████████████████████████▎ | 695.8MB 76.3MB/s eta 0:00:01[K    94% |████████████

[K    95% |██████████████████████████████▌ | 699.8MB 79.5MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 82.3MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 82.7MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 80.7MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 78.3MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 79.6MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 83.8MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 82.9MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 85.1MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 82.9MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 699.9MB 82.7MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 700.0MB 83.6MB/s eta 0:00:01[K    95% |██████████████████████████████▌ | 700.0MB 79.3MB/s eta 0:00:01[K    95% |████████████

[K    95% |██████████████████████████████▊ | 704.1MB 78.9MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.1MB 80.6MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.1MB 81.5MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.1MB 79.9MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.1MB 81.1MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.1MB 81.6MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.1MB 83.4MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.1MB 83.4MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.1MB 81.5MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.2MB 81.8MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.2MB 81.5MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.2MB 79.8MB/s eta 0:00:01[K    95% |██████████████████████████████▊ | 704.2MB 77.9MB/s eta 0:00:01[K    95% |████████████

[K    96% |██████████████████████████████▉ | 707.9MB 74.4MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 707.9MB 76.0MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 707.9MB 76.8MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 707.9MB 81.7MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 707.9MB 80.9MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 707.9MB 79.8MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 707.9MB 82.2MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 707.9MB 80.7MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 708.0MB 82.9MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 708.0MB 86.5MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 708.0MB 85.7MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 708.0MB 85.1MB/s eta 0:00:01[K    96% |██████████████████████████████▉ | 708.0MB 82.0MB/s eta 0:00:01[K    96% |████████████

[K    96% |███████████████████████████████ | 711.9MB 82.8MB/s eta 0:00:01[K    96% |███████████████████████████████ | 711.9MB 81.6MB/s eta 0:00:01[K    96% |███████████████████████████████ | 711.9MB 80.3MB/s eta 0:00:01[K    96% |███████████████████████████████ | 711.9MB 82.4MB/s eta 0:00:01[K    96% |███████████████████████████████ | 711.9MB 80.7MB/s eta 0:00:01[K    96% |███████████████████████████████ | 711.9MB 82.8MB/s eta 0:00:01[K    96% |███████████████████████████████ | 711.9MB 84.4MB/s eta 0:00:01[K    96% |███████████████████████████████ | 712.0MB 82.4MB/s eta 0:00:01[K    96% |███████████████████████████████ | 712.0MB 82.1MB/s eta 0:00:01[K    96% |███████████████████████████████ | 712.0MB 81.8MB/s eta 0:00:01[K    96% |███████████████████████████████ | 712.0MB 80.2MB/s eta 0:00:01[K    96% |███████████████████████████████ | 712.0MB 81.5MB/s eta 0:00:01[K    96% |███████████████████████████████ | 712.0MB 82.1MB/s eta 0:00:01[K    96% |████████████

[K    97% |███████████████████████████████▏| 715.8MB 84.3MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 715.9MB 82.1MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 715.9MB 80.8MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 715.9MB 79.8MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 715.9MB 80.0MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 715.9MB 81.5MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 715.9MB 82.1MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 715.9MB 80.5MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 715.9MB 80.0MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 715.9MB 81.1MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 716.0MB 83.8MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 716.0MB 85.1MB/s eta 0:00:01[K    97% |███████████████████████████████▏| 716.0MB 84.1MB/s eta 0:00:01[K    97% |████████████

[K    97% |███████████████████████████████▍| 719.8MB 77.8MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 719.8MB 78.5MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 719.8MB 81.2MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 719.8MB 79.9MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 719.8MB 80.7MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 719.8MB 82.0MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 719.9MB 85.1MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 719.9MB 84.5MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 719.9MB 82.2MB/s eta 0:00:01[K    97% |███████████████████████████████▍| 719.9MB 82.8MB/s eta 0:00:01[K    98% |███████████████████████████████▍| 719.9MB 84.4MB/s eta 0:00:01[K    98% |███████████████████████████████▍| 719.9MB 86.2MB/s eta 0:00:01[K    98% |███████████████████████████████▍| 719.9MB 88.4MB/s eta 0:00:01[K    98% |████████████

[K    98% |███████████████████████████████▌| 723.6MB 75.5MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.6MB 75.3MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.6MB 77.6MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.6MB 75.3MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.7MB 75.7MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.7MB 80.3MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.7MB 80.2MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.7MB 84.0MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.7MB 85.4MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.7MB 85.5MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.7MB 86.3MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.7MB 83.9MB/s eta 0:00:01[K    98% |███████████████████████████████▌| 723.7MB 84.3MB/s eta 0:00:01[K    98% |████████████

[K    99% |███████████████████████████████▊| 727.6MB 85.6MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.6MB 82.8MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.6MB 80.1MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.6MB 80.1MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.6MB 81.6MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.6MB 84.4MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.6MB 85.7MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.7MB 82.2MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.7MB 82.9MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.7MB 83.5MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.7MB 81.6MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.7MB 83.2MB/s eta 0:00:01[K    99% |███████████████████████████████▊| 727.7MB 83.5MB/s eta 0:00:01[K    99% |████████████

[K    99% |███████████████████████████████▉| 731.6MB 80.9MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.6MB 79.6MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.7MB 80.5MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.7MB 79.7MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.7MB 80.7MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.7MB 82.6MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.7MB 81.3MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.7MB 82.7MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.7MB 81.3MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.7MB 82.3MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.7MB 88.4MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.8MB 87.8MB/s eta 0:00:01[K    99% |███████████████████████████████▉| 731.8MB 89.5MB/s eta 0:00:01[K    99% |████████████

[?25hCollecting torchvision==0.4.2
[?25l  Downloading https://files.pythonhosted.org/packages/9b/e2/2b1f88a363ae37b2ba52cfb785ddfb3dd5f7e7ec9459e96fd8299b84ae39/torchvision-0.4.2-cp36-cp36m-manylinux1_x86_64.whl (10.2MB)
[K    100% |████████████████████████████████| 10.2MB 95.5MB/s ta 0:00:01
[?25hCollecting numpy (from torch==1.3.1)
[?25l  Downloading https://files.pythonhosted.org/packages/62/20/4d43e141b5bc426ba38274933ef8e76e85c7adea2c321ecf9ebf7421cedf/numpy-1.18.1-cp36-cp36m-manylinux1_x86_64.whl (20.1MB)
[K    100% |████████████████████████████████| 20.2MB 95.0MB/s ta 0:00:01                   | 563kB 68.1MB/s eta 0:00:01
[?25hCollecting pillow>=4.1.1 (from torchvision==0.4.2)
[?25l  Downloading https://files.pythonhosted.org/packages/19/5e/23dcc0ce3cc2abe92efd3cd61d764bee6ccdf1b667a1fb566f45dc249953/Pillow-7.0.0-cp36-cp36m-manylinux1_x86_64.whl (2.1MB)
[K    100% |████████████████████████████████| 2.1MB 91.9MB/s ta 0:00:01
[?25hCollecting six (from torchvision==0.4.2)

Forcing `pillow==6.2.1` due to https://discuss.pytorch.org/t/cannot-import-name-pillow-version-from-pil/66096

In [10]:
!pip uninstall -y pillow

[33mSkipping pillow as it is not installed.[0m
[33mYou are using pip version 10.0.1, however version 19.3.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [11]:
!pip install pillow==6.2.1 --upgrade --ignore-installed --no-cache --user

Collecting pillow==6.2.1
[?25l  Downloading https://files.pythonhosted.org/packages/10/5c/0e94e689de2476c4c5e644a3bd223a1c1b9e2bdb7c510191750be74fa786/Pillow-6.2.1-cp36-cp36m-manylinux1_x86_64.whl (2.1MB)
[K    100% |████████████████████████████████| 2.1MB 94.2MB/s ta 0:00:01
[31mdocker-compose 1.24.1 has requirement requests!=2.11.0,!=2.12.2,!=2.18.0,<2.21,>=2.6.1, but you'll have requests 2.22.0 which is incompatible.[0m
[31mawscli 1.16.283 has requirement botocore==1.13.19, but you'll have botocore 1.13.49 which is incompatible.[0m
[31mawscli 1.16.283 has requirement rsa<=3.5.0,>=3.1.2, but you'll have rsa 4.0 which is incompatible.[0m
[?25hInstalling collected packages: pillow
Successfully installed pillow-6.2.1
[33mYou are using pip version 10.0.1, however version 19.3.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


## Restart the Kernel to Recognize New Dependencies Above

In [None]:
from IPython.display import display_html
display_html("<script>Jupyter.notebook.kernel.restart()</script>", raw=True)

In [1]:
!pip3 list

Package              Version   
-------------------- ----------
absl-py              0.9.0     
astor                0.8.1     
astroid              1.6.6     
attrs                19.3.0    
awscli               1.16.76   
awscli-cwlogs        1.4.6     
backcall             0.1.0     
bleach               3.1.0     
bokeh                1.0.4     
boto                 2.49.0    
boto3                1.10.49   
botocore             1.13.49   
cachetools           4.0.0     
certifi              2019.11.28
chardet              3.0.4     
cloudpickle          1.2.2     
cmake                3.13.3    
colorama             0.3.9     
coremltools          2.0       
cpplint              1.3.0     
cycler               0.10.0    
dask                 2.6.0     
decorator            4.4.1     
defusedxml           0.6.0     
docutils             0.15.2    
entrypoints          0.3       
environment-kernels  1.1.1     
future               0.17.1    
gast                 0.2.2     
google-a

## Create the SageMaker Session

In [2]:
import os
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

## Setup the Service Execution Role and Region
Get IAM role arn used to give training and hosting access to your data.  See the documentation for how to create these.  Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the `sagemaker.get_execution_role()` with a the appropriate full IAM role arn string(s).

In [3]:
role = get_execution_role()
print('RoleARN:  {}\n'.format(role))

region = sagemaker_session.boto_session.region_name
print('Region:  {}'.format(region))

RoleARN:  arn:aws:iam::362377691630:role/service-role/AmazonSageMaker-ExecutionRole-20200109T002600

Region:  us-east-1


## Training Data

### Copy the Training Data to Your Notebook Disk

In [4]:
local_data_path = './data'

In [6]:
from torchvision import datasets, transforms

normalization_mean = 0.1307
normalization_std = 0.3081

# download the dataset
# this will not only download data to ./mnist folder, but also load and transform (normalize) them
datasets.MNIST(local_data_path, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((normalization_mean,), (normalization_std,))
]))

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )

In [7]:
!ls -R {local_data_path}

./data:
MNIST

./data/MNIST:
processed  raw

./data/MNIST/processed:
test.pt  training.pt

./data/MNIST/raw:
t10k-images-idx3-ubyte	   train-images-idx3-ubyte
t10k-images-idx3-ubyte.gz  train-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte	   train-labels-idx1-ubyte
t10k-labels-idx1-ubyte.gz  train-labels-idx1-ubyte.gz


### Upload the Data to S3 for Distributed Training Across Many Workers
We are going to use the `sagemaker.Session.upload_data` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use later when we start the training job.

This is S3 bucket and prefix that you want to use for training and model data.  This should be within the same region as the Notebook Instance, training, and hosting.

In [8]:
bucket = sagemaker_session.default_bucket()
data_prefix = 'sagemaker/pytorch-mnist/data'

In [9]:
training_data_uri = sagemaker_session.upload_data(path=local_data_path, bucket=bucket, key_prefix=data_prefix)
print('Input spec (S3 path): {}'.format(training_data_uri))

Input spec (S3 path): s3://sagemaker-us-east-1-362377691630/sagemaker/pytorch-mnist/data


In [10]:
!aws s3 ls --recursive {training_data_uri}

2020-01-09 05:51:34    7920466 sagemaker/pytorch-mnist/data/MNIST/processed/test.pt
2020-01-09 05:51:34   47520466 sagemaker/pytorch-mnist/data/MNIST/processed/training.pt
2020-01-09 05:51:34    7840016 sagemaker/pytorch-mnist/data/MNIST/raw/t10k-images-idx3-ubyte
2020-01-09 05:51:33    1648877 sagemaker/pytorch-mnist/data/MNIST/raw/t10k-images-idx3-ubyte.gz
2020-01-09 05:51:34      10008 sagemaker/pytorch-mnist/data/MNIST/raw/t10k-labels-idx1-ubyte
2020-01-09 05:51:33       4542 sagemaker/pytorch-mnist/data/MNIST/raw/t10k-labels-idx1-ubyte.gz
2020-01-09 05:51:33   47040016 sagemaker/pytorch-mnist/data/MNIST/raw/train-images-idx3-ubyte
2020-01-09 05:51:34    9912422 sagemaker/pytorch-mnist/data/MNIST/raw/train-images-idx3-ubyte.gz
2020-01-09 05:51:33      60008 sagemaker/pytorch-mnist/data/MNIST/raw/train-labels-idx1-ubyte
2020-01-09 05:51:34      28881 sagemaker/pytorch-mnist/data/MNIST/raw/train-labels-idx1-ubyte.gz


## Train
### Training Script
The `pytorch_mnist.py` script provides all the code we need for training and hosting a SageMaker model (`model_fn` function to load a model).
The training script is very similar to a training script you might run outside of SageMaker, but you can access useful properties about the training environment through various environment variables, such as:

* `SM_MODEL_DIR`: A string representing the path to the directory to write model artifacts to.
  These artifacts are uploaded to S3 for model hosting.
* `SM_NUM_GPUS`: The number of gpus available in the current container.
* `SM_CURRENT_HOST`: The name of the current container on the container network.
* `SM_HOSTS`: JSON encoded list containing all the hosts .

Supposing one input channel, 'training', was used in the call to the PyTorch estimator's `fit()` method, the following will be set, following the format `SM_CHANNEL_[channel_name]`:

* `SM_CHANNEL_TRAINING`: A string representing the path to the directory containing data in the 'training' channel.

For more information about training environment variables, please visit [SageMaker Containers](https://github.com/aws/sagemaker-containers).

A typical training script loads data from the input channels, configures training with hyperparameters, trains a model, and saves a model to `model_dir` so that it can be hosted later. Hyperparameters are passed to your script as arguments and can be retrieved with an `argparse.ArgumentParser` instance.

Because the SageMaker imports the training script, you should put your training code in a main guard (``if __name__=='__main__':``) if you are using the same script to host your model as we do in this example, so that SageMaker does not inadvertently run your training code at the wrong point in execution.

For example, the script run by this notebook:

In [11]:
!ls ./src/mnist_pytorch.py

./src/mnist_pytorch.py


You can add custom Python modules to the `src/requirements.txt` file.  They will automatically be installed - and made available to your training script.

In [12]:
!cat ./src/requirements.txt

# Python dependencies go here

### Train with SageMaker `PyTorch` Estimator

The `PyTorch` class allows us to run our training function as a training job on SageMaker infrastructure.  We need to configure it with our training script, an IAM role, the number of training instances, the training instance type, and hyperparameters.  In this case we are going to run our training job on two(2) `ml.p3.2xlarge` instances.  Alternatively, you can specify `ml.c4.xlarge` instances.  This example can be ran on one or multiple, cpu or gpu instances ([full list of available instances](https://aws.amazon.com/sagemaker/pricing/instance-types/)).  The hyperparameters parameter is a dict of values that will be passed to your training script -- you can see how to access these values in the `mnist.py` script above.

After we've constructed our `PyTorch` object, we can fit it using the data we uploaded to S3. SageMaker makes sure our data is available in the local filesystem of each worker, so our training script can simply read the data from disk.

### `fit` the Model (Approx. 15 mins)

To start a training job, we call `estimator.fit(training_data_uri)`.

In [13]:
from sagemaker.pytorch import PyTorch
import time

model_output_path = 's3://{}/sagemaker/pytorch-mnist/training-runs'.format(bucket)

mnist_estimator = PyTorch(
                  entry_point='mnist_pytorch.py',
                  source_dir='./src',
                  output_path=model_output_path,
                  role=role,
                  framework_version='1.3.1',
                  train_instance_count=1,
                  train_instance_type='ml.c5.2xlarge',
                  hyperparameters={
                    'epochs': 5,
                    'backend': 'gloo'
                  },
                  # Assuming the logline from the PyTorch training job is as follows:
                  #    Test set: Average loss: 0.3230, Accuracy: 9103/10000 (91%)
                  metric_definitions=[
                     {'Name':'test:loss', 'Regex':'Test set: Average loss: (.*?),'},
                     {'Name':'test:accuracy', 'Regex':'(.*?)%;'}
                  ]
)

mnist_estimator.fit(inputs={'training': training_data_uri},
                                        wait=False)

training_job_name = mnist_estimator.latest_training_job.name

print('training_job_name:  {}'.format(training_job_name))

training_job_name:  pytorch-training-2020-01-09-05-51-37-112


Attach to a training job to monitor the logs.

_Note:  Each instance in the training job (2 in this example) will appear as a different color in the logs.  1 color per instance._

In [14]:
mnist_estimator = PyTorch.attach(training_job_name=training_job_name)

2020-01-09 05:51:37 Starting - Starting the training job...
2020-01-09 05:51:38 Starting - Launching requested ML instances......
2020-01-09 05:52:44 Starting - Preparing the instances for training...
2020-01-09 05:53:26 Downloading - Downloading input data...
2020-01-09 05:54:06 Training - Training image download completed. Training in progress..[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2020-01-09 05:54:07,354 sagemaker-containers INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2020-01-09 05:54:07,356 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2020-01-09 05:54:07,365 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2020-01-09 05:54:08,814 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2020-01-09 05:54:09,080 sagemaker-containers INFO     Mo

[34mTest set: Average loss: 0.2076, Accuracy: 9400/10000 (94%)[0m
[0m
[34mINFO:__main__:Test set: Average loss: 0.1338, Accuracy: 9594/10000 (96%)[0m
[0m
[34mTest set: Average loss: 0.1338, Accuracy: 9594/10000 (96%)[0m
[0m
[34mINFO:__main__:Test set: Average loss: 0.1025, Accuracy: 9684/10000 (97%)[0m
[0m
[34mTest set: Average loss: 0.1025, Accuracy: 9684/10000 (97%)[0m
[0m
[34mINFO:__main__:Test set: Average loss: 0.0907, Accuracy: 9713/10000 (97%)[0m
[0m
[34mTest set: Average loss: 0.0907, Accuracy: 9713/10000 (97%)[0m
[0m

2020-01-09 05:55:35 Uploading - Uploading generated training model
2020-01-09 05:55:35 Completed - Training job completed
[34mINFO:__main__:Test set: Average loss: 0.0759, Accuracy: 9769/10000 (98%)[0m
[0m
[34mINFO:__main__:Saving the model.[0m
[34mTest set: Average loss: 0.0759, Accuracy: 9769/10000 (98%)
[0m
[34mSaving the model.[0m
[34m[2020-01-09 05:55:26.693 algo-1:49 INFO utils.py:27] The end of training job file will not be w

## Option 1:  Perform Batch Predictions Directly in the Notebook

Use PyTorch Core to load the model from `model_output_path`

In [15]:
!aws --region {region} s3 ls --recursive {model_output_path}/{training_job_name}/output/

2020-01-09 05:55:31      82012 sagemaker/pytorch-mnist/training-runs/pytorch-training-2020-01-09-05-51-37-112/output/model.tar.gz


In [16]:
!aws --region {region} s3 cp {model_output_path}/{training_job_name}/output/model.tar.gz ./model/model.tar.gz

Completed 80.1 KiB/80.1 KiB (1.0 MiB/s) with 1 file(s) remainingdownload: s3://sagemaker-us-east-1-362377691630/sagemaker/pytorch-mnist/training-runs/pytorch-training-2020-01-09-05-51-37-112/output/model.tar.gz to model/model.tar.gz


In [17]:
!ls ./model

model.tar.gz


In [18]:
!tar -xzvf ./model/model.tar.gz -C ./model

model.pth


In [19]:
# Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [20]:
import torch

loaded_model = Net().to('cpu')
# single-machine multi-gpu case or single-machine or multi-machine cpu case
loaded_model = torch.nn.DataParallel(loaded_model)
print(loaded_model)

DataParallel(
  (module): Net(
    (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
    (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
    (conv2_drop): Dropout2d(p=0.5, inplace=False)
    (fc1): Linear(in_features=320, out_features=50, bias=True)
    (fc2): Linear(in_features=50, out_features=10, bias=True)
  )
)


In [21]:
loaded_model.load_state_dict(torch.load('./model/model.pth', map_location='cpu'))

<All keys matched successfully>

In [22]:
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=256, 
    shuffle=True
)

single_loaded_img = test_loader.dataset.data[0]
single_loaded_img = single_loaded_img.to('cpu')
single_loaded_img = single_loaded_img[None, None]
single_loaded_img = single_loaded_img.type('torch.FloatTensor') # instead of DoubleTensor

print(single_loaded_img.numpy())

[[[[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.]
   [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.]
   [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.]
   [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.]
   [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.]
   [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.]
   [  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
      0.   0.

In [23]:
from matplotlib import pyplot as plt

plt.imshow(single_loaded_img.numpy().reshape(28, 28), cmap='Greys')

<matplotlib.image.AxesImage at 0x7feb2992c080>

In [24]:
result = loaded_model(single_loaded_img)
prediction = result.max(1, keepdim=True)[1][0][0].numpy()
print(prediction)

7


## Option 2:  Create a SageMaker Endpoint and Perform REST-based Predictions

### Deploy the Trained Model to a SageMaker Endpoint (Approx. 10 mins)

After training, we use the `PyTorch` estimator object to build and deploy a `PyTorchPredictor`. This creates a Sagemaker Endpoint -- a hosted prediction service that we can use to perform inference.

As mentioned above we have implementation of `model_fn` in the `pytorch_mnist.py` script that is required. We are going to use default implementations of `input_fn`, `predict_fn`, `output_fn` and `transform_fm` defined in [sagemaker-pytorch-containers](https://github.com/aws/sagemaker-pytorch-containers).

The arguments to the deploy function allow us to set the number and type of instances that will be used for the Endpoint. These do not need to be the same as the values we used for the training job. For example, you can train a model on a set of GPU-based instances, and then deploy the Endpoint to a fleet of CPU-based instances, but you need to make sure that you return or save your model as a cpu model similar to what we did in `mnist.py`.

In [None]:
predictor = mnist_estimator.deploy(initial_instance_count=1, instance_type='ml.c5.2xlarge')

-------------

### Invoke the Endpoint

We can now use this predictor to classify hand-written digits. Drawing into the image box loads the pixel data into a `data` variable in this notebook, which we can then pass to the `predictor`.

In [None]:
from IPython.display import HTML
HTML(open("input.html").read())

The value of `data` is retrieved from the HTML above.

In [None]:
print(data)

In [None]:
import numpy as np

image = np.array([data], dtype=np.float32)
response = predictor.predict(image)
prediction = response.argmax(axis=1)[0]
print(prediction)

### (Optional) Cleanup Endpoint

After you have finished with this example, remember to delete the prediction endpoint to release the instance(s) associated with it

In [None]:
sagemaker.Session().delete_endpoint(predictor.endpoint)