In [1]:
import jax
import jax.numpy as jnp
from jax import make_jaxpr


In [2]:
from rubix.pipeline import linear_pipeline as ltp
from rubix.pipeline import transformer as rtr
from rubix.utils import read_yaml

In [3]:
class Add(rtr.TransformerFactoryBase):
    def __init__(self, value: float = 0.0):
        self.s = value

    def create(self) -> jnp.array:
        def f(a):
            s = self.s
            return a + s

        return f


class Multiply(rtr.TransformerFactoryBase):
    def __init__(self, factor: float = 0):
        self.s = factor

    def create(self) -> jnp.array:
        def f(a):
            s = self.s
            return a * s

        return f


class Divide(rtr.TransformerFactoryBase):
    def __init__(self, divisor: float = 0):
        self.s = divisor

    def create(self) -> jnp.array:
        def f(a):
            s = self.s
            return a / s

        return f


In [4]:
cfg = {
    "Transformers": {
        "A": {
            "name": "Add",
            "depends_on": "B",
            "args": {
                "value": 3.0,
            },
        },
        "X": {
            "name": "Multiply",
            "depends_on": "A",
            "args": {
                "factor": 2,
            },
        },
        "Z": {
            "name": "Divide",
            "depends_on": "X",
            "args": {
                "divisor": 4,
            },
        },
        "B": {
            "name": "Multiply",
            "depends_on": "C",
            "args": {
                "factor": 2,
            },
        },
        "C": {
            "name": "Add",
            "depends_on": None,
            "args": {
                "value": 4,
            },
        },
    }
}

In [5]:
read_cfg = read_yaml("./demo.yml")

In [6]:
read_cfg

{'Transformers': {'A': {'name': 'Add',
   'depends_on': 'B',
   'args': {'value': 3.0}},
  'X': {'name': 'Multiply', 'depends_on': 'A', 'args': {'factor': 2}},
  'Z': {'name': 'Divide', 'depends_on': 'X', 'args': {'divisor': 4}},
  'B': {'name': 'Multiply', 'depends_on': 'C', 'args': {'factor': 2}},
  'C': {'name': 'Add', 'depends_on': 'None', 'args': {'value': 4}}}}

In [7]:
tp = ltp.LinearTransformerPipeline(cfg)

In [8]:
for classname in [Add, Multiply, Divide]: 
    tp.register_transformer(classname)

In [9]:
tp.transformers

{'Add': __main__.Add, 'Multiply': __main__.Multiply, 'Divide': __main__.Divide}

In [10]:
tp.assemble()

In [11]:
tp.pipeline

{'C': <function __main__.Add.create.<locals>.f(a)>,
 'B': <function __main__.Multiply.create.<locals>.f(a)>,
 'A': <function __main__.Add.create.<locals>.f(a)>,
 'X': <function __main__.Multiply.create.<locals>.f(a)>,
 'Z': <function __main__.Divide.create.<locals>.f(a)>}

In [12]:
tp.expression

<function rubix.pipeline.linear_pipeline.LinearTransformerPipeline.build_expression.<locals>.expr(input)>

In [13]:
jax.jit(tp.expression)

<PjitFunction of <function LinearTransformerPipeline.build_expression.<locals>.expr at 0x7f41b4589ab0>>

In [17]:
x = jnp.array([3., 2., 1.], dtype = jnp.float32)

In [18]:
from jax import make_jaxpr


In [19]:
make_jaxpr(tp.expression)(x)

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 4.0
    c:f32[3] = mul b 2.0
    d:f32[3] = add c 3.0
    e:f32[3] = mul d 2.0
    f:f32[3] = div e 4.0
  in (f,) }

In [None]:
res = tp.apply(x)