In [1]:
import jax
import jax.numpy as jnp
import numpy as np

# pytree 中的常用函数

In [2]:
# map 函数

# 我不太会写 lambda 函数，先快速上手一个
double = lambda x: 2*x
print(f"这是lambda语句的一个快速例子：{double(5)}")

list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

print(jax.tree.map(lambda x: double(x), list_of_lists))
#或者不用我的double函数，效果是一样的：
print(jax.tree.map(lambda x: 2*x, list_of_lists))

这是lambda语句的一个快速例子：10
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]
[[2, 4, 6], [2, 4], [2, 4, 6, 8]]


In [3]:
# 我对 zip 的用法不熟，首先快速上手一个
list1 = ['a', 'b', 'c', 'd']
list2 = [1, 2, 3, 4]

# zip 函数的作用是将两个列表中的元素一一对应地组合成一个元组，返回一个迭代器
for x in zip(list1, list2):
    print(x)

# zip 函数的返回值是一个迭代器，所以我们可以将其转换为列表
list_of_tuples = list(zip(list1, list2))
print(list_of_tuples)

# zip 函数的返回值是一个迭代器，所以我们也可以将其转换为字典
dict_of_tuples = dict(zip(list1, list2))
print(dict_of_tuples)

# zip 函数的返回值是一个迭代器，所以我们也可以将其转换为集合
set_of_tuples = set(zip(list1, list2))
print(set_of_tuples)


# 以及创建一个 dict 的两种方法
## 使用 dict() 创建字典
my_dict = dict(a=1, b=2, c=3)
## 使用大括号 {} 创建字典
my_dict_alt = {'a': 1, 'b': 2, 'c': 3}

print(my_dict == my_dict_alt)  # 输出: True

('a', 1)
('b', 2)
('c', 3)
('d', 4)
[('a', 1), ('b', 2), ('c', 3), ('d', 4)]
{'a': 1, 'b': 2, 'c': 3, 'd': 4}
{('a', 1), ('d', 4), ('c', 3), ('b', 2)}
True


In [4]:
# 初始化神经网络的参数并使用 jax.tree.map 打印参数的类型和形状

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

jax.tree.map(lambda x: (type(x),x.shape), params)

[{'biases': (numpy.ndarray, (128,)), 'weights': (numpy.ndarray, (1, 128))},
 {'biases': (numpy.ndarray, (128,)), 'weights': (numpy.ndarray, (128, 128))},
 {'biases': (numpy.ndarray, (1,)), 'weights': (numpy.ndarray, (128, 1))}]

In [5]:
# 现在我们可以写出 MLP 的前向传播函数、损失函数、以及参数更新方法

# Define the forward pass.
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

# Define the loss function.
def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

# Set the learning rate.
LEARNING_RATE = 0.0001

# Using the stochastic gradient descent, define the parameter update function.
# Apply `@jax.jit` for JIT compilation (speed).
@jax.jit
def update(params, x, y):
  # Calculate the gradients with `jax.grad`.
  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of many JAX functions that has
  # built-in support for pytrees.
  # This is useful - you can apply the SGD update using JAX pytree utilities.
  return jax.tree.map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

# 自定义 pytree 结点

In [6]:
class Special(object):
  def __init__(self, x, y):
    self.x = x
    self.y = y

# jax.tree.leaves 是 JAX 提供的一个函数，用于提取 pytree 中的所有叶子节点。叶子节点是指 pytree 中不可再分解的最小单元。
jax.tree.leaves(
    [Special(0, 1),Special(2, 4)]
)

[<__main__.Special at 0x15a5cdf6660>, <__main__.Special at 0x15a5ce3c190>]

In [7]:
# 因为没有将 special 注册为 pytree，所以我们可以看到它的类型是 <class '__main__.Special'>，所以下面的句子会报错
# TypeError: unsupported operand type(s) for +: 'Special' and 'int'

'''
jax.tree.map(lambda x: x + 1,
  [Special(0, 1),Special(2, 4)]
             )
'''

'\njax.tree.map(lambda x: x + 1,\n  [Special(0, 1),Special(2, 4)]\n             )\n'

In [8]:
# 注册自定义容器示例

# 继承：RegisteredSpecial 继承了 Special，因此它拥有 Special 的所有属性和方法。
# __repr__ 方法：重写了 __repr__ 方法，用于返回对象的可读字符串表示，便于调试和打印。
class RegisteredSpecial(Special):
  def __repr__(self):
    return f"RegisteredSpecial(x={self.x}, y={self.y})"

def special_flatten(v):
  """Specifies a flattening recipe.

  Params:
    v: The value of the registered type to flatten.
  Returns:
    A pair of an iterable with the children to be flattened recursively,
    and some opaque auxiliary data to pass back to the unflattening recipe.
    The auxiliary data is stored in the treedef for use during unflattening.
    The auxiliary data could be used, for example, for dictionary keys.
  """
  children = (v.x, v.y)
  aux_data = None
  return (children, aux_data)

def special_unflatten(aux_data, children):
  """Specifies an unflattening recipe.

  Params:
    aux_data: The opaque data that was specified during flattening of the
      current tree definition.
    children: The unflattened children

  Returns:
    A reconstructed object of the registered type, using the specified
    children and auxiliary data.
  """
  return RegisteredSpecial(*children)

from jax.tree_util import register_pytree_node
# Global registration
register_pytree_node(
    RegisteredSpecial,
    special_flatten,    # Instruct JAX what are the children nodes.
    special_unflatten   # Instruct JAX how to pack back into a `RegisteredSpecial`.
)
'''
register_pytree_node 需要三个输入参数：
需要注册的类：这里是 RegisteredSpecial，表示你要将哪个类注册为 pytree 节点。
flatten 方法：定义如何将对象分解为子节点（children）和辅助数据（auxiliary data）。
unflatten 方法：定义如何从子节点和辅助数据重新构造对象。
这两个方法（flatten 和 unflatten）共同指定了对象的序列化和反序列化方式，以便 JAX 能够正确处理该类的实例。

NB: flatten 输出的顺序和 unflatten 参数的顺序不同
'''

jax.tree.map(lambda x: x + 1,
  [RegisteredSpecial(0, 1),RegisteredSpecial(2, 4)]
             )

[RegisteredSpecial(x=1, y=2), RegisteredSpecial(x=3, y=5)]

In [9]:
# 我对 NamedTuple 和 Tuple 的用法不熟，首先快速上手一个
# NamedTuple 可以用来创建不可变的对象
from typing import NamedTuple

class Point(NamedTuple):
    x: int
    y: int

# 创建一个 Point 实例
p = Point(3, 4)

# 访问属性
print(f"x: {p.x}, y: {p.y}")

# NamedTuple 实例是不可变的
# p.x = 5  # 这会报错，因为 NamedTuple 是不可变的

a_tuple = (1, 2, 3)
print(f"a_tuple的类型是{type(a_tuple)}，其第一个元素是{a_tuple[0]}.")
# test[0] = 4 # 这会报错，因为元组是不可变的

x: 3, y: 4
a_tuple的类型是<class 'tuple'>，其第一个元素是1.


In [10]:
# 有些类型的对象已经与 pytree 兼容了，所以我们可以直接使用不需要注册它们
# 比如 NamedTuple

from typing import NamedTuple, Any

class MyOtherContainer(NamedTuple):
  name: str
  a: Any
  b: Any
  c: Any

# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree.leaves([
    MyOtherContainer('Alice', 1, 2, 3),
    MyOtherContainer('Bob', 4, 3.4, 'test')
])

['Alice', 1, 2, 3, 'Bob', 4, 3.4, 'test']

In [11]:
# 我对 dataclass 的用法不熟，首先快速上手一个
# dataclass 是 Python 3.7 引入的一个装饰器，用于简化类的定义，特别是用于存储数据的类。

# dataclass 可以自动生成一些特殊方法，比如 __init__、__repr__ 和 __eq__ 等等。
# 这使得我们可以更方便地创建和使用数据类，而不需要手动编写这些方法。
# 我试过有没有类似的基类而不用 dataclass 装饰器的类，结果发现没有
from dataclasses import dataclass

@dataclass
class Point:
    x: int
    y: int


# 创建实例
p1 = Point(3, 4)
p2 = Point(3, 4)

# 自动生成的 __repr__ 方法
print(p1)  # 输出: Point(x=3, y=4)

# 自动生成的 __eq__ 方法
print(p1 == p2)  # 输出: True

Point(x=3, y=4)
True


# 吃饭之后，从这儿开始

In [12]:

from dataclasses import dataclass
import functools

@functools.partial(jax.tree_util.register_dataclass,
                   data_fields=['a', 'b', 'c'],
                   meta_fields=['name'])
@dataclass
class MyDataclassContainer(object):
  name: str
  a: Any
  b: Any
  c: Any

# MyDataclassContainer is now a pytree node.
jax.tree.leaves([
  MyDataclassContainer('apple', 5.3, 1.2, jnp.zeros([4])),
  MyDataclassContainer('banana', np.array([3, 4]), -1., 0.)
])

[5.3, 1.2, Array([0., 0., 0., 0.], dtype=float32), array([3, 4]), -1.0, 0.0]

In [13]:
@jax.jit
def f(x: MyDataclassContainer | MyOtherContainer):
  return x.a + x.b

# Works fine! `mdc.name` is static.
mdc = MyDataclassContainer('mdc', 1, 2, 3)
y = f(mdc)

In [14]:
moc = MyOtherContainer('moc', 1, 2, 3)
y = f(moc)

TypeError: Error interpreting argument to <function f at 0x0000015A5CE425C0> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path x.name.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.