<a href="https://colab.research.google.com/github/teruyuki-yamasaki/HelloBrax/blob/main/pytree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html)

## What is a pytree?
- a tree-like structure build out of container-like Python objects. 

## Internal pytree handling
- JAX flattens pytrees into lists of leaves 
- while encoding their original tree strucures  (treedef)
- to easily apply on the leaves the operations that take and return array(s)


JAX flattens pytrees into lists of leaves at the api.py boundary (and also in control flow primitives). This keeps downstream JAX internals simpler: transformations like grad(), jit(), and vmap() can handle user functions that accept and return the myriad different Python containers, while all the other parts of the system can operate on functions that only take (multiple) array arguments and always return a flat list of arrays.

When JAX flattens a pytree it will produce a list of leaves and a treedef object that encodes the structure of the original value. The treedef can then be used to construct a matching structured value after transforming the leaves. Pytrees are tree-like, rather than DAG-like or graph-like, in that we handle them assuming referential transparency and that they can’t contain reference cycles.

In [4]:
import jax 
import jax.numpy as jnp 
from jax import tree_flatten, tree_unflatten

# The structured value to be transformed
value_structured = [1., (2., 3.)]

# The leaves in value_flat correspond to the `*` markers in value_tree
value_flat, value_tree = tree_flatten(value_structured)
print("value_flat={}\nvalue_tree={}".format(value_flat, value_tree))

# Transform the flat value list using an element-wise numeric transformer
transformed_flat = list(map(lambda v: v * 2., value_flat))
print("transformed_flat={}".format(transformed_flat))

# Reconstruct the structured output, using the original
transformed_structured = tree_unflatten(value_tree, transformed_flat)
print("transformed_structured={}".format(transformed_structured))

value_flat=[1.0, 2.0, 3.0]
value_tree=PyTreeDef([*, (*, *)])
transformed_flat=[2.0, 4.0, 6.0]
transformed_structured=[2.0, (4.0, 6.0)]


By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves:

In [7]:
from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])

example_containers = [
    (1., [2., 3.]),
    (1., {'b': 2., 'a': 3.}),
    1.,
    None,
    jnp.zeros(2),
    Point(1., 2.)
]
def show_example(structured):
  flat, tree = tree_flatten(structured)
  unflattened = tree_unflatten(tree, flat)
  print("structured={}\n  flat={}\n  tree={}\n  unflattened={}".\
        format(structured, flat, tree, unflattened))
  print()

for structured in example_containers:
  show_example(structured)

structured=(1.0, [2.0, 3.0])
  flat=[1.0, 2.0, 3.0]
  tree=PyTreeDef((*, [*, *]))
  unflattened=(1.0, [2.0, 3.0])

structured=(1.0, {'b': 2.0, 'a': 3.0})
  flat=[1.0, 3.0, 2.0]
  tree=PyTreeDef((*, {'a': *, 'b': *}))
  unflattened=(1.0, {'a': 3.0, 'b': 2.0})

structured=1.0
  flat=[1.0]
  tree=PyTreeDef(*)
  unflattened=1.0

structured=None
  flat=[]
  tree=PyTreeDef(None)
  unflattened=None

structured=[0. 0.]
  flat=[DeviceArray([0., 0.], dtype=float32)]
  tree=PyTreeDef(*)
  unflattened=[0. 0.]

structured=Point(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(namedtuple[<class '__main__.Point'>], [*, *]))
  unflattened=Point(x=1.0, y=2.0)



## Extending pytrees
By default, any part of a structured value that is not recognized as an internal pytree node (i.e. container-like) is treated as a leaf:

In [10]:
class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

show_example(Special(1., 2.))

class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return "Special(x={}, y={})".format(self.x, self.y)

show_example(Special(1., 2.))

structured=<__main__.Special object at 0x7fb983a04cd0>
  flat=[<__main__.Special object at 0x7fb983a04cd0>]
  tree=PyTreeDef(*)
  unflattened=<__main__.Special object at 0x7fb983a04cd0>

structured=Special(x=1.0, y=2.0)
  flat=[Special(x=1.0, y=2.0)]
  tree=PyTreeDef(*)
  unflattened=Special(x=1.0, y=2.0)



In the above example, you can see that the istance of ``Special`` is treated as a leaf. 

The set of Python types that are considered internal pytree nodes is extensible, through a global registry of types, and values of registered types are traversed recursively. To register a new type, you can use register_pytree_node():

In [12]:
from jax.tree_util import register_pytree_node

class RegisteredSpecial(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: the value of registered type to flatten.
  Returns:
    a pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, e.g., for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: the opaque data that was specified during flattening of the
      current treedef.
    children: the unflattened children

  Returns:
    a re-constructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

# Global registration
register_pytree_node(
    RegisteredSpecial,  # tell JAX which class you want to register as am extended pytree node
    special_flatten,    # tell JAX what are the children nodes
    special_unflatten   # tell JAX how to pack back into a RegisteredSpecial
)

show_example(RegisteredSpecial(1., 2.))

structured=RegisteredSpecial(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(<class '__main__.RegisteredSpecial'>[None], [*, *]))
  unflattened=RegisteredSpecial(x=1.0, y=2.0)



In [20]:
class RegisteredSpecial3(Special):
  def __repr__(self):
    return "RegisteredSpecial(x={}, y={})".format(self.x, self.y)
    
register_pytree_node(
    RegisteredSpecial3,
    lambda cls: ((cls.x, cls.y), None),              
    lambda _, args: RegisteredSpecial3(*args) 
)
show_example(RegisteredSpecial(1., 2.))

structured=RegisteredSpecial(x=1.0, y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(<class '__main__.RegisteredSpecial'>[None], [*, *]))
  unflattened=RegisteredSpecial(x=1.0, y=2.0)



Now ``RegisteredSpecial`` is recognized by JAX as a pytree node, not a leave. 
Its variabels ``x`` and ``y`` are each recognized as a leaf at the same level.

Alternatively, you can define appropriate tree_flatten and tree_unflatten methods on your class and decorate it with register_pytree_node_class():

In [16]:
from jax.tree_util import register_pytree_node_class 

@register_pytree_node_class 
class RegisteredSpecial2(Special):
    def __repr__(self):
        return "RegisteredSpecial2(x={},y={})".format(self.x, self.y)
    
    def tree_flatten(self):
        children = (self.x, self.y)
        aux_data = None
        return (children, aux_data)
    
    @classmethod 
    def tree_unflatten(cls, aux_data, children):
        return cls(*children)

show_example(RegisteredSpecial2(1.,2.))

structured=RegisteredSpecial2(x=1.0,y=2.0)
  flat=[1.0, 2.0]
  tree=PyTreeDef(CustomNode(<class '__main__.RegisteredSpecial2'>[None], [*, *]))
  unflattened=RegisteredSpecial2(x=1.0,y=2.0)



JAX sometimes needs to compare treedef for equality. Therefore, care must be taken to ensure that the auxiliary data specified in the flattening recipe supports a meaningful equality comparison.

The whole set of functions for operating on pytrees are in jax.tree_util.

