In [65]:
from jax import numpy as jnp, random, tree_util, nn
from jax import jacfwd, jacrev, jvp, vjp, jit, grad, vmap
from flax import linen as nn

In [66]:
# set seed, create key
key = random.PRNGKey(seed=1)

1. create model

In [67]:
# define model
d = 2
d_in, d_out = d, d
model = nn.Dense(d_out)

# initialize parameters
key, split = random.split(key)
params = model.init(key, jnp.ones(d_in))

2. parameters pytree

In [71]:
# print pytree representation
#type(params)
print(repr(params))

{'params': {'kernel': Array([[-0.67368186,  0.05756453],
       [-0.38660565, -0.07138228]], dtype=float32), 'bias': Array([0., 0.], dtype=float32)}}


In [72]:
# get leaves
leaves = tree_util.tree_leaves(params)
print(f"Number of leaves: {len(leaves)}")

Number of leaves: 2


In [73]:
# check shapes
tree_util.tree_map(lambda x: x.shape, params)

{'params': {'bias': (2,), 'kernel': (2, 2)}}

In [74]:
# get leaves and tree structure
leaves, params_treedef = tree_util.tree_flatten(params)
type(params_treedef)

jaxlib.xla_extension.pytree.PyTreeDef

3. modify pytrees

In [75]:
# zero pytree
updated_params = tree_util.tree_map(jnp.zeros_like, params)
tree_util.tree_map(lambda x: x.shape, updated_params)

{'params': {'bias': (2,), 'kernel': (2, 2)}}

In [77]:
# zeros pytree with batch
batch_size = 100
leaves_batched = list(map(lambda l: jnp.zeros((batch_size,) + l.shape), leaves))
updated_params = tree_util.tree_unflatten(params_treedef, leaves_batched)
tree_util.tree_map(lambda x: x.shape, updated_params)

{'params': {'bias': (100, 2), 'kernel': (100, 2, 2)}}

In [78]:
# constant pytree
a = 2.
updated_params = tree_util.tree_unflatten(params_treedef, list(map(lambda l: a * jnp.ones_like(l), leaves)))
#repr(updated_params)

In [79]:
# operations between pytrees
tree_util.tree_map(lambda x, y: x + y, params, updated_params)

{'params': {'bias': Array([2., 2.], dtype=float32),
  'kernel': Array([[1.3263181, 2.0575645],
         [1.6133944, 1.9286177]], dtype=float32)}}

4. Jacobian

In [98]:
x = jnp.ones(d)
u = model.apply(params, x)
#x.shape, y.shape, tree_util.tree_map(lambda x: x.shape, params)

# model with fixed input
f = lambda p: model.apply(p, x)

In [99]:
#jac_tree = jacfwd(f)(params)
jac_tree = jacfwd(model.apply)(params, x)
tree_util.tree_map(lambda x: x.shape, jac_tree)

{'params': {'bias': (2, 2), 'kernel': (2, 2, 2)}}

In [100]:
%time jac_tree = jacrev(model.apply)(params, x)

CPU times: user 10.1 ms, sys: 667 µs, total: 10.7 ms
Wall time: 10.1 ms


In [101]:
batch_size = int(1e3)
x = jnp.ones((batch_size, d))
f = lambda p: model.apply(p, x)

In [104]:
#jac_tree = jacfwd(model.apply)(params, x)
jac_tree = jacfwd(f)(params)
tree_util.tree_map(lambda x: x.shape, jac_tree)

{'params': {'bias': (1000, 2, 2), 'kernel': (1000, 2, 2, 2)}}

In [105]:
%time jac_tree = jacrev(model.apply)(params, x)

CPU times: user 46.4 ms, sys: 25.1 ms, total: 71.6 ms
Wall time: 44.8 ms


In [106]:
tree_util.tree_map(lambda x: x.shape, jac_tree)

{'params': {'bias': (1000, 2, 2), 'kernel': (1000, 2, 2, 2)}}

5. JVP and VJP

In [148]:
x = jnp.ones(d)
f = lambda p: model.apply(p, x)
y = f(params)
x.shape, y.shape, y

((2,), (2,), Array([-1.0602875 , -0.01381774], dtype=float32))

In [149]:
# jacobian vector product
jvp(f, (params,), (params,))

(Array([-1.0602875 , -0.01381774], dtype=float32),
 Array([-1.0602875 , -0.01381774], dtype=float32))

In [150]:
# vector jacobian product
y, vjp_fun = vjp(f, params)
vjp_tree = vjp_fun(y)[0]
print(tree_util.tree_map(lambda x: x.shape, vjp_tree))
print(repr(vjp_tree))

{'params': {'bias': (2,), 'kernel': (2, 2)}}
{'params': {'bias': Array([-1.0602875 , -0.01381774], dtype=float32), 'kernel': Array([[-1.0602875 , -0.01381774],
       [-1.0602875 , -0.01381774]], dtype=float32)}}


In [155]:
# vector jacobian product
y, vjp_fun = vjp(model.apply, params, x)
vjp_tree = vjp_fun(y)
print(tree_util.tree_map(lambda x: x.shape, vjp_tree))
print(repr(vjp_tree))

({'params': {'bias': (2,), 'kernel': (2, 2)}}, (2,))
({'params': {'bias': Array([-1.0602875 , -0.01381774], dtype=float32), 'kernel': Array([[-1.0602875 , -0.01381774],
       [-1.0602875 , -0.01381774]], dtype=float32)}}, Array([0.71350104, 0.41089946], dtype=float32))


with batch

In [157]:
batch_size = int(1e1)
x = jnp.ones((batch_size, d))
f = lambda p: model.apply(p, x)
f(params).shape

(10, 2)

In [161]:
# vector jacobian product
y, vjp_fun = vjp(f, params)#, reduce_axes=('batch',))
y.shape, vjp_fun(y)
#vmap(vjp_fun)(y)

((10, 2),
 ({'params': {'bias': Array([-10.602875  ,  -0.13817742], dtype=float32),
    'kernel': Array([[-10.602875  ,  -0.13817742],
           [-10.602875  ,  -0.13817742]], dtype=float32)}},))

In [160]:
# vector jacobian product
y, vjp_fun = vmap(vjp, in_axes=(None, 0))(f, params)

ScopeParamShapeError: Initializer expected to generate shape (2,) but got shape (2, 2) instead for parameter "kernel" in "/". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [144]:
y, vjp_fun = vmap(vjp, in_axes=(0, None, 0))(model.apply, params, x)
vmap(vjp)(y)

reduce_axes=('batch',)

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())