In [1]:
import jax

# 为jax设置64位浮点数
from jax import config
config.update("jax_enable_x64", True)

import jax.numpy as jnp
import numpy as np

In [2]:
# JAX arrays

x = jnp.arange(5)
print(f"Is x of type 'jax.Array'?: {isinstance(x, jax.Array)}")
print(f"The type of x is {type(x)}. It is on device {x.devices()}.")
print(x)

Is x of type 'jax.Array'?: True
The type of x is <class 'jaxlib._jax.ArrayImpl'>. It is on device {CpuDevice(id=0)}.
[0 1 2 3 4]


In [3]:
# Tracers

@jax.jit
def f(x):
  print(x)
  #jax.debug.print("x: {}", x)  # 使用 jax.debug.print 代替普通的 print
  return x + 1

x = jnp.arange(5)
result_x = f(x)

y = jnp.arange(10)
result_y = f(y)

JitTracer<int64[5]>
JitTracer<int64[10]>


In [4]:
# 这是 jaxpr 的一个例子

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)

{ [34;1mlambda [39;22m; a[35m:f64[5][39m. [34;1mlet
    [39;22mb[35m:bool[5][39m = gt a 0.0:f64[]
    c[35m:f64[5][39m = exp a
    d[35m:f64[5][39m = mul 1.67:f64[] c
    e[35m:f64[5][39m = sub d 1.67:f64[]
    f[35m:f64[5][39m = jit[
      name=_where
      jaxpr={ [34;1mlambda [39;22m; b[35m:bool[5][39m a[35m:f64[5][39m e[35m:f64[5][39m. [34;1mlet
          [39;22mf[35m:f64[5][39m = select_n b e a
        [34;1min [39;22m(f,) }
    ] b a e
    g[35m:f64[5][39m = mul 1.05:f64[] f
  [34;1min [39;22m(g,) }

In [5]:
# Pytree 的简单例子

# Named tuple of parameters
from typing import NamedTuple # 我没有学过这个，挺好玩的
class Params(NamedTuple):
  a: int
  b: float
  c: str = "default"

params = Params(1, 5.0, 67) # 并没有强制类型检查，所以用一个int来初始化了一个str，后来并不会强制转化成str
print(params)


print(jax.tree.structure(params))
print(jax.tree.leaves(params))

Params(a=1, b=5.0, c=67)
PyTreeDef(CustomNode(namedtuple[Params], [*, *, *]))
[1, 5.0, 67]


In [6]:
# 下面是一个 Pytree 更复杂的例子

# 定义一个嵌套的 NamedTuple 数据结构
class SubParams(NamedTuple):
    d: jnp.ndarray
    e: float

class Params(NamedTuple):
    a: int
    b: SubParams
    c: list

# 创建一个复杂的实例
params = Params(
    a=42,
    b=SubParams(
        d=jnp.array([1.0, 2.0, 3.0]),
        e=3.14
    ),
    c=[jnp.array([4.0, 5.0]), 6]
)

# 打印树结构和叶子节点
print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef(CustomNode(namedtuple[Params], [*, CustomNode(namedtuple[SubParams], [*, *]), [*, *]]))
[42, Array([1., 2., 3.], dtype=float64), 3.14, Array([4., 5.], dtype=float64), 6]
