In [None]:
import sagemaker
session = sagemaker.Session()
bucket = session.default_bucket()
role = sagemaker.get_execution_role()

In [None]:
!aws s3 cp churn_train.csv s3://{bucket}/churn_train.csv
!aws s3 cp churn_validate.csv s3://{bucket}/churn_validate.csv

In [None]:
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.pipeline_context import PipelineSession

pipeline_session = PipelineSession()

from sagemaker.workflow.steps import TrainingStep
from sagemaker.workflow.fail_step import FailStep
from sagemaker.workflow.condition_step import ConditionStep

from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo
from sagemaker.workflow.functions import Join



In [None]:
from sagemaker.workflow.parameters import (
    ParameterInteger,
    ParameterString,
    ParameterFloat,
    ParameterBoolean
)

In [None]:
tree_max_depth_parameter = ParameterInteger(
    name='TreeMaxDepth',
    default_value=5
)
churn_validation_loss = ParameterFloat(
    name='ChurnValidationLoss',
    default_value=0.2
)

In [None]:
from sagemaker.inputs import TrainingInput

s3_input_train = TrainingInput(
    s3_data=f's3://{bucket}/churn_train.csv', content_type='csv'
)
s3_input_validate = TrainingInput(
    s3_data=f's3://{bucket}/churn_validate.csv', content_type='csv'
)

In [None]:
xgb_image = sagemaker.image_uris.retrieve('xgboost', session.boto_region_name, '1.5-1')
estimator = sagemaker.estimator.Estimator(
    xgb_image,
    role,
    instance_count=1,
    instance_type='ml.m5.large',
    output_path=f's3://{bucket}/output',
    sagemaker_session=pipeline_session,
)
estimator.set_hyperparameters(
    max_depth=tree_max_depth_parameter,
    objective='binary:logistic',
    num_round=100,
)

In [None]:
churn_training_step = TrainingStep(
    name='ChurnTrainingStep',
    step_args=estimator.fit(
        inputs={
            'train': s3_input_train,
            'validation': s3_input_validate
        },
    )
)

In [None]:
from sagemaker.model import Model
from sagemaker.workflow.model_step import ModelStep

model = Model(
    image_uri=xgb_image,
    model_data=churn_training_step.properties.ModelArtifacts.S3ModelArtifacts,
    sagemaker_session=pipeline_session,
    role=role,
)

register_args = model.register(
    content_types=['text/csv'],
    response_types=['text/csv'],
    inference_instances=['ml.m5.large'],
    transform_instances=['ml.m5.large'],
    model_package_group_name='churn-model-group',
    approval_status='PendingManualApproval'
)

register_model_step = ModelStep(
    name='ChurnRegisterModel',
    step_args=register_args,
)

In [None]:
fail_step = FailStep(
    name='ChurnFailStep',
    error_message=Join(on=' ', values=['Failed a pipeline due to log loss >= ', churn_validation_loss]),
)

In [None]:
condition_step = ConditionStep(
    name='ModelRegistrationConditionStep',
    conditions = [
        ConditionLessThanOrEqualTo(
            left=churn_training_step.properties.FinalMetricDataList['validation:logloss'].Value,
            right=churn_validation_loss
    )],
    if_steps=[register_model_step],
    else_steps=[fail_step],
)

In [None]:
pipeline = Pipeline(
    name='churn-prediction-model-pipeline',
    steps=[churn_training_step, condition_step],
    parameters=[tree_max_depth_parameter, churn_validation_loss],
)

In [None]:
pipeline.upsert(role_arn=role)

In [None]:
pipeline.start(
    execution_display_name='conditional-model-registration',
    execution_description='Starting from the SageMaker Studio'
)