<a href="https://colab.research.google.com/github/yipkingster/ml/blob/main/jax_pure_func_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**All necessary imports**

In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit
from jax import random
from jax import lax

**Be aware of caching behavior in side effect of JAX.**

Use pure functions in Jax.

In [8]:
def impure_print_side_effect(x):
  print("Executing side effect!")
  return x

print("First call: ", jit(impure_print_side_effect)(1.0))
# The same type and shape will cause a cached compilation be used.
print("Second call: ", jit(impure_print_side_effect)(2.0))
# Changed type and shape will cause re-run of the function.
print("Third call: ", jit(impure_print_side_effect)(jnp.array([5.0])))

Executing side effect!
First call:  1.0
Second call:  2.0
Executing side effect!
Third call:  [5.]


**Be ware of global variable in cache**



In [13]:
g = 0
def impure_global(x):
  return g+x

print("First call: ", jit(impure_global)(1.0))
g = 5
# Cached global var used
print("Second call: ", jit(impure_global)(2.0))
print("Third call: ", jit(impure_global)(3.0))
# Changed type and shape caused the rerun.
print("Fourth call:", jit(impure_global)(jnp.array([4.0])))

First call:  1.0
Second call:  2.0
Third call:  3.0
Fourth call: [9.]


**Saved global var will be a special tracer object in jit**

In [14]:
g = 5
def impure_save_global(x):
  global g
  g += x
  return g

print("First call: ", jit(impure_save_global)(1.0))
# global var will be cached and used.
print("Second call: ", jit(impure_save_global)(2.0))
print("g=", g)

First call:  6.0
Second call:  7.0
g= JitTracer<~float32[]>


**Stateful object is OK as long as it doesn't read or write into external state in a pure function.**

In [20]:
def internal_stateful_func(x):
  state = dict(even=0, odd=0)
  # Must use 10 and not x - jit function can't rely on external
  for i in range(10):
    state['even' if i % 2 ==0 else 'odd'] += i
  return state['even'] + state['odd']

print("First call: ", jit(internal_stateful_func)(10))
print("Second call: ", jit(internal_stateful_func)(10))

First call:  45
Second call:  45


**Be aware of control flow - iterator is discouraged in jit function or control flow primitives.**

In [21]:
a = jnp.arange(10)
print(a)
print(lax.fori_loop(0, 10, lambda i, x : x+a[i], 0))
it = iter(range(10))
print(lax.fori_loop(0, 10, lambda i, x: x+next(it), 0))

[0 1 2 3 4 5 6 7 8 9]
45
0


**Iterator discouraged in jax.lax.scan too.**


In [25]:
def func_test(arr, extra):
  ones = jnp.ones(arr.shape)
  def body(carry, aelems):
    ae1, ae2 = aelems
    return (carry+ae1*ae2+extra, carry)
  return lax.scan(body, 0., (arr, ones))
jax.make_jaxpr(func_test)(jnp.arange(10), 10)
# Throw error
#jax.make_jaxpr(func_test)(iter(range(10)), 10)

{ [34;1mlambda [39;22m; a[35m:i32[10][39m b[35m:i32[][39m. [34;1mlet
    [39;22mc[35m:f32[10][39m = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(10,)
      sharding=None
    ] 1.0:f32[]
    d[35m:f32[][39m e[35m:f32[10][39m = scan[
      _split_transpose=False
      jaxpr={ [34;1mlambda [39;22m; f[35m:i32[][39m g[35m:f32[][39m h[35m:i32[][39m i[35m:f32[][39m. [34;1mlet
          [39;22mj[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] h
          k[35m:f32[][39m = mul j i
          l[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] g
          m[35m:f32[][39m = add l k
          n[35m:f32[][39m = convert_element_type[new_dtype=float32 weak_type=False] f
          o[35m:f32[][39m = add m n
        [34;1min [39;22m(o, g) }
      length=10
      linear=(False, False, False, False)
      num_carry=1
      num_consts=1
      reverse=False
      unroll=1
    ] b 0.0:f32[] a c
  [34;1min [

**No iterator in lax.cond**

In [29]:
operand = jnp.array([5.0])
print(lax.cond(True, lambda x: x+1, lambda x: x-1, operand))
print(lax.cond(False, lambda x: x+1, lambda x: x-1, operand))
iter_oprand = iter(range(10))
# Throw error
# print(lax.cond(True, lambda x: x+1, lambda x: x-1, iter_oprand))


[6.]
[4.]
