In [1]:
%load_ext nb_mypy

Version 1.0.5


In [2]:
# to make imports and folder paths work
# todo: Instead create python package and install locally
import os, sys
os.chdir(
    f'{os.environ["HOME"]}/repos/sagemaker-pipelines-abstraction/src'
)

In [3]:
from abc import ABC, abstractmethod
from typing import Generic, TypeVar, TypedDict, Any
from pathlib import Path

from pydantic_settings import BaseSettings
import boto3
from sagemaker.processing import Processor, FrameworkProcessor
from sagemaker.base_predictor import Predictor
from sagemaker.workflow.steps import ConfigurableRetryStep, ProcessingStep
from sagemaker.workflow.pipeline_context import _JobStepArguments
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.sklearn.estimator import SKLearn

from sm_pipelines_oo.shared_config_schema import SharedConfig
from sm_pipelines_oo.steps.step_utils import PathFactory
from sm_pipelines_oo.utils import load_pydantic_config_from_file
from sm_pipelines_oo.pipeline_wrapper import AWSConnectorInterface

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml


## Alternative 1: Make ProcessingStepFactoryInterface *generic* in ProcessorType 
**Problem: Different types of Processors require different run_args (and potentially different extra_kwargs), which is not possible using this architecture.**

We could try to solve this by adding two more generic type variables, RunArgsType and ExtraKwargsType, but this would not only overly complicate things, but also we would have to rely on the caller to pass a matching set of these three types variables. To solve this, we could define a concrete interface in a single place for each matching set of type variables, but it's probably easier to simply use subclassing instead of generics.

In [4]:

# *General* step factory interface
# ==============================
class StepFactoryInterface(ABC):
    @abstractmethod
    def create_step(self, shared_config) -> ConfigurableRetryStep:
        ...


# Factory interfaces for *specific* step types
# =============================================
ProcessorType = TypeVar("ProcessorType", bound=Processor)

class ProcessingStepFactoryInterface(StepFactoryInterface, Generic[ProcessorType]):
    @abstractmethod
    def create_step(self, shared_config) -> ProcessingStep:
        ...

    @abstractmethod
    def processor(self) -> ProcessorType:
        ...

    @abstractmethod
    def get_processor_run_args(self, shared_config: SharedConfig) -> dict:
        # todo: improve return type
        ...

    @abstractmethod
    def get_processor_extra_kwargs(self, shared_config: SharedConfig) -> dict:
        # todo: improve return type
        ...


# Use
# ====
framework_processing_step = ProcessingStepFactoryInterface[FrameworkProcessor]

## Alternative 2: Define interface for ProcessorType, and implement for each specific kind of processor

In [5]:
# *General* step factory interface
# ==============================
from sagemaker.workflow.steps import ProcessingStep


class BaseStepFactoryInterface(ABC):
    @abstractmethod
    def create_step(self, shared_config) -> ConfigurableRetryStep:
        ...

In [6]:
# Factory interfaces for *processing* step
# =============================================

class ProcessingStepFactoryInterface(BaseStepFactoryInterface):
    @abstractmethod
    def create_step(self, shared_config: SharedConfig) -> ProcessingStep:
        ...

    @abstractmethod
    def processor(self) -> Processor:
        ...

    @abstractmethod
    def get_processor_run_args(self) -> dict:
        # todo: improve return type
        ...

    @abstractmethod
    def get_processor_extra_kwargs(self, shared_config: SharedConfig) -> dict:
        # todo: improve return type
        ...

### Static analysis catches type mismatches
Note that type checking does not work for the Sagemaker SDK by default, but I enabled it with the fix described [here](../../../design_decisions/typing_sagemaker_sdk.ipynb), which I have included in the setup defined in the [project's Makefile](../../../Makefile).

In [7]:
# Use
# ====
class FrameworkProcessingStepFactoryInterface(ProcessingStepFactoryInterface):
    def __init__(self):
        pass

    # This is not ok - return type is not of (sub)type `ProcessingStep``. Mypy catches this by default.
    def create_step(self, shared_config) -> None:
        return None

    # This is not okay - return type is not of (sub)type `Processor`.
    # However, mypy doesn't catch it by default - but it does here since I manually enabled type
    # checking for Sagemaker-sdk.
    def processor(self) -> int:
        return 0

    # This is not ok - argument type is not of (super)type `SharedConfig``. Mypy catches this by default.
    def get_processor_run_args(self, shared_config: int) -> dict:
        return {}

    # This is ok
    def get_processor_extra_kwargs(self, shared_config: SharedConfig) -> dict:
        return {}


framework_processing_step = FrameworkProcessingStepFactoryInterface()
framework_processing_step.processor()
print('\nSuccessfully ran cell')

<cell>8: [1m[31merror:[m Return type [m[1m"None"[m of [m[1m"create_step"[m incompatible with return type [m[1m"ProcessingStep"[m in supertype [m[1m"ProcessingStepFactoryInterface"[m  [m[33m[override][m
<cell>8: [1m[31merror:[m Return type [m[1m"None"[m of [m[1m"create_step"[m incompatible with return type [m[1m"ConfigurableRetryStep"[m in supertype [m[1m"BaseStepFactoryInterface"[m  [m[33m[override][m
<cell>14: [1m[31merror:[m Return type [m[1m"int"[m of [m[1m"processor"[m incompatible with return type [m[1m"Processor"[m in supertype [m[1m"ProcessingStepFactoryInterface"[m  [m[33m[override][m
<cell>18: [1m[31merror:[m Signature of [m[1m"get_processor_run_args"[m incompatible with supertype [m[1m"ProcessingStepFactoryInterface"[m  [m[33m[override][m
<cell>18: [34mnote:[m      Superclass:[m
<cell>18: [34mnote:[m          def get_processor_run_args(self) -> dict[Any, Any][m
<cell>18: [34mnote:[m      Subclass:[m
<c


Successfully ran cell


# Mypy does accept type generalizations following the Liskov Substitution Principle
We are able to use more general or specific types in our implementation if they follow the Liskov Substitution Principle: Methods can *accept more general types* as arguments and *return a more specific types*.

In [8]:
# Use
# ====
class FrameworkProcessingStepFactoryInterface(ProcessingStepFactoryInterface):
    def __init__(self):
        pass

    def create_step(self, shared_config: SharedConfig) -> ProcessingStep:
        return ProcessingStep(name='preprocessing')

    # This is okay: *Return* more *specific* type.
    def processor(self) -> FrameworkProcessor:
        return FrameworkProcessor(
            estimator_cls=SKLearn,
            framework_version='0.23-1',
            role='dummy_role',
            instance_type='ml.m5.xlarge',
        )

    # This is ok: *Accept* more *general* argument type.
    def get_processor_run_args(self, shared_config: SharedConfig | dict) -> dict:
        return {}

    def get_processor_extra_kwargs(self, shared_config: SharedConfig) -> dict:
        return {}


framework_processing_step = FrameworkProcessingStepFactoryInterface()
framework_processing_step.processor()
print('\nSuccessfully ran cell')

<cell>20: [1m[31merror:[m Signature of [m[1m"get_processor_run_args"[m incompatible with supertype [m[1m"ProcessingStepFactoryInterface"[m  [m[33m[override][m
<cell>20: [34mnote:[m      Superclass:[m
<cell>20: [34mnote:[m          def get_processor_run_args(self) -> dict[Any, Any][m
<cell>20: [34mnote:[m      Subclass:[m
<cell>20: [34mnote:[m          def get_processor_run_args(self, shared_config: SharedConfig | dict[Any, Any]) -> dict[Any, Any][m


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml

Successfully ran cell


# How to incorporate need for different configs?

In [None]:
# Run Args
# ========
class ProcessorRunArgs(TypedDict):
    inputs: list[ProcessingInput]
    outputs: list[ProcessingOutput]
    arguments: list[str] | None

class FrameworkProcessorRunArgs(ProcessorRunArgs):
    # Additional args for FrameworkProcessor:
    source_dir: str
    code: str


In [9]:
# Redefining interfaces without shared config, remove method to get extra kwargs
from sagemaker.workflow.steps import ProcessingStep


class BaseStepFactoryInterface(ABC):
    @abstractmethod
    def create_step(self) -> ConfigurableRetryStep:
        ...
class ProcessingStepFactoryInterface(BaseStepFactoryInterface):
    @abstractmethod
    def create_step(self) -> ProcessingStep:
        ...

    @abstractmethod
    def processor(self) -> Processor:
        ...

    @abstractmethod
    def get_processor_run_args(self) -> ProcessorRunArgs:
        ...

In [1]:
# Different Configs
# =================
class StepConfig(BaseSettings):
    input_filename: str
    output_filename: str
    output_train_filename: str
    output_val_filename: str
    output_test_filename: str
    instance_type: str
    instance_count: int
    step_name: str

class ProcessingConfig(StepConfig):
    """
    This class provides the schema for the step-specific config file.
    It is passed to step factory in the latter's constructor.
    """
    sklearn_framework_version: str
    # Override default field with more specific filenames
    output_filename: None = None  # type: ignore[assignment]
    output_train_filename: str
    output_val_filename: str
    output_test_filename: str

# class FrameworkProcessingConfig(ProcessingConfig):
#     """
#     So far no extra configs needed. (While it would be nice if we could set `estimator_cls=SKLearn`
#     in the config file, but we would have to use `eval()` to construct a python object from the
#     string, which is a potential security vulnerability.)
#     """
#     ...

# processing_config: FrameworkProcessingConfig = load_pydantic_config_from_file(  # type: ignore
#     config_cls=FrameworkProcessingConfig,
#     config_path="sm_pipelines_oo/configs/dev/.env_pre_process",
# )
# print(processing_config)

In [11]:
from typing import Literal

class FrameworkProcessingConfig(BaseSettings):
    estimator_cls: Literal['SKLearn']  # todo: add other available estimators

In [13]:
from functools import cached_property

# Realistic implementation
# ========================
class FrameworkProcessingStepFactory(ProcessingStepFactoryInterface):
    def __init__(
        self,
        processing_config: ProcessingConfig,  # subtype uses more *specific arg* type
        shared_config: SharedConfig,
        fw_processing_config: FrameworkProcessingConfig,
        aws_connector: AWSConnectorInterface,
    ):
        self.aws_connector = aws_connector
        self.step_config = processing_config
        self.shared_config = shared_config
        self.frameworkprocessor_config = fw_processing_config

        # This determines how to construct the estimator object from the string in the config file, avoiding the
        # use of `eval`, which is a potential security vulnerability.
        self._str_to_cls_mapping: dict[str, Any] = {  # todo:  find supertype
            'SKLearn': SKLearn,
        }

    def support_additional_estimators(self, additional_estimator_mapping: dict[str, Any]) -> None:
        """Allow user to add additional estimators (following the open-closed principle)."""
        self._str_to_cls_mapping.update(additional_estimator_mapping)


    @cached_property
    def _estimator_cls(self):
        class_as_string = self.frameworkprocessor_config.estimator_cls
        return self._str_to_cls_mapping[class_as_string]

    @cached_property
    def _path_factory(self) -> PathFactory:
        return PathFactory(
            # todo: separate general, step, and specific configs
            step_config=self.step_config,  # type: ignore
            shared_config=self.shared_config
        )

    def get_processor_run_args(self) -> FrameworkProcessorRunArgs:
        s3_input_folder: str = self._path_factory.s3_input_folder
        s3_output_folder: str = self._path_factory.s3_output_folder
        local_folderpath: str = self._path_factory.local_folderpath

        skl_run_args = FrameworkProcessorRunArgs(
            inputs = [
                ProcessingInput(
                    source=s3_input_folder,
                    destination=f"{local_folderpath}/input/"
                ),
            ],
            outputs = [
                ProcessingOutput(
                    output_name="train",
                    source=f"/{local_folderpath}/train",
                    destination=f"{s3_output_folder}/train",
                ),
                ProcessingOutput(
                    output_name="validation",
                    source=f"/{local_folderpath}/validation",
                    destination=f"{s3_output_folder}/validation",
                ),
                ProcessingOutput(
                    output_name="test",
                    source=f"/{local_folderpath}/test",
                    destination=f"{s3_output_folder}/test",
                ),
            ],
            source_dir=self._path_factory.source_dir,
            code=self._path_factory.step_code_file,
            arguments=None # Todo: Decide whether this should come from configuration. May depend on type of step.
        )
        return skl_run_args

    def create_step(self) -> ProcessingStep:
        """
        Note that this can only be run from the PipelineWrapper, because this factory does not have
        access to the shared configs.
        """
        run_args: ProcessorRunArgs = self.get_processor_run_args()
        step_args: _JobStepArguments = self.processor.run(**run_args)
        return ProcessingStep(
            name=self.step_config.step_name,
            step_args=step_args,  # type: ignore
        )

    def processor(self) -> FrameworkProcessor:
        return FrameworkProcessor(
            framework_version=self.step_config.sklearn_framework_version,
            instance_type=self.step_config.instance_type,
            instance_count=self.step_config.instance_count,
            base_job_name=self.step_config.step_name,
            sagemaker_session=self.aws_connector.sm_session,
            role=self.aws_connector.role_arn,
            estimator_cls=self._estimator_cls,
        )

<cell>42: [1m[31merror:[m Return type [m[1m"ProcessorRunArgs"[m of [m[1m"get_processor_run_args"[m incompatible with return type [m[1m"dict[Any, Any]"[m in supertype [m[1m"ProcessingStepFactoryInterface"[m  [m[33m[override][m
<cell>83: [1m[31merror:[m [m[1m"Callable[[], FrameworkProcessor]"[m has no attribute [m[1m"run"[m  [m[33m[attr-defined][m


In [14]:
from sm_pipelines_oo.aws_connector.implementation import create_aws_connector


class SharedConfig(BaseSettings):  # type: ignore
    """Defines configuration shared by all pipeline steps (for a given environment)."""
    project_name: str
    project_version: str  # Versions data (and probably more in the future)
    region: str
    # To do: consider which of these fields should be made required.
    role_name: str | None = None
    project_bucket_name: str

shared_config: SharedConfig = SharedConfig(
    project_name='design-decisions',
    project_version='0.0',  # Versions data (and probably more in the future)
    region='us-east-1',
    # To do: consider which of these fields should be made required.
    role_name=None,
    project_bucket_name='design-decisions',
)

class ProcessingConfig(BaseSettings):  # type: ignore
    input_filename: str
    instance_type: str
    instance_count: int
    sklearn_framework_version: str
    # Don't set in config. This needs to correspond to SM's convention for local folder structure.
    # todo: Make this not set-able. Use property instead?
    step_type: Literal['processing'] = "processing"
    step_name: str = "processing"

processing_config = ProcessingConfig(
    input_filename='input.parquet',
    instance_type='local',
    instance_count=1,
    sklearn_framework_version='0.23-1',
)

aws_connector: AWSConnectorInterface = create_aws_connector(
    environment='dev',
    shared_config=shared_config,
    run_as_pipeline=True,
)


In [15]:
framework_processing_step = FrameworkProcessingStepFactory(
    # processor_cls=FrameworkProcessor,
    processing_config=processing_config,
    fw_processing_config=FrameworkProcessingConfig(estimator_cls='SKLearn'),
    shared_config=shared_config,
    aws_connector=aws_connector,
)

framework_processing_step.processor()
print('\nSuccessfully ran cell')

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml


[32m2023-12-24 13:30:51.596[0m | [34m[1mDEBUG   [0m | [36msm_pipelines_oo.aws_connector.implementation[0m:[36mrole_arn[0m:[36m86[0m - [34m[1mrole: arn:aws:iam::338755209567:role/aws-reserved/sso.amazonaws.com/AWSReservedSSO_AdministratorAccess_7b40736629c71dd9[0m


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml

Successfully ran cell


In [None]:
from functools import cached_property
from datetime import datetime

from loguru import logger
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.steps import Step

from sm_pipelines_oo.shared_config_schema import SharedConfig, Environment
from sm_pipelines_oo.steps.interfaces import StepFactoryInterface
from sm_pipelines_oo.aws_connector.interface import AWSConnectorInterface


class PipelineWrapper:
    """
    Holds configs & code
    """

    def __init__(
        self,
        step_factories: dict[StepFactoryInterface, Any],
        environment: Environment,
        shared_config: SharedConfig,
    ) -> None:
        self.environment = environment
        self.shared_config = shared_config

        # Other setup
        self.steps: list[Step] = []
        self._create_steps(step_factories, shared_config)
        self._aws_connector: AWSConnectorInterface = create_aws_connector(
            environment=environment,
            shared_config=shared_config,
            run_as_pipeline=True,
        )


    def _create_steps(self, step_factories: list[StepFactoryInterface], shared_config: SharedConfig) -> None:
        for factory, configs in step_factories.items():
            step: Step = factory.create_step(**configs)
            self.steps.append(step)

    @cached_property
    def _pipeline(self) -> Pipeline:
        pipeline_name = f'{self.shared_config.project_name}-{datetime.now():%Y-%m-%d-%H-%M-%S}'
        pipeline = Pipeline(
            name=pipeline_name,
            steps=self.steps,
            sagemaker_session=self._aws_connector.sm_session,
        )
        pipeline.create(role_arn=self._aws_connector.role_arn)
        return pipeline


    # Public methods
    # ==============

    def run(self) -> None:
        logger.info(f"Starting pipeline run for project {self.shared_config.project_name}")
        execution = self._pipeline.start()
        execution.wait()
        execution.list_steps()


In [None]:
processing_config = ProcessingConfig(
    input_filename='input.parquet',
    instance_type='local',
    instance_count=1,
    sklearn_framework_version='0.23-1',
)
fw_processing_config = FrameworkProcessingConfig(estimator_cls='SKLearn')

pipeline = PipelineWrapper(
    step_factories={
        FrameworkProcessingStepFactory: [processing_config, fw_processing_config],
    },
    environment=ENVIRONMENT,
    shared_config=shared_config,
)
try:
    pipeline.run()
except Exception as e:
    logger.error(e)
