In [2]:
from adjax import *
from rich import inspect
# from rich import print
%load_ext rich

TypeError: 'type' object is not subscriptable

# new primitives

In [2]:
# add primatives for semiring operations

semiring_add_p = Primitive('semiring_add')
semiring_mul_p = Primitive('semiring_mul')

def semiring_add(x, y): return bind1(semiring_add_p, x, y)
def semiring_mul(x, y): return bind1(semiring_mul_p, x, y)

# temp just define these as| normal add and mul
impl_rules[semiring_add_p] = lambda x, y: x + y
impl_rules[semiring_mul_p] = lambda x, y: x * y

abstract_eval_rules[semiring_add_p] = binop_abstract_eval
abstract_eval_rules[semiring_mul_p] = binop_abstract_eval

In [3]:
# add some extra primatives needed here

exp_p = Primitive('exp')

def exp(x): return bind1(exp_p, x)

def exp_jvp(primals, tangents):
    (x,), (x_dot,) = primals, tangents
    return [ exp(x) ], [ exp(x) * x_dot ]

def exp_srjvp(primals, tangents):
    (x,), (x_dot,) = primals, tangents
    return [ exp(x) ], [ semiring_mul(exp(x), x_dot) ]

impl_rules[exp_p] = lambda x: np.exp(x)
jvp_rules[exp_p] = exp_jvp
abstract_eval_rules[exp_p] = vectorized_unop_abstract_eval

In [4]:
# one approach is to use a tracer that carries around the tangent and the semiring operations
# with the expectation that 

# every binary operation 

class SRJVPTracer(Tracer):
    def __init__(self, trace, primal, tangent):
        # print(f"new srjvp tracer {trace} {primal} {tangent}")
        self._trace = trace
        self.primal = primal
        self.tangent = tangent
    
    @property
    def aval(self):
        return get_aval(self.primal)
    
class SRJVPTrace(Trace):
    pure = lift = lambda self, val: SRJVPTracer(self, val, zeros_like(val))

    def process_primitive(self, primitive, tracers, params):
        # print(f"SRJVPTrace: process_primitive {primitive} {tracers} {params}")
        primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
        srjvp_rule = srjvp_rules[primitive]
        primal_outs, tangents_outs = srjvp_rule(primals_in, tangents_in, **params)
        return [SRJVPTracer(self, p, t) for p, t in zip(primal_outs, tangents_outs)]
    
srjvp_rules = {}


In [5]:

def add_srjvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents                                                                                                                                                                                                                            
  return [x + y], [semiring_add(x_dot , y_dot)]

def mul_srjvp(primals, tangents):
  (x, y), (x_dot, y_dot) = primals, tangents
  return [x * y], [semiring_add(semiring_mul(x_dot , y) , semiring_mul(x , y_dot))]

def sin_srjvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [sin(x)], [semiring_mul(cos(x) , x_dot)]

def cos_srjvp(primals, tangents):
  (x,), (x_dot,) = primals, tangents
  return [cos(x)], [semiring_mul(-sin(x) , x_dot)]

# not sure how to handle these 
def neg_srjvp(primals, tangents):                                                                      
  (x,), (x_dot,) = primals, tangents
  return [neg(x)], [neg(x_dot)]

def reduce_sum_srjvp(primals, tangents, *, axis):
  (x,), (x_dot,) = primals, tangents
  return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]

def greater_srjvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = greater(x, y)
  return [out_primal], [zeros_like(out_primal)]

def less_srjvp(primals, tangents):
  (x, y), _ = primals, tangents
  out_primal = less(x, y)
  return [out_primal], [zeros_like(out_primal)]



In [6]:
# table of rules for srjvp

srjvp_rules[add_p] = add_srjvp
srjvp_rules[mul_p] = mul_srjvp
srjvp_rules[sin_p] = sin_srjvp
srjvp_rules[cos_p] = cos_srjvp
srjvp_rules[neg_p] = neg_srjvp
srjvp_rules[reduce_sum_p] = reduce_sum_srjvp
srjvp_rules[greater_p] = greater_srjvp
srjvp_rules[less_p] = less_srjvp
srjvp_rules[exp_p] = exp_srjvp



In [7]:
def srjvp_flat(f, primals, tangents):
    # print("srjvp_flat enter")
    with new_main(SRJVPTrace) as main:
        # print("srjvp context manager entry")
        trace = SRJVPTrace(main)
        # print("main SRJVP trace created")
        tracers_in = [SRJVPTracer(trace, x, t) for x, t in zip(primals, tangents)]
        outs = f(*tracers_in)
        tracers_out = [full_raise(trace, out) for out in outs]
        primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)
    return primals_out, tangents_out

def srjvp(f, primals, tangents):
    # print(f"srjvp {primals} {tangents}")
    primals_flat, in_tree = tree_flatten(primals)
    tangents_flat, in_tree2 = tree_flatten(tangents)
    if in_tree != in_tree2:
        raise TypeError
    f, out_tree = flatten_fun(f, in_tree)
    primals_out_flat, tangents_out_flat = srjvp_flat(f, primals_flat, tangents_flat)
    primals_out = tree_unflatten(out_tree(), primals_out_flat)
    tangents_out = tree_unflatten(out_tree(), tangents_out_flat)
    return primals_out, tangents_out

def srjvp_jaxpr(jaxpr: Jaxpr) -> tuple[Jaxpr, list[Any]]:
    def srjvp_traceable(*primals_and_tangents):
        n = len(primals_and_tangents) // 2
        primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]
        return srjvp(jaxpr_as_fun(jaxpr), primals, tangents)
    
    in_avals = [v.aval for v in jaxpr.in_binders]
    new_jaxpr, new_consts, _ = make_jaxpr(srjvp_traceable, *in_avals, *in_avals)
    return new_jaxpr, new_consts



# TEST EXAMPLES

In [9]:
def f(x,y):
    return exp(x) + (x + -y)*y

jaxpr, consts, _ = make_jaxpr(f, get_aval(1.0), get_aval(1.0))
print(jaxpr)

{ lambda a:float64[], b:float64[] .
  let c:float64[] = exp a
      d:float64[] = neg b
      e:float64[] = add a d
      f:float64[] = mul e b
      g:float64[] = add c f
  in ( g ) }


In [10]:
def f(x,y):
    return 2.0*x + (x+ -y)*y

jaxpr, consts, _ = make_jaxpr(f, get_aval(1.0), get_aval(1.0))

print(jaxpr)


{ lambda a:float64[], b:float64[] .
  let c:float64[] = mul 2.0 a
      d:float64[] = neg b
      e:float64[] = add a d
      f:float64[] = mul e b
      g:float64[] = add c f
  in ( g ) }


In [12]:
jaxpr_jvp, consts_jvp = jvp_jaxpr(jaxpr)

jaxpr_srjvp, consts_srjvp = srjvp_jaxpr(jaxpr)

print(jaxpr_jvp)

print(jaxpr_srjvp)

{ lambda a:float64[], b:float64[], c:float64[], d:float64[] .
  let e:float64[] = mul 2.0 a
      f:float64[] = mul 0.0 a
      g:float64[] = mul 2.0 c
      h:float64[] = add f g
      i:float64[] = neg b
      j:float64[] = neg d
      k:float64[] = add a i
      l:float64[] = add c j
      m:float64[] = mul k b
      n:float64[] = mul l b
      o:float64[] = mul k d
      p:float64[] = add n o
      q:float64[] = add e m
      r:float64[] = add h p
  in ( q, r ) }
{ lambda a:float64[], b:float64[], c:float64[], d:float64[] .
  let e:float64[] = mul 2.0 a
      f:float64[] = semiring_mul 0.0 a
      g:float64[] = semiring_mul 2.0 c
      h:float64[] = semiring_add f g
      i:float64[] = neg b
      j:float64[] = neg d
      k:float64[] = add a i
      l:float64[] = semiring_add c j
      m:float64[] = mul k b
      n:float64[] = semiring_mul l b
      o:float64[] = semiring_mul k d
      p:float64[] = semiring_add n o
      q:float64[] = add e m
      r:float64[] = semiring_add h p


# examples and test cases

In [None]:

def semiring_backprop(f, *primals_in):
    # # flatten inputs
    # primals_in_flat, in_tree = tree_flatten(primals_in)
    # f, out_tree = flatten_fun(f, in_tree)

    # jaxpr for function eval
    jaxpr_f, consts, _ = make_jaxpr(f, *primals_in)

    # jaxpr for jvp
    jaxpr_jvp, jvp_consts = jvp_jaxpr(jaxpr_f)

    # jaxpr for sr jvp
    jaxpr_srjvp, srjvp_consts  = srjvp_jaxpr(jaxpr_f)
    
    print("jaxpr jvp")
    print(jaxpr_jvp)



    # "linearize" jvp jaxpr (i.e. partial eval on primal input)
    n = len(jaxpr_srjvp.in_binders) // 2
    print(f"n = {n}")
    in_unknowns = [False] * n + [True] * n
    print(f"in_unknowns = {in_unknowns}")
    jaxpr_primal, jaxpr_tangent, out_unknowns, num_res = partial_eval_jaxpr(jaxpr_jvp, in_unknowns)

    print("JVP jaxpr primal")
    print(jaxpr_primal)

    # prim_out = eval_jaxpr(jaxpr_primal, consts, primals_in)

    print("JVP jaxpr tangent")
    print(jaxpr_tangent)

    return jaxpr_primal, jaxpr_tangent


def f(x,y):
    return (x + -y)*y

jaxpr_primal, jaxpr_tangent = semiring_backprop(f, get_aval(1.0), get_aval(2.0))


jaxpr jvp
{ lambda a:float64[], b:float64[], c:float64[], d:float64[] .
  let e:float64[] = neg b
      f:float64[] = neg d
      g:float64[] = add a e
      h:float64[] = add c f
      i:float64[] = mul g b
      j:float64[] = mul h b
      k:float64[] = mul g d
      l:float64[] = add j k
  in ( i, l ) }
n = 2
in_unknowns = [False, False, True, True]
JVP jaxpr primal
{ lambda a:float64[], b:float64[] .
  let c:float64[] = neg b
      d:float64[] = add a c
      e:float64[] = mul d b
  in ( e, d, b ) }
JVP jaxpr tangent
{ lambda a:float64[], b:float64[], c:float64[], d:float64[] .
  let e:float64[] = neg d
      f:float64[] = add c e
      g:float64[] = mul f b
      h:float64[] = mul a d
      i:float64[] = add g h
  in ( i ) }


In [None]:
def f(x):
    return 3.0 * x + 2.0

jaxpr, consts, _ = make_jaxpr(f, get_aval(1.0))

jaxpr_jvp, consts_jvp = jvp_jaxpr(jaxpr)

jaxpr_srjvp, consts_srjvp = srjvp_jaxpr(jaxpr)

in_unknowns = [True, False]
jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr_srjvp, in_unknowns)

print("jaxpr1")

print(jaxpr1)

print("jaxpr2")

print(jaxpr2)

jaxpr1
{ lambda a:float64[] .
  let b:float64[] = semiring_mul 3.0 a
  in ( b ) }
jaxpr2
{ lambda a:float64[], b:float64[] .
  let c:float64[] = mul 3.0 b
      d:float64[] = semiring_mul 0.0 b
      e:float64[] = semiring_add d a
      f:float64[] = add c 2.0
      g:float64[] = semiring_add e 0.0
  in ( f, g ) }


In [None]:

def srbp_demo(f):
    f_jaxpr, f_consts, _ = make_jaxpr(f, *map(get_aval,)

In [13]:
# simple test case

def f(x):
    return 3.0*x + 2.0

jaxpr, consts, _ = make_jaxpr(f, get_aval(1.0))

print(f"function jaxpr")
print(jaxpr)
print(f"-------\n")

jaxpr_jvp, consts_jvp = jvp_jaxpr(jaxpr)

print(f"JVP jaxpr")
print(jaxpr_jvp)
print(f"-------\n")

# r  = jaxpr_as_fun(jaxpr)(1.0)
# print(r)




function jaxpr
{ lambda a:float64[] .
  let b:float64[] = mul 3.0 a
      c:float64[] = add b 2.0
  in ( c ) }
-------

JVP jaxpr
{ lambda a:float64[], b:float64[] .
  let c:float64[] = mul 3.0 a
      d:float64[] = mul 0.0 a
      e:float64[] = mul 3.0 b
      f:float64[] = add d e
      g:float64[] = add c 2.0
      h:float64[] = add f 0.0
  in ( g, h ) }
-------



In [14]:

print(f"SR JVP jaxpr")

jaxpr_srjvp, consts_srjvp = srjvp_jaxpr(jaxpr)

print(jaxpr_srjvp)



SR JVP jaxpr
srjvp (<adjax.JaxprTracer object at 0x7f4d803ad2d0>,) (<adjax.JaxprTracer object at 0x7f4e7fe41690>,)
srjvp_flat enter
srjvp context manager entry
main SRJVP trace created
new srjvp tracer <__main__.SRJVPTrace object at 0x7f4db80c2d90> <adjax.JaxprTracer object at 0x7f4d803ad2d0> <adjax.JaxprTracer object at 0x7f4e7fe41690>
new srjvp tracer <__main__.SRJVPTrace object at 0x7f4db80ca250> 3.0 0.0
SRJVPTrace: process_primitive Primitive(name='mul') [<__main__.SRJVPTracer object at 0x7f4d80392290>, <__main__.SRJVPTracer object at 0x7f4db80c2310>] {}
new srjvp tracer <__main__.SRJVPTrace object at 0x7f4db80ca250> <adjax.JaxprTracer object at 0x7f4d80392490> <adjax.JaxprTracer object at 0x7f4d80393e50>
new srjvp tracer <__main__.SRJVPTrace object at 0x7f4d80391f10> 2.0 0.0
SRJVPTrace: process_primitive Primitive(name='add') [<__main__.SRJVPTracer object at 0x7f4d80392990>, <__main__.SRJVPTracer object at 0x7f4d803933d0>] {}
new srjvp tracer <__main__.SRJVPTrace object at 0x7f4d8

In [14]:
jaxpr_srjvp


[1m{[0m lambda [1;92ma:f[0mloat64[1m[[0m[1m][0m, [1;92mb:f[0mloat64[1m[[0m[1m][0m .
  let [1;92mc:f[0mloat64[1m[[0m[1m][0m = mul [1;36m3.0[0m a
      [1;92md:f[0mloat64[1m[[0m[1m][0m = semiring_mul [1;36m0.0[0m a
      [1;92me:f[0mloat64[1m[[0m[1m][0m = semiring_mul [1;36m3.0[0m b
      [1;92mf:f[0mloat64[1m[[0m[1m][0m = semiring_add d e
      g:float64[1m[[0m[1m][0m = add c [1;36m2.0[0m
      h:float64[1m[[0m[1m][0m = semiring_add f [1;36m0.0[0m
  in [1m([0m g, h [1m)[0m [1m}[0m

In [None]:
def srlinearize_flat(f, *primals_in):
    print("linearize flat")

    pvals_in = ([PartialVal.known(x) for x in primals_in] +
                [PartialVal.unknown(vspace(get_aval(x))) for x in primals_in])
    
    # print(pvals_in)

    def f_srjvp(*primals_tangents_in):
        primals_out, tangents_out = srjvp(f, *split_half(primals_tangents_in))
        return [*primals_out, *tangents_out]
    
    jaxpr, pvals_out, consts = partial_eval_flat(f_srjvp, pvals_in)
    primal_pvals, _ = split_half(pvals_out)
    assert all(pval.is_known for pval in primal_pvals)
    primals_out = [pval.const for pval in primal_pvals]

    print(jaxpr)

    f_lin = lambda *tangents: eval_jaxpr(jaxpr, [*consts, *tangents])

    return primals_out, f_lin, jaxpr

def srlinearize(f, *primals_in):
    print("srlinearize")
    primals_in_flat, in_tree = tree_flatten(primals_in)
    f, out_tree = flatten_fun(f, in_tree)
    primals_out_flat, f_lin_flat, jaxpr = srlinearize_flat(f, *primals_in_flat)
    primals_out = tree_unflatten(out_tree(), primals_out_flat)

    def f_lin(*tangents_in):
        tangents_in_flat, in_tree2 = tree_flatten(tangents_in)
        if in_tree != in_tree2: raise TypeError
        tangents_out_flat = f_lin_flat(*tangents_in_flat)
        return tree_unflatten(out_tree(), tangents_out_flat)

    return primals_out, f_lin, jaxpr

def f(x):
    return 2.0*x + 1.0

x, f_lin, jaxpr = srlinearize(f, 1.0)
print(jaxpr)

srlinearize
linearize flat
