In [1]:
%load_ext nb_mypy

Version 1.0.5


In [2]:
# to make imports and folder paths work
# todo: Instead create python package and install locally
import os, sys
os.chdir(
    f'{os.environ["HOME"]}/repos/sagemaker-pipelines-abstraction/src'
)

In [3]:
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Literal, Callable, TypeAlias
from pathlib import Path
from datetime import datetime
from typing import TypeVar, Generic
from dataclasses import dataclass

from pydantic_settings import BaseSettings
from loguru import logger
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.steps import Step
from sagemaker.sklearn.estimator import SKLearn
from sagemaker.processing import FrameworkProcessor
from sagemaker.workflow.steps import ConfigurableRetryStep, ProcessingStep
from sm_pipelines_oo.shared_config_schema import Environment

from sm_pipelines_oo.shared_config_schema import SharedConfig, Environment
# from sm_pipelines_oo.steps.interfaces import StepFactoryInterface
from sm_pipelines_oo.connector.interface import AWSConnectorInterface
from sm_pipelines_oo.utils import load_pydantic_config_from_file
from sm_pipelines_oo.connector.interface import AWSConnectorInterface
from sm_pipelines_oo.connector.implementation import create_aws_connector
from sm_pipelines_oo.pipeline_wrapper import PipelineWrapper


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/thomas-22/.config/sagemaker/config.yaml


In [4]:
class SharedConfig(BaseSettings):  # type: ignore
    """Defines configuration shared by all pipeline steps (for a given environment)."""
    project_name: str
    project_version: str  # Versions data (and probably more in the future)
    region: str
    # To do: consider which of these fields should be made required.
    role_name: str | None = None
    project_bucket_name: str

shared_config: SharedConfig = SharedConfig(
    project_name='design-decisions',
    project_version='0.0',  # Versions data (and probably more in the future)
    region='local',
    # To do: consider which of these fields should be made required.
    role_name=None,
    project_bucket_name='design-decisions',
)

In [5]:
class ProcessingConfig(BaseSettings):  # type: ignore
    input_filename: str
    instance_type: str
    instance_count: int
    sklearn_framework_version: str
    # Don't set in config. This needs to correspond to SM's convention for local folder structure.
    # todo: Make this not set-able. Use property instead?
    step_type: Literal['processing'] = "processing"
    step_name: str = "processing"

processing_config = ProcessingConfig(
    input_filename='input.parquet',
    instance_type='local',
    instance_count=1,
    sklearn_framework_version='0.23-1',
)

# Current design: Shared config *not* accessible to step factories

In [6]:
class PipelineWrapper:
    def __init__(
        self,
        step_factories: list[StepFactoryInterface],
        environment: Environment,
        shared_config: SharedConfig,
        aws_connector: AWSConnectorInterface,
    ) -> None:
        self.environment = environment
        self.shared_config = shared_config
        self._aws_connector = aws_connector

        self.steps: list[Step] = []
        self._create_steps(step_factories, shared_config)

    def _create_steps(self, step_factories: list[StepFactoryInterface], shared_config: SharedConfig) -> None:
        for factory in step_factories:
            step: Step = factory.create_step(
                shared_config=shared_config,
            )
            self.steps.append(step)

    @cached_property
    def _pipeline(self) -> Pipeline:
        pipeline_name = f'{self.shared_config.project_name}-{datetime.now():%Y-%m-%d-%H-%M-%S}'
        pipeline = Pipeline(
            name=pipeline_name,
            steps=self.steps,
            sagemaker_session=self._aws_connector.sm_session,
        )
        pipeline.create(role_arn=self._aws_connector.role_arn)
        return pipeline


    # Public methods
    # ==============

    def run(self) -> None:
        logger.info(f"Starting pipeline run for project {self.shared_config.project_name}")
        execution = self._pipeline.start()
        execution.wait()
        execution.list_steps()


<cell>1: [1m[31merror:[m Name [m[1m"PipelineWrapper"[m already defined (possibly by an import)  [m[33m[no-redef][m
<cell>18: [1m[31merror:[m Unexpected keyword argument [m[1m"shared_config"[m for [m[1m"create_step"[m of [m[1m"StepFactoryInterface"[m  [m[33m[call-arg][m


In [7]:
# Run
# ===
aws_connector: AWSConnectorInterface = create_aws_connector(
    environment=ENVIRONMENT,
    shared_config=shared_config,
    run_as_pipeline=True,
)

pre_processing_step_factory = ProcessingStepFactory(
    processor_cls=FrameworkProcessor,
    processor_extra_kwargs={'estimator_cls': SKLearn},
    step_config_path=config_path_pre_processing,
    aws_connector=aws_connector,
)
pipeline = PipelineWrapper(
    step_factories=[
        pre_processing_step_factory,
    ],
    environment=ENVIRONMENT,
    shared_config=shared_config,
    aws_connector=aws_connector,
)
try:
    pipeline.run()
except Exception as e:
    logger.error(e)


<cell>4: [1m[31merror:[m Name [m[1m"ENVIRONMENT"[m is not defined  [m[33m[name-defined][m
<cell>9: [1m[31merror:[m Name [m[1m"ProcessingStepFactory"[m is not defined  [m[33m[name-defined][m
<cell>12: [1m[31merror:[m Name [m[1m"config_path_pre_processing"[m is not defined  [m[33m[name-defined][m
<cell>19: [1m[31merror:[m Name [m[1m"ENVIRONMENT"[m is not defined  [m[33m[name-defined][m


NameError: name 'ENVIRONMENT' is not defined

# Alternative: Shared config *is* accessible to step factories

In [None]:
from abc import ABC, abstractmethod
from sagemaker.workflow.steps import ConfigurableRetryStep, ProcessingStep

# *General* step factory interface
# ==============================
# class StepFactoryInterface(ABC):


#     @abstractmethod
#     def create_step(self) -> ConfigurableRetryStep:
#         ...

#     @property
#     @abstractmethod
#     def _step_name(self) -> str:
#         ...

#     @abstractmethod
#     def _get_step_config(self, environment: Environment) -> dict:  # todo: improve type
#         ...


class ProcessingStepFactoryInterface(StepFactoryInterface):
    # todo: add specific methods
    ...
class FrameworkProcessorFactory(ProcessingStepFactoryInterface):
    def __init__(
            self,
            step_name: str,  # identifies step config (given env)
        ) -> None:
            self._step_name = step_name


In [None]:
from sm_pipelines_oo.connector.interface import AWSConnectorInterface
from sm_pipelines_oo.connector.implementation import AWSConnector, LocalAWSConnector

class PipelineWrapper:
    def __init__(
        self,
        step_factories: list[StepFactoryInterface],
        environment: Environment,
        shared_config: SharedConfig,
        aws_connector: AWSConnectorInterface,
    ) -> None:
        self.environment: Environment = environment
        self.shared_config = shared_config
        self._aws_connector = aws_connector

        self.steps: list[Step] = []
        self._create_steps(step_factories, shared_config)

    @cached_property
    def _aws_connector(self) -> AWSConnectorInterface:
        """
        This code makes connector.implementation.create_aws_connector() redundant, except for use
        outside of pipeline.
        Todo: decide where to put code for the latter case.
        """
        if self.environment == 'local':
            return LocalAWSConnector()
        else:
            return AWSConnector(
                environment=self.environment,
                # this error will resolve once we don't use SharedConfig from this notebook but
                # library's AWSConnector.
                shared_config=self.shared_config,  # type: ignore
                run_as_pipeline=True
            )


    def _create_steps(self, step_factories: list[StepFactoryInterface], shared_config: SharedConfig) -> None:
        for factory in step_factories:
            step: Step = factory.create_step()
            self.steps.append(step)

    @cached_property
    def _pipeline(self) -> Pipeline:
        pipeline_name = f'{self.shared_config.project_name}-{datetime.now():%Y-%m-%d-%H-%M-%S}'
        pipeline = Pipeline(
            name=pipeline_name,
            steps=self.steps,
            sagemaker_session=self._aws_connector.sm_session,
        )
        pipeline.create(role_arn=self._aws_connector.role_arn)
        return pipeline


    # Public methods
    # ==============

    def run(self) -> None:
        logger.info(f"Starting pipeline run for project {self.shared_config.project_name}")
        execution = self._pipeline.start()
        execution.wait()
        execution.list_steps()
