In [1]:
import jax
import jax.numpy as jnp
from jax import make_jaxpr
from jax.tree_util import Partial


In [2]:
from rubix.pipeline import linear_pipeline as ltp
from rubix.pipeline import transformer as rtr
from rubix.utils import read_yaml

### these will be replaced with a decorator probably

In [3]:
class Add(rtr.TransformerFactoryBase):
    def __init__(self, value: float = 0.0):
        self.s = value

    def create(self) -> callable:
        def f(a):
            s = self.s
            return a + s

        return f


class Multiply(rtr.TransformerFactoryBase):
    def __init__(self, factor: float = 0):
        self.s = factor

    def create(self) -> callable:
        def f(a):
            s = self.s
            return a * s

        return f


class Divide(rtr.TransformerFactoryBase):
    def __init__(self, divisor: float = 0):
        self.s = divisor

    def create(self) -> callable:
        def f(a):
            s = self.s
            return a / s

        return f


In [52]:
read_cfg = read_yaml("./demo.yml")

In [53]:
read_cfg

{'Transformers': {'A': {'name': 'Add',
   'depends_on': 'B',
   'args': {'value': 3.0}},
  'X': {'name': 'Multiply', 'depends_on': 'A', 'args': {'factor': 3}},
  'Z': {'name': 'Divide', 'depends_on': 'X', 'args': {'divisor': 4}},
  'B': {'name': 'Multiply', 'depends_on': 'C', 'args': {'factor': 2}},
  'C': {'name': 'Add', 'depends_on': None, 'args': {'value': 4}}}}

In [57]:
tp = ltp.LinearTransformerPipeline(read_cfg)

In [58]:
for classname in [Add, Multiply, Divide]: 
    tp.register_transformer(classname)

In [59]:
tp.transformers

{'Add': __main__.Add, 'Multiply': __main__.Multiply, 'Divide': __main__.Divide}

In [60]:
tp.assemble()

### Look at pipeline output to check if it's build correctly. order of application is top to bottom

In [61]:
tp.pipeline

{'C': <function __main__.Add.create.<locals>.f(a)>,
 'B': <function __main__.Multiply.create.<locals>.f(a)>,
 'A': <function __main__.Add.create.<locals>.f(a)>,
 'X': <function __main__.Multiply.create.<locals>.f(a)>,
 'Z': <function __main__.Divide.create.<locals>.f(a)>}

In [12]:
tp.expression

<function rubix.pipeline.linear_pipeline.LinearTransformerPipeline.build_expression.<locals>.expr(input)>

In [15]:
x = jnp.array([3., 2., 1.], dtype = jnp.float32)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


### Make tests to check how the pipeline behaves compared to manual composition

In [19]:
tpp = tp.pipeline

In [22]:
restest = tpp["Z"](tpp["X"](tpp["A"](tpp["B"](tpp["C"](x)))))

In [23]:
res

Array([12.75, 11.25,  9.75], dtype=float32)

In [24]:
restest

Array([12.75, 11.25,  9.75], dtype=float32)

In [26]:
tp.pipeline

{'C': <function __main__.Add.create.<locals>.f(a)>,
 'B': <function __main__.Multiply.create.<locals>.f(a)>,
 'A': <function __main__.Add.create.<locals>.f(a)>,
 'X': <function __main__.Multiply.create.<locals>.f(a)>,
 'Z': <function __main__.Divide.create.<locals>.f(a)>}

In [27]:
C = Add(4).create()
B = Multiply(2).create()
A = Add(3).create()
X = Multiply(3).create()
Z = Divide(4).create()


In [32]:
def func(x): 
    return Z(X(A(B(C(x)))))

In [34]:
compiled_func = jax.jit(func)

In [39]:
expr = jax.make_jaxpr(func)(x)

In [37]:
pexpr = jax.make_jaxpr(tp.expression)(x)

In [40]:
expr

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 4.0
    c:f32[3] = mul b 2.0
    d:f32[3] = add c 3.0
    e:f32[3] = mul d 3.0
    f:f32[3] = div e 4.0
  in (f,) }

In [41]:
pexpr

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 4.0
    c:f32[3] = mul b 2.0
    d:f32[3] = add c 3.0
    e:f32[3] = mul d 3.0
    f:f32[3] = div e 4.0
  in (f,) }

### The two expressions are the same 
- therefore it seems we can go with the for loop approach and don't need to operate on the expression level for the computation.
- we don't have to ensure the laziness ourselves at least for simple cases 
- what about more complex ones with static parameters and stuff
- we might have to for compatibility checking
- where will that break? 
- will it break at all? 
- how can we assure that it doesn't? 
