In [None]:
import jax
import jax.numpy as jnp

In [None]:
from rubix.pipeline import linear_pipeline as ltp
from rubix.pipeline import transformer as rtr
from rubix.utils import read_yaml

In [None]:
class Add(rtr.TransformerFactoryBase):
    def __init__(self, value: float = 0.0):
        self.s = value

    def create(self) -> jnp.array:
        @jax.jit
        def f(a):
            s = self.s
            return a + s

        return f


class Multiply(rtr.TransformerFactoryBase):
    def __init__(self, factor: float = 0):
        self.s = factor

    def create(self) -> jnp.array:
        @jax.jit
        def f(a):
            s = self.s
            return a * s

        return f


class Divide(rtr.TransformerFactoryBase):
    def __init__(self, divisor: float = 0):
        self.s = divisor

    def create(self) -> jnp.array:
        @jax.jit
        def f(a):
            s = self.s
            return a / s

        return f


In [None]:
cfg = {
    "Transformers": {
        "A": {
            "name": "Add",
            "depends_on": "B",
            "args": {
                "value": 3.0,
            },
        },
        "X": {
            "name": "Multiply",
            "depends_on": "A",
            "args": {
                "factor": 2,
            },
        },
        "Z": {
            "name": "Divide",
            "depends_on": "X",
            "args": {
                "divisor": 4,
            },
        },
        "B": {
            "name": "Multiply",
            "depends_on": "C",
            "args": {
                "factor": 2,
            },
        },
        "C": {
            "name": "Add",
            "depends_on": None,
            "args": {
                "value": 4,
            },
        },
    }
}

In [None]:
read_cfg = read_yaml("./demo.yml")

In [None]:
read_cfg

In [None]:
tp = ltp.LinearTransformerPipeline(cfg)

In [None]:
for classname in [Add, Multiply, Divide]: 
    tp.register_transformer(classname)

In [None]:
tp.transformers

In [None]:
tp.assemble()

In [None]:
tp.pipeline

In [None]:
tp.expression

In [None]:
jax.jit(tp.expression)

In [None]:
x = jnp.array([3, 2, 1])

In [None]:
res = tp.apply(x)

In [None]:
res

In [None]:
 res_hand = ((((x + 4) * 2) + 3) * 2) / 4

In [None]:
res_hand

In [None]:
all(res == res_hand)