# JAX scratchpad

In [16]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

devices = jax.devices()
assert len(devices) == 1

In [14]:
d0 = devices[0]
assert d0.device_kind == 'NVIDIA GeForce RTX 2070 SUPER'

In [None]:
d0 = jax.devices()[0]
for x in dir(d0):
    if x[0] != '_':
        print(f'{x:40} : {getattr(d0, x)}')

# JAX Gotchas & FAQ
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
# https://jax.readthedocs.io/en/latest/faq.html


# JAX Tutorials
# https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html
# https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html
# https://jax.readthedocs.io/en/latest/jax-101/03-gradients.html
# https://jax.readthedocs.io/en/latest/jax-101/04-broadcasting.html
# https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html
# https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html

# FLAX Tutorials
# https://flax.readthedocs.io/en/latest/
# https://flax.readthedocs.io/en/latest/getting_started.html
# https://flax.readthedocs.io/en/latest/notebooks/Flax_basics.html

# FLAX NNX vs FLAX Linen

In [14]:
def pure_uses_internal_state(x):
  print(f'Running pure_uses_internal_state, x={x}')
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x  
  return state['even'] + state['odd']

print(jax.jit(pure_uses_internal_state)(5))
print(jax.jit(pure_uses_internal_state)(5.))
print(jax.jit(pure_uses_internal_state)(10.))

Running pure_uses_internal_state, x=Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
50
Running pure_uses_internal_state, x=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
50.0
100.0


In [13]:
# def pure_fn_loop(n: int):
#   print(f'Running pure_fn_loop, n={n}')
#   total = 0
#   for i in range(n):
#     total += i
#   return total

# Not allowed to loop over range(Traced[]
# print(jax.jit(pure_fn_loop)(10))

In [20]:
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45

iterator = iter(range(2, 10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

r_list = list(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+r_list[i], 0)) # expected result 45


45
20


TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function scanned_fun at /home/vscode/.local/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1945 for scan. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [21]:
jnp.sum([1, 2, 3])

TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.

In [25]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

print(grad(f)(2.))  # ok!
print(grad(f)(4.))  # ok!

# print(grad(jax.jit(f))(2.))  # bad
# print(grad(jax.jit(f))(4.))  # ok!

12.0
-4.0


In [42]:
@jit
def f(x):
  print(f'Running f, x={x}')
  #for i in range(3):
  #  x = x * x
  #return x
  # return 2 * x * x + 5 * x + 999
  return 2 * x ** 3 + 5 * x + 999

print(f(3))
print(grad(f)(1.0))
print(grad(f)(3.0))
print(grad(f)(4.0))


Running f, x=Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
1068
Running f, x=Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=3/0)>
11.0
59.0
101.0


In [45]:
@jit
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

# This will fail!
f(2)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /tmp/ipykernel_11976/3402096563.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [55]:
def f(x, n):
  y = 0.0
  for i in range(n):
    y = y + x[i]
  return y

f = jit(f, static_argnums=(1,))

for n in [2,4,6,8,10]:
  print(f(jnp.array([2., 3., 4.]), n))
  print(grad(f)(jnp.array([2., 3., 4.]), n))




5.0
[1. 1. 0.]
13.0
[1. 1. 1.]
21.0
[1. 1. 1.]
29.0
[1. 1. 1.]
37.0
[1. 1. 1.]
