# Automate model training with Step Functions Data Science SDK

## Setup environment

In [59]:
import os
import sagemaker
from sagemaker.pytorch import PyTorch as PyTorchEstimator
from sagemaker.tuner import IntegerParameter, CategoricalParameter, ContinuousParameter, HyperparameterTuner

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
bucket = sagemaker_session.default_bucket()

## Define data inputs from S3

In [60]:
# Replace with your S3 dataset path
inputs = {'train': 's3://sagemaker-us-east-1-175748383800/data-processing-2020-06-26-21-44-08-917/output/preprocessed/'}
print(inputs)

{'train': 's3://sagemaker-us-east-1-175748383800/data-processing-2020-06-26-21-44-08-917/output/preprocessed/'}


## Train

In [61]:
hyperparameters={
        "model_name":"roberta-base",
        "data_folder": '/opt/ml/input/data/train',
        "output_folder": '/opt/ml/model',
        "epochs": 2,
        "learning_rate": 2e-5,
        "batch_size": 64,
        "seed": 42,
        "max_len": 160
    }

metric_definitions = [{'Name': 'validation_accuracy',
                       'Regex': 'val_accuracy: ([0-9\\.]+)'}]

In [62]:
estimator = PyTorchEstimator(
    entry_point='run_training.py',
    source_dir='source_dir',
    role=role,
    train_instance_count=1,
    train_instance_type='ml.p3.2xlarge',
    train_volume_size=50,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    framework_version='1.5.0',
    py_version='py3',
)

## Step function SDK

In [63]:
import sys
!{sys.executable} -m pip install --upgrade stepfunctions

Requirement already up-to-date: stepfunctions in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (1.0.0.9)
Requirement not upgraded as not directly required: sagemaker>=1.42.8 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from stepfunctions) (1.55.3)
Requirement not upgraded as not directly required: boto3>=1.9.213 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from stepfunctions) (1.12.39)
Requirement not upgraded as not directly required: pyyaml in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from stepfunctions) (5.3.1)
Requirement not upgraded as not directly required: protobuf3-to-dict>=0.1.5 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from sagemaker>=1.42.8->stepfunctions) (0.1.5)
Requirement not upgraded as not directly required: smdebug-rulesconfig==0.1.2 in /home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages (from sagemaker>=1.42

In [64]:
import logging
import stepfunctions
from stepfunctions.template.pipeline import TrainingPipeline
stepfunctions.set_stream_logger(level=logging.INFO)

workflow_execution_role = role

In [65]:
pipeline = TrainingPipeline(
    estimator=estimator,
    role=workflow_execution_role,
    inputs=inputs,
    s3_bucket=bucket
)

In [66]:
print(pipeline.workflow.definition.to_json(pretty=True))

{
    "StartAt": "Training",
    "States": {
        "Training": {
            "Resource": "arn:aws:states:::sagemaker:createTrainingJob.sync",
            "Parameters": {
                "AlgorithmSpecification.$": "$$.Execution.Input['Training'].AlgorithmSpecification",
                "OutputDataConfig.$": "$$.Execution.Input['Training'].OutputDataConfig",
                "StoppingCondition.$": "$$.Execution.Input['Training'].StoppingCondition",
                "ResourceConfig.$": "$$.Execution.Input['Training'].ResourceConfig",
                "RoleArn.$": "$$.Execution.Input['Training'].RoleArn",
                "InputDataConfig.$": "$$.Execution.Input['Training'].InputDataConfig",
                "HyperParameters.$": "$$.Execution.Input['Training'].HyperParameters",
                "TrainingJobName.$": "$$.Execution.Input['Training'].TrainingJobName",
                "DebugHookConfig.$": "$$.Execution.Input['Training'].DebugHookConfig"
            },
            "Type": "Task",
 

In [67]:
pipeline.render_graph()

In [68]:
pipeline.create()

[32m[INFO] Workflow created successfully on AWS Step Functions.[0m


'arn:aws:states:us-east-1:175748383800:stateMachine:training-pipeline-2020-06-28-17-21-34'

In [69]:
pipeline.execute()

[32m[INFO] Workflow execution started successfully on AWS Step Functions.[0m
