In [1]:
%load_ext nb_mypy

Version 1.0.5


In [2]:
# to make imports and folder paths work
# todo: Instead create python package and install locally
import os, sys
os.chdir(
    f'{os.environ["HOME"]}/repos/sagemaker-pipelines-abstraction/src'
)

In [3]:
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

import boto3
from sagemaker.session import Session, get_execution_role
from sagemaker.processing import Processor, FrameworkProcessor
from sagemaker.base_predictor import Predictor
from sagemaker.workflow.steps import ConfigurableRetryStep, ProcessingStep
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.sklearn.estimator import SKLearn

from sm_pipelines_oo.shared_config_schema import SharedConfig

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml


## Alternative 1: Make ProcessingStepFactoryInterface *generic* in ProcessorType 
**Problem: Different types of Processors require different run_args (and potentially different extra_kwargs), which is not possible using this architecture.**

We could try to solve this by adding two more generic type variables, RunArgsType and ExtraKwargsType, but this would not only overly complicate things, but also we would have to rely on the caller to pass a matching set of these three types variables. To solve this, we could define a concrete interface in a single place for each matching set of type variables, but it's probably easier to simply use subclassing instead of generics.

In [4]:

# *General* step factory interface
# ==============================
class StepFactoryInterface(ABC):
    @abstractmethod
    def create_step(self, shared_config) -> ConfigurableRetryStep:
        ...


# Factory interfaces for *specific* step types
# =============================================
ProcessorType = TypeVar("ProcessorType", bound=Processor)

class ProcessingStepFactoryInterface(StepFactoryInterface, Generic[ProcessorType]):
    @abstractmethod
    def create_step(self, shared_config) -> ProcessingStep:
        ...

    @abstractmethod
    def processor(self) -> ProcessorType:
        ...

    @abstractmethod
    def get_processor_run_args(self, shared_config: SharedConfig) -> dict:
        # todo: improve return type
        ...

    @abstractmethod
    def get_processor_extra_kwargs(self, shared_config: SharedConfig) -> dict:
        # todo: improve return type
        ...


# Use
# ====
framework_processing_step_interface = ProcessingStepFactoryInterface[FrameworkProcessor]

## Alternative 2: Define interface for ProcessorType, and implement for each Processortype

In [5]:
# *General* step factory interface
# ==============================
from sagemaker.workflow.steps import ProcessingStep


class BaseStepFactoryInterface(ABC):
    @abstractmethod
    def create_step(self, shared_config) -> ConfigurableRetryStep:
        ...

In [6]:
# Factory interfaces for *processing* step
# =============================================

class ProcessingStepFactoryInterface(BaseStepFactoryInterface):
    @abstractmethod
    def create_step(self, shared_config: SharedConfig) -> ProcessingStep:
        ...

    @abstractmethod
    def processor(self) -> Processor:
        ...

    @abstractmethod
    def get_processor_run_args(self, shared_config: SharedConfig) -> dict:
        # todo: improve return type
        ...

    @abstractmethod
    def get_processor_extra_kwargs(self, shared_config: SharedConfig) -> dict:
        # todo: improve return type
        ...

### Static analysis catches type mismatches
Note that type checking does not work for the Sagemaker SDK by default, but I enabled it with the fix described [here](../../../design_decisions/typing_sagemaker_sdk.ipynb), which I have included in the setup defined in the [project's Makefile](../../../Makefile).

In [7]:
# Use
# ====
class FrameworkProcessingStepFactoryInterface(ProcessingStepFactoryInterface):
    def __init__(self):
        pass

    # This is not ok - return type is not of (sub)type `ProcessingStep``. Mypy catches this by default.
    def create_step(self, shared_config) -> None:
        return None

    # This is not okay - return type is not of (sub)type `Processor`.
    # However, mypy doesn't catch it by default - but it does here since I manually enabled type
    # checking for Sagemaker-sdk.
    def processor(self) -> int:
        return 0

    # This is not ok - argument type is not of (super)type `SharedConfig``. Mypy catches this by default.
    def get_processor_run_args(self, shared_config: int) -> dict:
        return {}

    # This is ok
    def get_processor_extra_kwargs(self, shared_config: SharedConfig) -> dict:
        return {}


framework_processing_step_interface = FrameworkProcessingStepFactoryInterface()
framework_processing_step_interface.processor()
print('\nSuccessfully ran cell')

<cell>8: [1m[31merror:[m Return type [m[1m"None"[m of [m[1m"create_step"[m incompatible with return type [m[1m"ProcessingStep"[m in supertype [m[1m"ProcessingStepFactoryInterface"[m  [m[33m[override][m
<cell>8: [1m[31merror:[m Return type [m[1m"None"[m of [m[1m"create_step"[m incompatible with return type [m[1m"ConfigurableRetryStep"[m in supertype [m[1m"BaseStepFactoryInterface"[m  [m[33m[override][m
<cell>14: [1m[31merror:[m Return type [m[1m"int"[m of [m[1m"processor"[m incompatible with return type [m[1m"Processor"[m in supertype [m[1m"ProcessingStepFactoryInterface"[m  [m[33m[override][m
<cell>18: [1m[31merror:[m Argument 1 of [m[1m"get_processor_run_args"[m is incompatible with supertype [m[1m"ProcessingStepFactoryInterface"[m; supertype defines the argument type as [m[1m"SharedConfig"[m  [m[33m[override][m
<cell>18: [34mnote:[m This violates the Liskov substitution principle[m
<cell>18: [34mnote:[m See [4


Successfully ran cell


### Mypy does accept type generalizations following the Liskov Substitution Principle
We are able to use more general or specific types in our implementation if they follow the Liskov Substitution Principle: Methods can return a more specific type, and method arguments can accept more general types.

In [8]:
# Use
# ====
class FrameworkProcessingStepFactoryInterface(ProcessingStepFactoryInterface):
    def __init__(self):
        pass

    def create_step(self, shared_config: SharedConfig) -> ProcessingStep:
        return ProcessingStep(name='preprocessing')

    # This is okay: *Return* more *specific* type.
    def processor(self) -> FrameworkProcessor:
        return FrameworkProcessor(
            estimator_cls=SKLearn,
            framework_version='0.23-1',
            role=get_execution_role(),
            instance_type='ml.m5.xlarge',
        )

    # This is ok: *Accept* more *general* argument type.
    def get_processor_run_args(self, shared_config: SharedConfig | dict) -> dict:
        return {}

    def get_processor_extra_kwargs(self, shared_config: SharedConfig) -> dict:
        return {}


framework_processing_step_interface = FrameworkProcessingStepFactoryInterface()
framework_processing_step_interface.processor()
print('Successfully ran cell')

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml
Successfully ran cell
