In [1]:
import yaml
import numpy as np
from abc import ABC, abstractmethod

In [2]:
import jax
import jax.numpy as jnp

In [7]:
def read_yaml(path_to_file: str) -> dict:
    """
    read_yaml Read yaml file into dictionary

    Args:
        path_to_file (str): path to the file to read

    Raises:
        RuntimeError: When an error occurs during reading

    Returns:
        dict: Either the read yaml file in dictionary form, or an empty
            dictionary if an error occured.
    """
    cfg = {}
    try:
        with open(path_to_file, "r") as cfgfile:
            cfg = yaml.safe_load(cfgfile)
    except Exception as e:
        raise RuntimeError(
            f"Something went wrong while reading yaml file {str(path_to_file)}"
        ) from e
    return cfg

In [8]:
class Transformer(ABC):
    def _init__(self):
        pass

    @abstractmethod
    def create(self):
        pass

In [10]:
class Add(Transformer):
    def __init__(self, value: float = 0.0):
        self.s = value

    def create(self) -> np.array:
        @jax.jit
        def f(a):
            s = self.s
            return a + s

        return f

    @classmethod
    def from_cfg(cls, node: dict):
        return cls(**node)


class Multiply(Transformer):
    def __init__(self, factor: float = 0):
        self.s = factor

    def create(self) -> np.array:
        @jax.jit
        def f(a):
            s = self.s
            return a * s

        return f

    @classmethod
    def from_cfg(cls, node: dict):
        return cls(**node)


class Divide(Transformer):
    def __init__(self, divisor: float = 0):
        self.s = divisor

    def create(self) -> np.array:
        @jax.jit
        def f(a):
            s = self.s
            return a / s

        return f

    @classmethod
    def from_cfg(cls, node: dict):
        return cls(**node)

In [58]:
class LinearTransformerPipeline:
    def __init__(self, cfg: dict):
        self.config = cfg
        self._pipeline = []
        self._names = []
        self._build_pipeline()
        self.expression = self._build_expression()

    def _update(self, current_name):
        for key, node in self.config["Transformers"].items():
            if current_name == node["depends_on"]:
                self._pipeline.append(
                    globals()[node["name"]](**node["args"]).create())
                self._names.append(key)
                self._update(key)

    def _build_expression(self):
        def expr(input):
            res = input
            for f in self._pipeline:
                res = f(res)
            return res

        return expr

    def _build_pipeline(self):

        # find the starting point and look into corrections
        start_func = None
        start_name = None
        for key, node in self.config["Transformers"].items():

            if "name" not in node:
                raise ValueError(
                    "Error, each node of a pipeline must have a config node"
                )

            # if node["active"] is False:
            #     continue

            if node["depends_on"] is None and start_name is None:
                start_func = globals()[node["name"]](**node["args"]).create()
                start_name = key
            elif node["depends_on"] is None and start is not None:
                raise ValueError("There can only be one starting point.")
            else:
                continue

        self._pipeline = [
            start_func,
        ]
        self._names = [
            start_name,
        ]

        self._update(start_name)

    @property
    def pipeline(self):
        return dict(zip(self._names, self._pipeline))

    def add_transformer(self, name: str, node: dict):
        func = globals()[node["name"]](**node["args"]).create()
        self._update(name)
        self.config["Transformers"][name] = node

    # doesn't really work
    def apply(self, input: np.array):
        expr = jax.jit(self.expression)
        return expr(input)

In [59]:
cfg = {
    "Transformers": {
        "A": {
            "name": "Add",
            "depends_on": "B",
            "args": {
                "value": 3.0,
            },
        },
        "X": {
            "name": "Multiply",
            "depends_on": "A",
            "args": {
                "factor": 2,
            },
        },
        "Z": {
            "name": "Divide",
            "depends_on": "X",
            "args": {
                "divisor": 4,
            },
        },
        "B": {
            "name": "Multiply",
            "depends_on": "C",
            "args": {
                "factor": 2,
            },
        },
        "C": {
            "name": "Add",
            "depends_on": None,
            "args": {
                "value": 4,
            },
        },
    }
}

In [60]:
tp = LinearTransformerPipeline(cfg)

In [61]:
tp.pipeline

{'C': <PjitFunction of <function Add.create.<locals>.f at 0x7ff064b7ecb0>>,
 'B': <PjitFunction of <function Multiply.create.<locals>.f at 0x7ff064be0280>>,
 'A': <PjitFunction of <function Add.create.<locals>.f at 0x7ff064be0670>>,
 'X': <PjitFunction of <function Multiply.create.<locals>.f at 0x7ff064be0a60>>,
 'Z': <PjitFunction of <function Divide.create.<locals>.f at 0x7ff064be0e50>>}

In [62]:
tp.expression

<function __main__.LinearTransformerPipeline._build_expression.<locals>.expr(input)>

In [63]:
x = jnp.array([3, 2, 1])

In [64]:
res = tp.apply(x)

In [56]:
res

Array([8.5, 7.5, 6.5], dtype=float32, weak_type=True)

In [54]:
 res_hand = ((((x + 4) * 2) + 3) * 2) / 4

In [57]:
res_hand

Array([8.5, 7.5, 6.5], dtype=float32)

In [55]:
all(res == res_hand)

True

In [None]:
class DAGTransformerPipeline(LinearTransformerPipeline):

    def _update(self, current_name: str):
        # TODO
        pass