In [None]:
%pip install -r requirements.txt

In [5]:
from __future__ import annotations

import jax
import jax.numpy as jnp

from melody import compose

# for the case where you have a bunch of functions that you want to chain together
# when doing a pipeline that you may want to take the grad of, e.g.:
#
# data generation -> preprocess -> loss
#
# but you may not want to have to specify all the kwargs for each function
# (e.g. you want to specify the kwargs for the first function, and then the rest
# are inferred from the return values of the previous functions in the chain)
#
# saves you doing
# def pipeline(thing_for_grad, **kwargs):
#    data = data_gen(thing_for_grad, some_kwargs)
#    data = preprocess(data, some_other_kwargs)
#    loss = loss(data, some_other_other_kwargs)
#    return loss
#
# conditions:
# - each func in chain must return dict 
#   corresponding to some kwargs of next function
# - first arg is for grad, rest need to be specified as kwargs
#
# it's a bit hacky so never ended up using it, but it's a cool idea!

def data_gen(p: float) -> dict[str, float]:
    return dict(data = jnp.array([3 * p**2,4]))

def preprocess(param: float, data: jax.Array) -> dict[str, float]:
    s, b = data
    return dict(s = s + param, b = b)
    
def loss(yeet: float, s: float, b: float) -> float:
    return s / b - yeet

pipeline = compose([data_gen, preprocess, loss])
pipeline(1, p=4., param=343.)

Array(96.75, dtype=float32)

In [7]:
# can take grad of pipeline wrt any of the args if you keep it as the first arg
print(jax.grad(pipeline)(5., params=1, p=3))    # grad wrt 'yeet'
print(jax.grad(pipeline)(3., params=1, yeet=5)) # grad wrt 'p'
print(jax.grad(pipeline)(1., p=1, yeet=5))      # grad wrt 'params'

-0.75
4.75
0.25
