In [2]:
import functools
import inspect

In [45]:
def compose(*workflow):
    def dec(loss):
        @functools.wraps(loss)
        def called(*args,**kwargs): # *args are for grad, **kwargs are the rest
            
            res = dict([])
            
            for f in workflow+(loss,):
                sig = inspect.signature(f)
                f_args = sig.parameters.keys()
                
                feed_args = False
                feed_kwargs = False
                
                arglist = []
                
                for arg in f_args:
                    if not feed_args or not feed_kwargs:
                        if arg in kwargs.keys() and arg not in res.keys():
                            feed_kwargs = True
                            arglist.append(arg)
                        elif arg not in kwargs.keys() and arg not in res.keys():
                            feed_args = True
                    else:
                        break
                        
                f_kwargs = {k:kwargs[k] for k in arglist}

                if feed_args and feed_kwargs:
                    res = f(*args, **res, **f_kwargs)
                elif feed_args and not feed_kwargs:
                    res = f(*args, **res)
                elif not feed_args and feed_kwargs:
                    res = f(**res, **f_kwargs)
                else:
                    res = f(**res)
                        
            return res
        
        return called
    
    return dec
            
  

In [61]:

def comp(*workflow):
    def pipeline(*args,**kwargs): # *args are for grad, **kwargs are the rest

        res = dict([])

        for f in workflow:
            sig = inspect.signature(f)
            f_args = sig.parameters.keys()

            feed_args = False
            feed_kwargs = False

            arglist = []

            for arg in f_args:
                if not feed_args or not feed_kwargs:
                    if arg in kwargs.keys() and arg not in res.keys():
                        feed_kwargs = True
                        arglist.append(arg)
                    elif arg not in kwargs.keys() and arg not in res.keys():
                        feed_args = True
                else:
                    break

            f_kwargs = {k:kwargs[k] for k in arglist}

            if feed_args and feed_kwargs:
                res = f(*args, **res, **f_kwargs)
            elif feed_args and not feed_kwargs:
                res = f(*args, **res)
            elif not feed_args and feed_kwargs:
                res = f(**res, **f_kwargs)
            else:
                res = f(**res)

        return res

    return pipeline
            
def data_gen(p):
    return dict(data = jnp.array([3 * p**2,4]))

def preprocess(params, data):
    s, b = data
    return dict(s = s + params, b = b)
    
def loss(yeet, s, b):
    return s / b - yeet

pipeline = compose(data_gen, preprocess, loss)

jax.grad(pipeline)(5., params=1, p=3)    # grad wrt 'yeet'
jax.grad(pipeline)(3., params=1, yeet=5) # grad wrt 'p'
jax.grad(pipeline)(1., p=1, yeet=5)      # grad wrt 'params'

DeviceArray(0.25, dtype=float32)

In [46]:
import jax
import jax.numpy as jnp

In [50]:
# conditions:
# - each func in chain must return dict 
#   corresponding to some kwargs of next function

# - for jax.grad to work, need to have the args
#   you want to diff with respect to as the *first*
#   arg of the relevant func in the chain
#     - the first arg of a @composed function is then that arg
#     - everything else must be specified with a keyword

def data_gen(p):
    return dict(data = jnp.array([3 * p**2,4]))

def preprocess(params, data):
    s, b = data
    return dict(s = s + params, b = b)
    
@compose(data_gen, preprocess)
def loss(yeet, s, b):
    return s / b - yeet

jax.grad(loss)(5., params=1, p=3)    # grad wrt 'yeet'
jax.grad(loss)(3., params=1, yeet=5) # grad wrt 'p'
jax.grad(loss)(1., p=1, yeet=5)      # grad wrt 'params'

DeviceArray(0.25, dtype=float32)

In [51]:
from numpy import random

make pipeline object, and do the arg stuff when calling?

could use decorator for appending to pipeline?

In [56]:
def f(a: int, b: list) -> int:
    return a

args = dict(a=int, b=list)
params = [inspect.Parameter(param,
                            inspect.Parameter.POSITIONAL_OR_KEYWORD,
                            annotation=type_)
                        for param, type_ in args.items()]
f.__signature__ = inspect.Signature(params)
f.__annotations__ = args

In [60]:
f.__annotations__

{'a': int, 'b': list}

In [59]:
f.__signature__

<Signature (a: int, b: list)>

In [30]:
!python --version

Python 3.8.3


(0, 3, 3)

In [None]:
arglist = [arg1, arg2, ...]

func.__signature__)