In [1]:
from jax import numpy as jnp, random, tree_util, nn
from jax import jacfwd, jacrev, jvp, vjp, jit, grad, vmap
import equinox as eqx

from playground_jax.models_equinox import *

In [2]:
# set seed, create key
key = random.PRNGKey(seed=1)

1. create model

In [3]:
d = 2
d_in, d_hiddens, d_out = d, [32, 16], d
key, split = random.split(key)
model = MLP(d_in, d_out, d_hiddens, nn.tanh, lambda x: x, key=key)

2. the model as a pytree

In [4]:
# split between parameters and the model structure
params, static = eqx.partition(model, eqx.is_array)

In [5]:
# print the pytree structure
print(type(params)), print(f"Pytree: {repr(model):<30}")

<class 'playground_jax.models_equinox.MLP'>
Pytree: MLP(
  layers=[
    Linear(
      weight=f32[32,2],
      bias=f32[32],
      in_features=2,
      out_features=32,
      use_bias=True
    ),
    <wrapped function <lambda>>,
    Linear(
      weight=f32[16,32],
      bias=f32[16],
      in_features=32,
      out_features=16,
      use_bias=True
    ),
    <wrapped function <lambda>>,
    Linear(
      weight=f32[2,16],
      bias=f32[2],
      in_features=16,
      out_features=2,
      use_bias=True
    ),
    <function <lambda>>
  ]
)


(None, None)

In [6]:
# get the pytree elaves

leaves = tree_util.tree_leaves(model)
print(f"Number of leaves: {len(leaves)}")

Number of leaves: 9


In [8]:
# check shapes
#tree_util.tree_map(lambda x: x.shape, params)

In [9]:
leaves, model_treedef = tree_util.tree_flatten(params)

In [18]:
#type(leaves), type(model_treedef)
#leaves
#params
#model_treedef

3. modify pytrees

In [15]:
updated_leaves = list(map(lambda l: l + jnp.ones(1), leaves))
#updated_params = tree_util.tree_unflatten(model_treedef, updated_leaves)

In [16]:
# zero pytree
%time updated_params = tree_util.tree_map(jnp.zeros_like, params)

CPU times: user 51.8 ms, sys: 0 ns, total: 51.8 ms
Wall time: 51 ms


In [21]:
# constant pytree
a = 2.
%time updated_params = tree_util.tree_unflatten(params_treedef, list(map(lambda l: a * jnp.ones_like(l), leaves)))
#repr(updated_params)

CPU times: user 15.6 ms, sys: 0 ns, total: 15.6 ms
Wall time: 14.9 ms


4. Jacobian of a model with respect to the parameters

In [54]:
x = jnp.ones(d)
def f(params):
    model = eqx.combine(params, static)
    return model(x)

In [55]:
%time jac = jacfwd(f)(params)
#%time jac = jacrev(model)(x)

CPU times: user 9 ms, sys: 465 µs, total: 9.47 ms
Wall time: 8.33 ms


In [56]:
type(jac)

playground_jax.models_equinox.MLP

batch input

In [61]:
batch_size = int(1e6)
x = jnp.ones((batch_size, d))
def f(params):
    model = eqx.combine(params, static)
    return vmap(model)(x)
y = f(params)
y.shape

(1000000, 2)

In [64]:
#%time jac = vmap(jacfwd(f))(params)

In [65]:
#%time jac_tree = vmap(jacrev(model))(x)

5. Jacobian vector products (forward-mode ad)

In [78]:
x = jnp.ones(d)
x.shape, u.shape

def f(params):
    model = eqx.combine(params, static)
    return model(x)

In [79]:
y = f(params)
y

Array([-0.2048847 , -0.00798638], dtype=float32)

In [80]:
# jacobian vector product
#y, u = eqx.filter_jvp(model, (leaves,), (u,))

# vector jacobian product
y, vjp_fun = eqx.filter_vjp(f, params)
vjp_1 = vjp_fun(y)

In [81]:
#print(y)
#print(vjp_fun)
print(vjp_1)

(MLP(
  layers=[
    Linear(
      weight=f32[32,2],
      bias=f32[32],
      in_features=2,
      out_features=32,
      use_bias=True
    ),
    None,
    Linear(
      weight=f32[16,32],
      bias=f32[16],
      in_features=32,
      out_features=16,
      use_bias=True
    ),
    None,
    Linear(
      weight=f32[2,16],
      bias=f32[2],
      in_features=16,
      out_features=2,
      use_bias=True
    ),
    None
  ]
),)


In [72]:
leaves = tree_util.tree_leaves(vjp_1)

In [26]:
#vmap(jnp.dot, in_axes=(None, 1))(jnp.ones(2), jnp.ones((2, 32, 2)))

In [27]:
new_leaves = [vmap(jnp.dot, in_axes=(None, 1))(u, leaf) for leaf in leaves]
new_tree = tree_util.tree_unflatten(params_treedef, new_leaves)

In [30]:
tree_util.tree_map(lambda x: x.shape, new_tree), tree_util.tree_map(lambda x: x.shape, params)

({'params': {'Dense_0': {'bias': (32,), 'kernel': (2, 32)},
   'Dense_1': {'bias': (32,), 'kernel': (32, 32)},
   'Dense_2': {'bias': (32,), 'kernel': (32, 32)},
   'Dense_3': {'bias': (2,), 'kernel': (32, 2)}}},
 {'params': {'Dense_0': {'bias': (32,), 'kernel': (2, 32)},
   'Dense_1': {'bias': (32,), 'kernel': (32, 32)},
   'Dense_2': {'bias': (32,), 'kernel': (32, 32)},
   'Dense_3': {'bias': (2,), 'kernel': (32, 2)}}})

5. operations between pytrees with the same structure

In [113]:
tree_util.tree_map(lambda x, y: x + y, jac, updated_jac)

{'params': {'bias': Array([[-1.486341 ,  0.       ],
         [ 0.       ,  1.5494415]], dtype=float32),
  'kernel': Array([[[-1.486341 ,  0.       ],
          [-1.486341 ,  0.       ]],
  
         [[ 0.       ,  1.5494415],
          [ 0.       ,  1.5494415]]], dtype=float32)}}

In [114]:
jac

{'params': {'bias': Array([[1., 0.],
         [0., 1.]], dtype=float32),
  'kernel': Array([[[1., 0.],
          [1., 0.]],
  
         [[0., 1.],
          [0., 1.]]], dtype=float32)}}

In [116]:
updated_jac

{'params': {'bias': Array([[-2.486341  ,  0.        ],
         [-0.        ,  0.54944146]], dtype=float32),
  'kernel': Array([[[-2.486341  ,  0.        ],
          [-2.486341  ,  0.        ]],
  
         [[-0.        ,  0.54944146],
          [-0.        ,  0.54944146]]], dtype=float32)}}