**Structured Control Flow Primitives**

In [12]:
from jax import numpy as jnp
from jax import lax
from jax import grad
import jax

**lax.cond**

In [4]:
operand = jnp.array([2.0])
print(lax.cond(True, lambda x:x+1, lambda x:x-1, operand))
print(lax.cond(False, lambda x:x+1, lambda x:x-1, operand))

[3.]
[1.]


**lax.while_loop**

In [6]:
cond_func = lambda x: x<10
body_func = lambda x: x+1
print(lax.while_loop(cond_func, body_func, 0))

10


**lax.fori_loop**

In [8]:
initial_value = 0
start = 0
stop = 10
body_func = lambda i,x:x+i
print(lax.fori_loop(start, stop, body_func, initial_value))

45


**JAX Logical Operators don't short circuit**


In [10]:
def python_check_positive_even(x: int):
  if x % 2 == 0:
    is_even = True
  else:
    is_even = False
  return is_even and x>0

def jax_check_positive_even(x: int):
  if x % 2 == 0:
    is_even = True
  else:
    is_even = False
  # It doesn't short circuit - x>0 is still evaluated even if is_even == True
  return jnp.logical_and(is_even, x>0)

print(python_check_positive_even(24))
print(jax_check_positive_even(24))


True
True


**Regular Python function works fine if jit isn't used**

In [14]:
def func(x):
  if x > 3:
    return 3
  else:
    return x

print(grad(func)(2.0))


1.0
