# JAX scratchpad

In [16]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

devices = jax.devices()
assert len(devices) == 1

In [14]:
d0 = devices[0]
assert d0.device_kind == 'NVIDIA GeForce RTX 2070 SUPER'

In [None]:
d0 = jax.devices()[0]
for x in dir(d0):
    if x[0] != '_':
        print(f'{x:40} : {getattr(d0, x)}')

# JAX Gotchas & FAQ
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
# https://jax.readthedocs.io/en/latest/faq.html


# JAX Tutorials
# https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html
# https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html
# https://jax.readthedocs.io/en/latest/jax-101/03-gradients.html
# https://jax.readthedocs.io/en/latest/jax-101/04-broadcasting.html
# https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html
# https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html

# FLAX Tutorials
# https://flax.readthedocs.io/en/latest/
# https://flax.readthedocs.io/en/latest/getting_started.html
# https://flax.readthedocs.io/en/latest/notebooks/Flax_basics.html

# FLAX NNX vs FLAX Linen

In [14]:
def pure_uses_internal_state(x):
  print(f'Running pure_uses_internal_state, x={x}')
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x  
  return state['even'] + state['odd']

print(jax.jit(pure_uses_internal_state)(5))
print(jax.jit(pure_uses_internal_state)(5.))
print(jax.jit(pure_uses_internal_state)(10.))

Running pure_uses_internal_state, x=Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
50
Running pure_uses_internal_state, x=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
50.0
100.0


In [13]:
# def pure_fn_loop(n: int):
#   print(f'Running pure_fn_loop, n={n}')
#   total = 0
#   for i in range(n):
#     total += i
#   return total

# Not allowed to loop over range(Traced[]
# print(jax.jit(pure_fn_loop)(10))

In [20]:
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45

iterator = iter(range(2, 10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

r_list = list(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+r_list[i], 0)) # expected result 45


45
20


TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function scanned_fun at /home/vscode/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1945 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [21]:
jnp.sum([1, 2, 3])

TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

In [25]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!

# print(grad(jax.jit(f))(2.))  # bad
# print(grad(jax.jit(f))(4.))  # ok!

12.0
-4.0


In [42]:
@jit
def f(x):
  print(f'Running f, x={x}')
  #for i in range(3):
  #  x = x * x
  #return x
  # return 2 * x * x + 5 * x + 999
  return 2 * x ** 3 + 5 * x + 999

print(f(3))
print(grad(f)(1.0))
print(grad(f)(3.0))
print(grad(f)(4.0))


Running f, x=Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
1068
Running f, x=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
11.0
59.0
101.0


In [45]:
@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
f(2)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_11976/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [55]:
def f(x, n):
  y = 0.0
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

for n in [2,4,6,8,10]:
  print(f(jnp.array([2., 3., 4.]), n))
  print(grad(f)(jnp.array([2., 3., 4.]), n))




5.0
[1. 1. 0.]
13.0
[1. 1. 1.]
21.0
[1. 1. 1.]
29.0
[1. 1. 1.]
37.0
[1. 1. 1.]


In [63]:
z =jnp.zeros((2,3), jnp.int32)
print(z)
print(z.at[1,2].set(4.4))
print(z.at[1,2].set(5))
print(z)


[[0 0 0]
 [0 0 0]]
[[0 0 0]
 [0 0 4]]
[[0 0 0]
 [0 0 5]]
[[0 0 0]
 [0 0 0]]




In [67]:
from jax import random

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()

389 μs ± 16.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
72.1 μs ± 843 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [68]:
from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


In [69]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x)

In [73]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))

def apply_matrix(x):
  return jnp.dot(mat, x)

In [75]:
apply_matrix(batched_x[0]).shape

(150,)

In [80]:
print(mat.devices())
print(mat.sharding)


{CudaDevice(id=0)}
SingleDeviceSharding(device=CudaDevice(id=0), memory_kind=device)


In [81]:
x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)

{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; g:bool[5] h:f32[5] i:f32[5]. let
          j:f32[5] = select_n g i h
        in (j,) }
    ] b a e
    k:f32[5] = mul 1.0499999523162842 f
  in (k,) }

In [83]:
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
print(params)

PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
[1, 2, (Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32))]


In [84]:
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]


In [85]:
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))

PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]


In [86]:
import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


In [95]:
#def f(x):
#    return 2 * x**5 + 10 * x**3 + 2 * x**2 + 3 * x + 4

def f(x): return x**8

print(f'f(1.0) = {f(1.0)}')

g = f
for i in range(1,10):
    g = grad(g)
    print(f'g_{i}(1.0) = {g(1.0)}')





f(1.0) = 1.0
g_1(1.0) = 8.0
g_2(1.0) = 56.0
g_3(1.0) = 336.0
g_4(1.0) = 1680.0
g_5(1.0) = 6720.0
g_6(1.0) = 20160.0
g_7(1.0) = 40320.0
g_8(1.0) = 40320.0
g_9(1.0) = 0.0


In [99]:

@jax.jit
def f(x):
  print("print(x) ->", x)
  jax.debug.print("jax.debug.print(y) -> {x}", x=x)
  y = jnp.sin(x)
  print("print(y) ->", y)
  jax.debug.print("jax.debug.print(y) -> {y}", y=y)
  return y

print('first call')
result = f(2.)
print('second call')
result = f(20.)

first call
print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
jax.debug.print(y) -> 0.9092974662780762
jax.debug.print(y) -> 2.0
second call
jax.debug.print(y) -> 0.9129452705383301
jax.debug.print(y) -> 20.0


In [117]:
@jax.jit
def f(x):
  print('compiling')
  y = x*x
  jax.debug.print("(x, y) -> ({x}, {y})", x=x,y=y)
  return y

print('first call')
jax.vmap(f)(jnp.arange(5.))
print('second call')
jax.vmap(f)(jnp.arange(2.))
print('third call')
jax.vmap(f)(jnp.arange(3))


first call
compiling
(x, y) -> (1.0, 1.0)
(x, y) -> (2.0, 4.0)
(x, y) -> (3.0, 9.0)
(x, y) -> (4.0, 16.0)
(x, y) -> (0.0, 0.0)
second call
(x, y) -> (1.0, 1.0)
(x, y) -> (0.0, 0.0)
third call
compiling
(x, y) -> (1, 1)
(x, y) -> (2, 4)
(x, y) -> (0, 0)


Array([0, 1, 4], dtype=int32)

In [119]:
@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was log2 at /tmp/ipykernel_11976/495203139.py:6 traced for jit.
------------------------------
The leaked intermediate value was created on line /tmp/ipykernel_11976/495203139.py:12 (<module>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/home/vscode/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py:128 (_pseudo_sync_runner)
/home/vscode/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3334 (run_cell_async)
/home/vscode/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3517 (run_ast_nodes)
/home/vscode/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3577 (run_code)
/tmp/ipykernel_11976/495203139.py:12 (<module>)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError