# Using script mode training
This notebook shows how to train a Keras Sequential model on SageMaker. The model used for this notebook is a simple deep CNN that was extracted from [the Keras examples](https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py).

## The dataset
The [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) is one of the most popular machine learning datasets. It consists of 60,000 32x32 images belonging to 10 different classes (6,000 images per class). Here are the classes in the dataset, as well as 10 random images from each:

![cifar10](https://maet3608.github.io/nuts-ml/_images/cifar10.png)

In this tutorial, we will train a deep CNN to recognize these images.


## Set up the environment

In [None]:
import os
import sagemaker
from sagemaker import get_execution_role
import boto3
sagemaker_session = sagemaker.Session()

role = get_execution_role()

In [None]:
bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/script-mode'

## Download the CIFAR-10 dataset
Downloading the test and training data takes around 5 minutes.

In [None]:
!pip install wget

In [None]:
# import wget # for TF2
!mkdir data
!python generate_cifar10_tfrecords_v2.py --data-dir data/

## Run on SageMaker cloud

### Uploading the data to s3

In [None]:
dataset_location = sagemaker_session.upload_data(bucket=bucket, path='data', key_prefix=prefix)
display(dataset_location)

### Configuring metrics from the job logs
SageMaker can get training metrics directly from the logs and send them to CloudWatch metrics.

In [None]:
keras_metric_definition = [
    {'Name': 'train:loss', 'Regex': '.*loss: ([0-9\\.]+) - acc: [0-9\\.]+.*'},
    {'Name': 'train:accuracy', 'Regex': '.*loss: [0-9\\.]+ - acc: ([0-9\\.]+).*'},
    {'Name': 'validation:accuracy', 'Regex': '.*step - loss: [0-9\\.]+ - acc: [0-9\\.]+ - val_loss: [0-9\\.]+ - val_acc: ([0-9\\.]+).*'},
    {'Name': 'validation:loss', 'Regex': '.*step - loss: [0-9\\.]+ - acc: [0-9\\.]+ - val_loss: ([0-9\\.]+) - val_acc: [0-9\\.]+.*'},
    {'Name': 'sec/steps', 'Regex': '.* - \d+s (\d+)[mu]s/step - loss: [0-9\\.]+ - acc: [0-9\\.]+ - val_loss: [0-9\\.]+ - val_acc: [0-9\\.]+'}
]

### Train image classification based on the cifar10 dataset

In [None]:
hyperparameters = {'epochs': 10, 'batch-size' : 256}

In [None]:
from sagemaker.tensorflow import TensorFlow


source_dir = os.path.join(os.getcwd(), 'source_dir')
estimator = TensorFlow(base_job_name='cifar10-tf',
                       entry_point='cifar10_keras_main.py',
                       source_dir=source_dir,
                       role=role,
                       framework_version='1.15.2',
                       py_version='py3',
                       hyperparameters=hyperparameters,
                       instance_count=1, 
                       instance_type='ml.p3.2xlarge',
                       metric_definitions=keras_metric_definition)

remote_inputs = {'train' : dataset_location+'/train', 'validation' : dataset_location+'/validation', 'eval' : dataset_location+'/eval'}
estimator.fit(remote_inputs, wait=True)

### View the job training metrics
SageMaker used the regular expression configured above, to send the job metrics to CloudWatch metrics.
You can also view the job metrics directly from the SageMaker Studio . On the left side bar select the SageMaker Experiment List, right click on  _Unassigned trial components_, open in trial component list, choose the latest training job, open in trial details, and now you can see all the metrics that you defined to sent to Cloud watch.   
You can also use CloudWatch metrics, where you can change the period and configure the statistics.

In [None]:
from IPython.core.display import Markdown

link = 'https://console.aws.amazon.com/cloudwatch/home?region='+sagemaker_session.boto_region_name+'#metricsV2:query=%7B/aws/sagemaker/TrainingJobs,TrainingJobName%7D%20'+estimator.latest_training_job.job_name
display(Markdown('CloudWatch metrics: [link]('+link+')'))
display(Markdown('After you choose a metric, change the period to 1 Minute (Graphed Metrics -> Period)'))