In [31]:
import yaml
import numpy as np
from abc import ABC, abstractmethod

In [32]:
def read_yaml(path_to_file: str) -> dict:
    """
    read_yaml Read yaml file into dictionary

    Args:
        path_to_file (str): path to the file to read

    Raises:
        RuntimeError: When an error occurs during reading

    Returns:
        dict: Either the read yaml file in dictionary form, or an empty
              dictionary if an error occured.
    """
    cfg = {}
    try:
        with open(path_to_file, "r") as cfgfile:
            cfg = yaml.safe_load(cfgfile)
    except Exception as e:
        raise RuntimeError(
            f"Something went wrong while reading yaml file {str(path_to_file)}"
        ) from e
    return cfg

In [33]:
class Transformer(ABC):
    def _init__(self):
        pass

    @abstractmethod
    def create(self):
        pass

In [147]:
class Add(Transformer):
    def __init__(self, value: float = 0.0):
        self.s = value

    def create(self) -> np.array:
        def f(a):
            s = self.s
            return a + s

        return f


class Multiply(Transformer):
    def __init__(self, factor: float = 0):
        self.s = factor

    def create(self) -> np.array:
        def f(a):
            s = self.s
            return a * s

        return f


class Divide(Transformer):
    def __init__(self, divisor: float = 0):
        self.s = divisor

    def create(self) -> np.array:
        def f(a):
            s = self.s
            return a / s

        return f

In [237]:
class LinearTransformerPipeline:
    def __init__(self, cfg: dict):
        self.config = cfg
        self._pipeline = []
        self._names = []
        self._build_pipeline()
        self.expression = self._build_expression(self._pipeline)

    def _update(self, current_name):
        for key, node in self.config["Transformers"].items():
            if current_name == node["depends_on"]:
                self._pipeline.append(globals()[node["name"]](**node["args"]).create())
                self._names.append(key)
                self._update(key)
                
    def _build_pipeline(self):

        # find the starting point and look into corrections
        start_func = None
        start_name = None
        for key, node in self.config["Transformers"].items():

            if "name" not in node:
                raise ValueError(
                    "Error, each node of a pipeline must have a config node"
                )

            # if node["active"] is False:
            #     continue

            if node["depends_on"] is None and start_name is None:
                start_func = globals()[node["name"]](**node["args"]).create()
                start_name = key
            elif node["depends_on"] is None and start is not None:
                raise ValueError("There can only be one starting point.")
            else:
                continue

        self._pipeline = [start_func,]
        self._names = [start_name, ]

        self._update(start_name)

    def _build_expression(self, *pipeline):
        if len(pipeline) == 1: 
            return pipeline[0]
        else: 
            return pipeline[0](self._build_expression(*pipeline[1:]))

    @property 
    def pipeline(self):
        return dict(zip(self._names, self._pipeline))


    def add_transformer(self, name: str, node: dict): 
        func = globals()[node["name"]](**node["args"]).create()
        _update(name)

        self.config["Transformers"][name] = node


    def run_greedy(self, input: np.array): 
        result = input 
        for f in self._pipeline: 
            result = f(result)
        return result

    # doesn't really work
    def run(self, input: np.array):
        return self.expression(input)


In [239]:
cfg = {
    "Transformers": {
        "A": {
            "name": "Add",
            "depends_on": "B",
            "args": {
                "value": 3.0,
            },
        },
        "X": {
            "name": "Multiply",
            "depends_on": "A",
            "args": {
                "factor": 2,
            },
        },
        "Z": {
            "name": "Divide",
            "depends_on": "X",
            "args": {
                "divisor": 4,
            },
        },
        "B": {
            "name": "Multiply",
            "depends_on": "C",
            "args": {
                "factor": 2,
            },
        },
        "C": {
            "name": "Add",
            "depends_on": None,
            "args": {
                "value": 4,
            },
        },
    }
}

In [240]:
tp = LinearTransformerPipeline(cfg)

In [241]:
tp.pipeline

{'C': <function __main__.Add.create.<locals>.f(a)>,
 'B': <function __main__.Multiply.create.<locals>.f(a)>,
 'A': <function __main__.Add.create.<locals>.f(a)>,
 'X': <function __main__.Multiply.create.<locals>.f(a)>,
 'Z': <function __main__.Divide.create.<locals>.f(a)>}

In [242]:
x = np.array([3, 2, 1])

In [243]:
res = tp.run(x)

TypeError: 'list' object is not callable

In [231]:
 res_hand = ((((x + 4) * 2) + 3) * 2) / 4

In [233]:
all(res == res_hand)

True

In [None]:
class DAGTransformerPipeline(LinearTransformerPipeline):

    def _update(self, current_name: str):
        pass
