In [None]:
# 下载step functions的python库
!pip install stepfunctions

In [None]:
import sys
import uuid
import logging
import stepfunctions
import boto3
import sagemaker
import pandas as pd
import numpy as np
from time import gmtime, strftime, sleep         
import os
from sagemaker.amazon.amazon_estimator import get_image_uri
from sagemaker import s3
from sagemaker.s3 import S3Uploader
from stepfunctions import steps
from stepfunctions.steps import TrainingStep, ModelStep
from stepfunctions.inputs import ExecutionInput
from stepfunctions.workflow import Workflow

In [None]:
# 创建执行step functions所需的role
client_iam = boto3.client('iam')
with open('trust_policy.json', 'r') as f_obj:
    policy = f_obj.read()
client_iam.create_role(RoleName='StepFunctionsWorkflowExecutionRole',
                       AssumeRolePolicyDocument=policy)
with open('sfn_policy.json', 'r') as f_obj:
    policy = f_obj.read()
client_iam.put_role_policy(
    RoleName='StepFunctionsWorkflowExecutionRole', PolicyName='test-sfn-policy', PolicyDocument=policy)

In [None]:
# 定义sfn参数，将workflow_execution_role中的your_account_id替换成你的aws账户id
session = sagemaker.Session()
stepfunctions.set_stream_logger(level=logging.INFO)

region = boto3.Session().region_name
bucket = session.default_bucket()
hex_id = uuid.uuid4().hex
workflow_execution_role = 'arn:aws-cn:iam::your_account_id:role/StepFunctionsWorkflowExecutionRole'
sagemaker_execution_role = sagemaker.get_execution_role()

In [None]:
# 定义数据集
project_name = 'cifar-10-data'
boto3.Session().resource('s3').Bucket(bucket).Object('cifar-10-data/eval.tfrecords').upload_file('cifar-10-data/eval.tfrecords')
boto3.Session().resource('s3').Bucket(bucket).Object('cifar-10-data/train.tfrecords').upload_file('cifar-10-data/train.tfrecords')
boto3.Session().resource('s3').Bucket(bucket).Object('cifar-10-data/validation.tfrecords').upload_file('cifar-10-data/validation.tfrecords')
train_data = 's3://{}/{}/'.format(bucket, project_name)

In [None]:
# 配置AWS SageMaker Estimator
hyperparameters = {'train-steps': 100}
instance_type = 'ml.m5.large'

In [None]:
# 将image_name中的your_account_id替换成你的aws账户id
cif = sagemaker.estimator.Estimator(image_name='your_account_id.dkr.ecr.cn-northwest-1.amazonaws.com.cn/sagemaker-tf-cifar10-example:latest',
                                    role=sagemaker_execution_role, 
                                    train_instance_count=1, 
                                    train_instance_type=instance_type,
                                    hyperparameters=hyperparameters,
                                    output_path='s3://{}/{}/output'.format(bucket, project_name))

In [None]:
# 定义机器学习工作流各阶段名称占位符类型
execution_input = ExecutionInput(schema={
    'TrainingJobName': str,
    'ModelName': str,
    'EndpointName': str,
})

In [None]:
# 创建训练步骤
training_step = steps.TrainingStep(
    'Model Training', 
    estimator=cif,
    data=train_data,
    job_name=execution_input['TrainingJobName'],
    wait_for_completion=True
)

In [None]:
# 创建save model步骤
model_step = steps.ModelStep(
    'Save Model',
    model=training_step.get_expected_model(),
    model_name=execution_input['ModelName'],
    result_path='$.ModelStepResults'
)

In [None]:
# 创建终端节点配置步骤
endpoint_config_step = steps.EndpointConfigStep(
    "Create Model Endpoint Config",
    endpoint_config_name=execution_input['ModelName'],
    model_name=execution_input['ModelName'],
    initial_instance_count=1,
    instance_type='ml.m5.xlarge'
)

In [None]:
# 更新终端节点
endpoint_step = steps.EndpointStep(
    'Update Model Endpoint',
    endpoint_name=execution_input['EndpointName'],
    endpoint_config_name=execution_input['ModelName'],
    update=False
)

In [None]:
# 连接所有步骤
workflow_definition = steps.Chain([
    training_step,
    model_step,
    endpoint_config_step,
    endpoint_step
])

In [None]:
# 定义工作流
workflow = Workflow(
    name='MyBYOC_{}'.format(uuid.uuid4().hex),
    definition=workflow_definition,
    role=workflow_execution_role,
    execution_input=execution_input
)

In [None]:
workflow.create()