In [2]:
%load_ext nb_mypy

Version 1.0.5


In [3]:
from typing import  ClassVar, TypeVar, TypeAlias, Any, final
from abc import ABC, abstractmethod

from sagemaker.processing import Processor
from sagemaker.estimator import EstimatorBase

from sagemaker.workflow.steps import ProcessingStep, TrainingStep, CreateModelStep, TransformStep, \
    TuningStep, ConfigurableRetryStep


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml


## Why no generics?
- The end goal is simply to have an object that satisfies the (ConfigurableRetry)StepInterface. From the perspective of the pipeline, we don't care what type of step it is.
- The initial reason for looking into leveraging generics was for making sure that we are passing the right config for a given type of step. However, after a lot of trial and error, I still did not find a good way to create a simple class hierarchy based on what the Sagemaker SDK makes available to us. Instead, it looks more promising to simply create a very minimal interface for step factories, and later specific implementations decide what the best way to create that kind of step is.
  - Downside: let's reuse of code between different step factories. This makes it somewhat harder to get started with creating new step factories, because there is less structure imposed for how exactly to do it.
  - Upside: more flexibility for grading step factories. This may actually make it easier to create new step factories, and it will make it easier to maintain given step factories as the interface of the Sagemaker SDK changes.
  - Note: neither of these points will affect a basic library user who only uses inbuilt step factories.

# Simple factory
This makes better use of factory, because depending on arg passed to it, it creates a different type of step. Otherwise, we may as well us strategy pattern (only use of factory is to construct step later when configs etc are known - but a given factory always produces same kind of step, except from configuration).

In [5]:
class StepFactoryInterface(ABC):
    @abstractmethod
    def __init__(self) -> None:
        """
        We need a shared interface for instantiating factories for specific steps, so that we can instantiate any specific step factory in the StepFactoryFaçade, without the façade knowing which kind of specific factory is used.
        """
        ...

    # todo: Consider making this a classmethod - unless this class needs to hold any state? That way, we simplify create_step, and we don't have to include the __init__ in the interface definition.
    @abstractmethod
    def create_step(self, step_type: str) -> ConfigurableRetryStep:
        ...

def load_step_config_dict_from_yaml(step_name: str) -> dict[str, Any]:
    ...

class StepFactoryFacade:
    """
    This class serves as a façade for creating steps that abstracts the following tasks from the user:
    - It receives the step name from the user, based on which it retrieves the associated config for that step.
    - From that config, it looks up what kind of step the user wants to create.
    - It looks up what's that factory it should use for creating that kind of step. (Note that this class is initialized with a lookup table that maps step names to factory classes. However, this library will also expose an instance of the StepFactoryFaçade that has already been initialized with a default lookup table, which will make the library even easier to use for less advanced users).
    - Great an instance of that specific step factory.
    - Finally, it will delegate the creation of the actual step to that specific factory, and then return the resulting step to the user.
    """
    def __init__(
        self,
        stepfactory_lookup_table: dict[str, StepFactoryInterface],
    ):
        self._stepfactory_lookup_table = stepfactory_lookup_table

    def create_step(self, step_name: str) -> ConfigurableRetryStep:
        # todo: Abstract configuration into a single config class which loads all config's it needs in the directory (even if this requires traversing). This will not only make the intent of this method more clear, but it will also make it easier to have a single config façade that abstracts what config's are global and what are step-specific (step config simply need a reference to the shared config, so they can fall back to that if necessary, but the concrete logic can be implemented differently for each step type). Also, having a config façade makes it easy to define methods that compute derived values.
        # step name identifies config location.
        step_config: dict[str, Any] = load_step_config_dict_from_yaml(step_name=step_name)
        # Before converting config into a pydantic class to validate it's types, we need to know for
        # what type of step it is (as this determines the pydantic class used).
        # todo: Make knowable for typechecker that `step_type` key exists. Use typeddict (with optional keys)?
        step_type = step_config['step_type']
        # Look up which factory to use, based on step_type speified in config
        SpecificFactory: type[StepFactoryInterface] = self._stepfactory_lookup_table[step_type]
        specific_factory = SpecificFactory()
        return specific_factory.create_step(step_config=step_config)


class _FrameworkProcessingStepFactory():
    def create_step(self, step_config: dict[str, Any]) -> ProcessingStep:
        ...

<cell>3: [34mnote:[m [m[1m"create_step"[m of [m[1m"StepFactoryInterface"[m defined here[m
<cell>6: [1m[31merror:[m Missing return statement  [m[33m[empty-body][m
<cell>39: [1m[31merror:[m Unexpected keyword argument [m[1m"step_config"[m for [m[1m"create_step"[m of [m[1m"StepFactoryInterface"[m  [m[33m[call-arg][m
<cell>43: [1m[31merror:[m Missing return statement  [m[33m[empty-body][m


Note that the StepFactoryWrapper is decoupled from the specific StepFactory that will be used to create the step. The latter is determined by a lookup table, which is injected into to the StepFactoryWrapper during instantiation.

The downside is that this is less convenient for simple use cases, where the user is content with choosing only from the default factories that ship with the library. To remediate this disadvantage, we can simply create a facade, which instantiates the StepFactoryWrapper with the default lookup table. More advanced users, by contrast, can directly import this default lookup table and customize it to point to custom StepFactory implementations. In a second step, they then initialize the StepFactoryWrapper directly, passing it the custom lookup table.

In [None]:
# higher-level-interface
# ======================

stepfactory_lookup_table: dict[str, StepFactoryInterface] = {
    'FrameworkProcessor': _FrameworkProcessingStepFactory,
}

# This is what user will import
stepfactory_wrapper = StepFactoryFacade(
    stepfactory_lookup_table=stepfactory_lookup_table,
)

In [None]:
# lower-level interface (if customization of factories is needed)
# ===============================================================

# Implement custom stepfactory
class  _CustomProcessingStepFactory():
    ...

# add it to the lookup table
stepfactory_lookup_table.update(
    {
        'CustomProcessor': _CustomProcessingStepFactory,
    },
)

# Instantiate StepFactory with customized lookup table
customized_step_factory = StepFactoryFacade(
    stepfactory_lookup_table=stepfactory_lookup_table
)