In [1]:
%load_ext autoreload
%autoreload 2

from sagemaker.processing import NetworkConfig
from sagemaker.workflow.parameters import ParameterString
from sagemaker_rightline.model import Configuration
from sagemaker_rightline.validations import (
    StepImagesExistOnEcr,
    StepKmsKeyId,
    PipelineParameters,
    ContainerImage,
    StepNetworkConfig,
)
from sagemaker_rightline.rules import Equals, Contains
from tests.utils import create_image
from tests.fixtures.pipeline import get_sagemaker_pipeline
from tests.fixtures.image_details import IMAGE_1_URI, IMAGE_2_URI
from moto import mock_ecr
import boto3

In [19]:
from sagemaker.processing import NetworkConfig, ScriptProcessor
from sagemaker.spark.processing import PySparkProcessor
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.steps import ProcessingInput, ProcessingOutput, ProcessingStep, TrainingStep
from sagemaker.sklearn.estimator import SKLearn

from tests.fixtures.constants import TEST_ROLE_ARN
from tests.fixtures.image_details import IMAGE_1_URI, IMAGE_2_URI


def arn_formatter(_type: str, _id: str, account_id: str, region_name: str) -> str:
    return f"arn:aws:sagemaker:{region_name}:{account_id}:{_type}/{_id}"


def get_sagemaker_pipeline(
    script_path: str = "tests/fixtures/fake_processing_script.py",
) -> Pipeline:
    network_config = NetworkConfig(
            enable_network_isolation=True,
            security_group_ids=["sg-12345"],
            subnets=["subnet-12345"],
            encrypt_inter_container_traffic=True,
        )

    sm_processor_sklearn = ScriptProcessor(
        base_job_name="sm_processor",
        role=TEST_ROLE_ARN,
        image_uri=IMAGE_1_URI,
        network_config=network_config,
    )
    sm_processor_spark = PySparkProcessor(
        base_job_name="sm_processors",
        role=TEST_ROLE_ARN,
        image_uri=IMAGE_2_URI,
        instance_type="ml.m5.xlarge",
        instance_count=2,
    )
    dummy_bucket = "dummy-bucket"

    sm_processing_step_sklearn = ProcessingStep(
        name="sm_processing_step_sklearn",
        code=script_path,
        processor=sm_processor_sklearn,
        kms_key="some/kms-key-alias",
        inputs=[
            ProcessingInput(
                source=f"s3://{dummy_bucket}/input-1",
                destination="/opt/ml/processing/input",
                input_name="input-1",
            ),
            ProcessingInput(
                source=f"s3://{dummy_bucket}/input-2",
                destination="/opt/ml/processing/input",
                input_name="input-2",
            ),
        ],
        outputs=[
            ProcessingOutput(
                output_name="output-1",
                source="/opt/ml/processing/output/1",
                destination=f"s3://{dummy_bucket}/output-1",
            ),
            ProcessingOutput(
                output_name="output-2",
                source="/opt/ml/processing/output/2",
                destination=f"s3://{dummy_bucket}/output-2",
            ),
        ],
    )

    sm_processing_step_spark = ProcessingStep(
        name="sm_processing_step_spark",
        code=script_path,
        processor=sm_processor_spark,
        kms_key="some/kms-key-alias",
        inputs=[
            ProcessingInput(
                source=f"s3://{dummy_bucket}/output-1",
                destination="/opt/ml/processing/input",
                input_name="input-1",
            ),
            ProcessingInput(
                source=f"s3://{dummy_bucket}/output-2",
                destination="/opt/ml/processing/input",
                input_name="input-2",
            ),
        ],
        outputs=[
            ProcessingOutput(
                output_name="output-1",
                source="/opt/ml/processing/output/1",
                destination=f"s3://{dummy_bucket}/output-3",
            ),
            ProcessingOutput(
                output_name="output-2",
                source="/opt/ml/processing/output/2",
                destination=f"s3://{dummy_bucket}/output-4",
            ),
        ],
        depends_on=[sm_processing_step_sklearn.name],
    )

    sm_trainer_sklearn = SKLearn(
        entry_point=script_path,
        role=TEST_ROLE_ARN,
        image_uri=IMAGE_1_URI,
        instance_type="ml.c4.xlarge",
        output_kms_key="some/kms-key-alias",
        enable_network_isolation=network_config.enable_network_isolation,
        security_group_ids=network_config.security_group_ids,
        subnets=network_config.subnets,
        encrypt_inter_container_traffic=network_config.encrypt_inter_container_traffic,
    )
    sm_training_step_sklearn = TrainingStep(
        name="sm_training_step_sklearn",
        estimator=sm_trainer_sklearn,
        inputs={
            "train": f"s3://{dummy_bucket}/output-3",
            "test": f"s3://{dummy_bucket}/output-4",
        },
        depends_on=[sm_processing_step_spark.name],
    )

    sm_pipeline = Pipeline(
        name="dummy-pipeline",
        steps=[sm_processing_step_sklearn, sm_processing_step_spark, sm_training_step_sklearn],
        parameters=[
            ParameterString(
                name="parameter-1",
                default_value="some-value-1",
            ),
            ParameterString(
                name="parameter-2",
                default_value="some-value-2",
            ),
        ],
    )
    return sm_pipeline


## Fetching SM Pipeline

In [20]:
sm_pipeline = get_sagemaker_pipeline(script_path="../tests/fixtures/fake_processing_script.py")

## Defining function that validates the SM Pipeline

Note: In order to run this example offline/without AWS access, we're mocking ECR and publishing two mock images to it.

In [4]:
@mock_ecr
@create_image(
    boto3.client("ecr"), [ContainerImage(uri=IMAGE_1_URI), ContainerImage(uri=IMAGE_2_URI)]
)
def validate_sm_pipeline(sm_pipeline: "sagemaker.workflow.pipeline.Pipeline") -> "pandas.DataFrame":
    validations = [
        StepImagesExistOnEcr(),
        PipelineParameters(
            parameters_expected=[
                ParameterString(
                    name="parameter-1",
                    default_value="some-value",
                ),
            ],
            rule=Contains(),
        ),
        StepKmsKeyId(
            kms_key_id_expected="some/kms-key-alias",
            step_name="output-1",  # optional: if not set, will check all steps [applies to all Step* validations
            rule=Equals(),
        ),
        StepNetworkConfig(
            network_config_expected=NetworkConfig(
                enable_network_isolation=False,
                security_group_ids=["sg-1234567890"],
                subnets=["subnet-1234567890"],
            ),
            rule=Equals(),
        )
    ]
    cm = Configuration(
        validations=validations,
        sagemaker_pipeline=sm_pipeline,
    )
    return cm.run(return_df=True)

In [15]:
a = NetworkConfig()

In [26]:
sm_pipeline.steps[2].estimator.enable_network_isolation()

True

## Running the validation and printing the report

In [5]:
report = validate_sm_pipeline(sm_pipeline)

In [6]:
report

Unnamed: 0,validation_name,subject,success,message
0,StepImagesExistOnEcr,['123456789012.dkr.ecr.eu-west-1.amazonaws.com...,True,Images ['123456789012.dkr.ecr.eu-west-1.amazon...
1,PipelineParameters,"[ParameterString(name='parameter-1', parameter...",False,"[ParameterString(name='parameter-1', parameter..."
2,StepKmsKeyId,['some/kms-key-alias'],True,['some/kms-key-alias'] does equal ['some/kms-k...
3,StepNetworkConfig,"[{'enable_network_isolation': False, 'security...",False,[] does not equal [{'enable_network_isolation'...


In [7]:
report

Unnamed: 0,validation_name,subject,success,message
0,StepImagesExistOnEcr,['123456789012.dkr.ecr.eu-west-1.amazonaws.com...,True,Images ['123456789012.dkr.ecr.eu-west-1.amazon...
1,PipelineParameters,"[ParameterString(name='parameter-1', parameter...",False,"[ParameterString(name='parameter-1', parameter..."
2,StepKmsKeyId,['some/kms-key-alias'],True,['some/kms-key-alias'] does equal ['some/kms-k...
3,StepNetworkConfig,"[{'enable_network_isolation': False, 'security...",False,[] does not equal [{'enable_network_isolation'...
