In [8]:
from jax import numpy as jnp
from jax.tree_util import tree_flatten, tree_unflatten

In [9]:
def show_example(structured):
    flat, tree = tree_flatten(structured)
    unflattened = tree_unflatten(tree, flat)
    print("structured={}\n  flat={}\n  tree={}\n  unflattened={}".format(
        structured, flat, tree, unflattened))

In [21]:
class Fourier(object):
    def __init__(self, ext_params={"bias": 1.0}):
        self.params = jnp.array([1, 2, 3])
        self.aux = ext_params
        
    def __call__(self, x):
        return self.params + x + self.aux["bias"]

    def __repr__(self):
        return f"Fourier {self.params}"
    
my_field = Fourier()

In [28]:
my_field.params

DeviceArray([1, 2, 3], dtype=int32)

In [22]:
my_field

Fourier [1 2 3]

In [23]:
show_example(my_field)

structured=Fourier [1 2 3]
  flat=[Fourier [1 2 3]]
  tree=PyTreeDef(*)
  unflattened=Fourier [1 2 3]


In [81]:
from jax.tree_util import register_pytree_node
from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class RegisteredFourier(Fourier):
    def __repr__(self):
        return f"Registered Fourier {self.params}"
    
    def tree_flatten(self):
        children = self.params
        print(children.__repr__())
        aux_data = self.aux
        return (children, aux_data)
    
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        a = cls(aux_data)
        print(a.params.__repr__())
        a.params = children
        print(a.params.__repr__())
        return a

In [82]:
my_field_reg = RegisteredFourier()

In [83]:
my_field_reg

Registered Fourier [1 2 3]

In [84]:
show_example(my_field_reg)

DeviceArray([1, 2, 3], dtype=int32)
DeviceArray([1, 2, 3], dtype=int32)
(DeviceArray(1, dtype=int32), DeviceArray(2, dtype=int32), DeviceArray(3, dtype=int32))
structured=Registered Fourier [1 2 3]
  flat=[DeviceArray(1, dtype=int32), DeviceArray(2, dtype=int32), DeviceArray(3, dtype=int32)]
  tree=PyTreeDef(CustomNode(<class '__main__.RegisteredFourier'>[{'bias': 1.0}], [*, *, *]))
  unflattened=Registered Fourier (DeviceArray(1, dtype=int32), DeviceArray(2, dtype=int32), DeviceArray(3, dtype=int32))


In [None]:
from jax import jit

def fun(u):
    return u + 1