In [1]:
# 导入必要的包
import jax
import jax.numpy as jnp

# Jax 转换如何工作？

In [2]:
global_list = []

def log2(x):
  global_list.append(x) #this is omitted in the jaxpr
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))

def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))

{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:f32[][39m = log a
    c[35m:f32[][39m = log 2.0:f32[]
    d[35m:f32[][39m = div b c
  [34;1min [39;22m(d,) }
printed x: JitTracer<~float32[]>
{ [34;1mlambda [39;22m; a[35m:f32[][39m. [34;1mlet
    [39;22mb[35m:f32[][39m = log a
    c[35m:f32[][39m = log 2.0:f32[]
    d[35m:f32[][39m = div b c
  [34;1min [39;22m(d,) }


In [3]:
jax.jit(log2_with_print)(4) #可以看到 print的输出只在第一次运行的时候出现，因为第一次运行会被编译成jaxpr

printed x: JitTracer<~int32[]>


Array(2., dtype=float32, weak_type=True)

In [4]:
jax.jit(log2_with_print)(5) # 第二次运行的时候不会有print的输出，因为已经编译过了

Array(2.321928, dtype=float32, weak_type=True)

In [5]:
# Jax.jit 和 python 的标准流程控制的关系需要注意
# A key thing to understand is that a jaxpr captures the function as executed on the parameters given to it. For example, if we have a Python conditional, the jaxpr will only know about the branch we take:

def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

{ [34;1mlambda [39;22m; a[35m:i32[3][39m. [34;1mlet[39;22m  [34;1min [39;22m(a,) }


# JIT compiling a function

In [6]:
# JIT编译一个函数得到更好的性能
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()


selu_jit = jax.jit(selu)

# Pre-compile the function before timing...
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

2.65 ms ± 147 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
778 μs ± 49.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
# 有时候jit整个函数不太好操作，这个时候，我们可以jit函数的某一部分

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

g_inner_jitted(10, 20)

Array(30, dtype=int32, weak_type=True)

In [8]:
# Condition on value of x.

def f(x):
  if x > 0:
    return x
  else:
    return 2 * x
# 这个函数的jit编译会失败，因为它的分支条件是基于x的值

# 但是可以用 static_argnums 来解决这个问题
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))

10


In [None]:
# static 的另一个例子
def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

# jax.jit(g)(10, 20)  # 会报错

# 但是可以用 static_argnames 来解决这个问题
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(114514, 20))

In [9]:
# 要注意，对于 partial 和 lambda 来说，jit会每次都重新编译，因为它们的hash值是不同的
from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()

print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()

print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()

jit called in a loop with partials:
249 ms ± 6.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
249 ms ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
1.3 ms ± 78.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
