In [4]:
# jaxpr

from jax import make_jaxpr
import jax.numpy as jnp

def softmax(x, c=1):
    x = x * c
    x = jnp.exp(x)
    return x / jnp.sum(x)

jaxpr = make_jaxpr(softmax)(jnp.array([1., 2., 3.]), jnp.float32(10.))
print(jaxpr)


{ lambda ; a:f32[3] b:f32[]. let
    c:f32[3] = mul a b
    d:f32[3] = exp c
    e:f32[] = reduce_sum[axes=(0,)] d
    f:f32[3] = div d e
  in (f,) }


In [12]:
# functional transformations

# grad
from jax import grad

def f(x):
  return 4.0 * x ** 3 + 3.0 * x ** 2 + 2.0 * x + 1.0

print(f(1.0), grad(f)(1.0))

print(make_jaxpr(f)(1.0))
print(make_jaxpr(grad(f))(1.0))


10.0 20.0
{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=3] a
    c:f32[] = mul 4.0 b
    d:f32[] = integer_pow[y=2] a
    e:f32[] = mul 3.0 d
    f:f32[] = add c e
    g:f32[] = mul 2.0 a
    h:f32[] = add f g
    i:f32[] = add h 1.0
  in (i,) }
{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=3] a
    c:f32[] = integer_pow[y=2] a
    d:f32[] = mul 3.0 c
    e:f32[] = mul 4.0 b
    f:f32[] = integer_pow[y=2] a
    g:f32[] = integer_pow[y=1] a
    h:f32[] = mul 2.0 g
    i:f32[] = mul 3.0 f
    j:f32[] = add e i
    k:f32[] = mul 2.0 a
    l:f32[] = add j k
    _:f32[] = add l 1.0
    m:f32[] = mul 2.0 1.0
    n:f32[] = mul 3.0 1.0
    o:f32[] = mul n h
    p:f32[] = add_any m o
    q:f32[] = mul 4.0 1.0
    r:f32[] = mul q d
    s:f32[] = add_any p r
  in (s,) }


In [17]:
# jit
from jax import jit

def relu1(x):
  if x > 0:
    return x
  else:
    return 0.
  
def relu2(x):
  return jnp.where(x > 0, x, 0)

x = -1.0

print(relu1(x))
print(relu2(x))

jrelu1 = jit(relu1)
jrelu2 = jit(relu2)

# print(jrelu1(x))
print(jrelu2(x))


0.0
0.0
0.0
