# MNIST Training using PyTorch

In [None]:
import sagemaker

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/DEMO-pytorch-mnist'

role = sagemaker.get_execution_role()

Getting the data

In [None]:
from torchvision import datasets, transforms

datasets.MNIST('data', download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))

Uploading the data to S3

In [None]:
inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)
print('input spec (in this case, just an S3 path): {}'.format(inputs))

Borrowed `mnist.py` script as it provides all the code we need for training and hosting a SageMaker model

In [None]:
!pygmentize mnist.py

Run training job

In [None]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(entry_point='mnist.py',
                    role=role,
                    framework_version='1.2.0',
                    train_instance_count=2,
                    train_instance_type='ml.c4.xlarge',
                    hyperparameters={
                        'epochs': 6,
                        'backend': 'gloo'
                    })

In [None]:
estimator.fit({'training': inputs})

In [None]:
#deploy PyTorch predictor an a single ```ml.m4.xlarge``` instance.
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

the cell below opens html page where you can draw your digit using mouse pointer or touch screen

In [None]:
from IPython.display import HTML
HTML(open("input.html").read())

In [None]:
import numpy as np

image = np.array([data], dtype=np.float32)
response = predictor.predict(image) #now we predict
prediction = response.argmax(axis=1)[0]
print(prediction)

In [None]:
estimator.delete_endpoint() #sagemaker cleanup