# Concept
## Basic ideas
- rubix essentially implements a big data transformation pipeline. 

- a pipeline is composed of nodes that are ordered in a list ordered by execution order (or more generally a DAG [not supported currently]). Each node is called a transformer. 

- each step in this pipeline (i.e., each transformer) can ultimately be seen in itself as being composed of other, smaller transformers. This gives us a pattern that can be used to guide the implementation of transformers

- simple implementation in rubix.pipeline

## Restrictions
- jax is pure functional. Anything that needs to be transformed with jax has to be a pure function. 
Any stuff that comes from the environment must be explicitly copied into the function or be bound to it such that the internal state is of the function is self-contained. 

- It's irrelevant what builds these pure functions. Therefore, we use a factory pattern to do all configuration work like reading files, pulling stuff from the net, providing any function arguments to be used in the pipeline and so on. A factory then produces a pure function that contains all the data we need as static arguments and retains only the stuff it computes on as tracable arguments. 

- we can leverage [`jax.tree_util.Partial`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.Partial.html) for this, which works like `functools.partial` but is compatible with jax transformations. Note that stateful objects can still be used internally as long as no stuff from an outer scope (that may change over time) is read or written. This is the user's responsibility 


In [1]:
import jax
import jax.numpy as jnp
from jax import make_jaxpr
from jax.tree_util import Partial


In [2]:
from rubix.pipeline import linear_pipeline as ltp
from rubix.pipeline import transformer as rtr
from rubix.utils import read_yaml

## Build some simple decorator for function configuration
-  leverages jax.tree_util.Partial
-  builds a partial object to which jax transformations can be applied 
-  three cases: 
   -  build the pure function object: you have to take care about static args/kwargs yourself upon calling jit. The decorator only builds the function object
   -  jit it right away: the usual. here you can tell it which args/kwargs to trace or not with the `static_args` and `static_kwargs` keyword arguments
   -  build expression: mainly to check what comes out of the thing at the end of for intermediate steps. can build a jax expression (wiht no arguments) or a jax core expression (when arguments are given as well). Note that for some reasone, `jax.make_jaxpr` does not have `static_argnames` like `jit` does. 
-  With these, we can configure our pipeline transformers. 
-  Not entirely sure right now which are useful or needed
-  these decorators/factory functions live in `rubix.pipeline.transformer`

**simple transformer decorator that binds function to arguments** 

In [3]:
def transformer(**kwargs):

    def transformer_wrap(kernel):
        return Partial(kernel, **kwargs)

    return transformer_wrap

In [4]:
@transformer(z = 5, k = 3.14)
def add(x, y, z: float = 0, k: float = 0): 
    return x + y + z + k

In [5]:
type(add)

jax._src.tree_util.Partial

In [6]:
addjit = jax.jit(add)

In [7]:
x = jnp.array([3., 2., 1.], dtype = jnp.float32)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [8]:
addjit(x, x)

Array([14.14, 12.14, 10.14], dtype=float32)

**transformer that compiles stuff immediatelly**
can be used for the final pipeline or for intermediate steps during debug or whatever

In [9]:
def compiled_transformer( static_args: list = [], static_kwargs: list = [], **kwargs,):

    def transformer_wrap(kernel):

        return jax.jit(Partial(kernel, **kwargs), static_argnums= static_args, static_argnames=static_kwargs)

    return transformer_wrap

In [10]:
@compiled_transformer(z = 5, k = -3.14)
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0: 
        return x + y + z + k 
    else: 
        return x + y + z + 2*k

In [11]:
cond_add

<PjitFunction of Partial(<function cond_add at 0x7f7f13156d40>, z=5, k=-3.14)>

In [12]:
cond_add(x,x)

Array([7.8599997, 5.8599997, 3.86     ], dtype=float32)

In [13]:
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0: 
        return x + y + z + k 
    else: 
        return x + y + z + 2*k

use on predefined functions without the decorator syntax

In [14]:
cond_add_plus = compiled_transformer(z = 5, k = -3.14, static_kwargs = ["z", "k"])(cond_add)

In [15]:
cond_add_plus

<PjitFunction of Partial(<function cond_add at 0x7f7f131577f0>, z=5, k=-3.14)>

In [16]:
cond_add_plus(x,x)

Array([7.8599997, 5.8599997, 3.86     ], dtype=float32)

**Expression based decorator for getting out the intermediate `jaxpr` object for inspection** 
- `make_jaxpr` does not support kwargs. god knows why?

In [17]:
def expression_transformer(*args, static_args: list = [], ):
     
    def transformer_wrap(kernel): 
        if len(args) > 0: 
            return jax.make_jaxpr(kernel, static_argnums=static_args)(*args)
        else: 
            return jax.make_jaxpr(kernel, static_argnums=static_args)
    return transformer_wrap

In [18]:
@expression_transformer(x, x, 5, 3.14, static_args = [2, 3])
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0: 
        return x + y + z + k 
    else: 
        return x + y + z + 2*k

In [19]:
cond_add

{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = add a b
    d:f32[3] = add c 5.0
    e:f32[3] = add d 6.28000020980835
  in (e,) }

In [20]:
@expression_transformer(x, x, 5, -3.14, static_args = [3])
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0: 
        return x + y + z + k 
    else: 
        return x + y + z + 2*k

In [21]:
cond_add

{ lambda ; a:f32[3] b:f32[3] c:i32[]. let
    d:f32[3] = add a b
    e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
    f:f32[3] = add d e
    g:f32[3] = add f -3.140000104904175
  in (g,) }

In [22]:
@expression_transformer(static_args = [2, 3])
def cond_add(x, y, z: float = 0, k: float = 0):
    if k < 0: 
        return x + y + z + k 
    else: 
        return x + y + z + 2*k

In [23]:
cond_add

<function jax.make_jaxpr(cond_add)(x, y, z: float = 0, k: float = 0)>

### Define a number of simple, dump transformers
- we pretend that their second value is something we want to configure from the start and hence it should not be traced

- we can use the above decorators to bind their second arg to something we know

In [24]:
def add(x, s: float): 
    return x + s 

def mult(x, m: float): 
    return x * m 

def div(x, d: float): 
    return x / d 

def sub(x, s: float): 
    return x - s 

## Configuration files and pipeline building

- yaml format: dictionary 
- inside the dictionary one can arbitrarily nest lists, dicts. 
- customizable for node formats not provided 
- available on pretty much all languages
- the config file builds an adjacency list of a DAG essentially, but currently it's limited to only one child per node => linear
- the build algorithm is limited to linear pipelines for the moment 
- while a more general base class is provided, we only implement linear pipelines atm 
- the essential part is the `transformers` node of the config. this is the actual DAG adjacency list 
- you can use other nodes to configure other parts of your system: data directories etc


#### Config node structure: 
name_of_pipeline_step:

____name: name_of_function

____depends_on: name_of_step_immediatelly_prior_in_pipeline 
    
____args: 

________ argument1: value1 

________ argument2: value2 

________ argumentN: valueN

-> 

**see the example file for for details**

the arguments in `args` will be used to create the partial object

In [25]:
read_cfg = read_yaml("./demo.yml")  # implemented in utils

In [26]:
read_cfg

{'Transformers': {'A': {'name': 'add', 'depends_on': 'B', 'args': {'s': 3.0}},
  'X': {'name': 'mult', 'depends_on': 'A', 'args': {'m': 3}},
  'Z': {'name': 'div', 'depends_on': 'X', 'args': {'d': 4}},
  'B': {'name': 'sub', 'depends_on': 'C', 'args': {'s': 2}},
  'C': {'name': 'add', 'depends_on': None, 'args': {'s': 4}}}}

In [27]:
type(read_cfg)

dict

In [28]:
read_cfg["Transformers"]

{'A': {'name': 'add', 'depends_on': 'B', 'args': {'s': 3.0}},
 'X': {'name': 'mult', 'depends_on': 'A', 'args': {'m': 3}},
 'Z': {'name': 'div', 'depends_on': 'X', 'args': {'d': 4}},
 'B': {'name': 'sub', 'depends_on': 'C', 'args': {'s': 2}},
 'C': {'name': 'add', 'depends_on': None, 'args': {'s': 4}}}

In [29]:
type(read_cfg["Transformers"])

dict

In [30]:
tp = ltp.LinearTransformerPipeline(read_cfg)

In [31]:
for name in [add, mult, div, sub]: 
    tp.register_transformer(name)

In [32]:
tp.transformers

{'add': <function __main__.add(x, s: float)>,
 'mult': <function __main__.mult(x, m: float)>,
 'div': <function __main__.div(x, d: float)>,
 'sub': <function __main__.sub(x, s: float)>}

The `transformers` member gives us a dict of `name: function` pairs for the transformers 
This currently has to be done before the assembly of the pipeline, or the pipeline will not know what to assemble it from

In [33]:
tp.assemble()

In [34]:
tp.pipeline

{'C': Partial(<function add at 0x7f7f13157d00>, s=4),
 'B': Partial(<function sub at 0x7f7f13188280>, s=2),
 'A': Partial(<function add at 0x7f7f13157d00>, s=3.0),
 'X': Partial(<function mult at 0x7f7f13156e60>, m=3),
 'Z': Partial(<function div at 0x7f7f131881f0>, d=4)}

Now we have a list of jax `Partial`s to which we can apply, assuming the individual elements are well behaved, all jax transformations in principle. If this is true for the elements, then it is true for the composition as long as the function we use for composition is pure functional itself

In [35]:
tp.expression

Partial(<function LinearTransformerPipeline.build_expression.<locals>.expr at 0x7f7f13188550>, pipeline=[Partial(<function add at 0x7f7f13157d00>, s=4), Partial(<function sub at 0x7f7f13188280>, s=2), Partial(<function add at 0x7f7f13157d00>, s=3.0), Partial(<function mult at 0x7f7f13156e60>, m=3), Partial(<function div at 0x7f7f131881f0>, d=4)])

In [36]:
func = tp.compile_expression()

In [37]:
x

Array([3., 2., 1.], dtype=float32)

In [38]:
func(x)

Array([6.  , 5.25, 4.5 ], dtype=float32)

In [39]:
div(mult(add(sub(add(x, s=4), s=2), s=3), m=3),d=4)

Array([6.  , 5.25, 4.5 ], dtype=float32)

In [40]:
func

<PjitFunction of Partial(_HashableCallableShim(Partial(<function LinearTransformerPipeline.build_expression.<locals>.expr at 0x7f7f13188550>, pipeline=[Partial(<function add at 0x7f7f13157d00>, s=4), Partial(<function sub at 0x7f7f13188280>, s=2), Partial(<function add at 0x7f7f13157d00>, s=3.0), Partial(<function mult at 0x7f7f13156e60>, m=3), Partial(<function div at 0x7f7f131881f0>, d=4)])))>

In [41]:
tp.get_jaxpr()(x)

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 4.0
    c:f32[3] = sub b 2.0
    d:f32[3] = add c 3.0
    e:f32[3] = mul d 3.0
    f:f32[3] = div e 4.0
  in (f,) }

In [42]:
def func(x): 
    return div(mult(add(sub(add(x, s=4), s=2), s=3), m=3),d=4)

In [43]:
make_jaxpr(func)(x)

{ lambda ; a:f32[3]. let
    b:f32[3] = add a 4.0
    c:f32[3] = sub b 2.0
    d:f32[3] = add c 3.0
    e:f32[3] = mul d 3.0
    f:f32[3] = div e 4.0
  in (f,) }

## Summary 
- pipeline produces same jax code as handwritten stuff. This seems encouraging.
- at which points do we still need to ensure pure functional behavior?
- how will we enforce transformer compatibility
- this is a pathologically simple case 
- when does it break? 
- what use cases are not covered?
- what else do you need? 
