In [1]:
## Standard libraries
import os
import math
import numpy as np
import time

## Progress bar
from tqdm.auto import tqdm

import jax
import jax.numpy as jnp

In [2]:
a = jnp.zeros((2, 5), dtype=jnp.float32)




In [3]:
jax.devices()

[CpuDevice(id=0)]

In [4]:
a.at[0].set(1)

DeviceArray([[1., 1., 1., 1., 1.],
             [0., 0., 0., 0., 0.]], dtype=float32)

In [5]:
rng = jax.random.PRNGKey(42)

In [6]:
jax_random_number_1 = jax.random.normal(rng)
jax_random_number_2 = jax.random.normal(rng)
print('JAX - Random number 1:', jax_random_number_1)
print('JAX - Random number 2:', jax_random_number_2)

JAX - Random number 1: -0.18471177
JAX - Random number 2: -0.18471177


In [7]:
rng, subkey1, subkey2 = jax.random.split(rng, num=3)  # We create 3 new keys
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print('JAX new - Random number 1:', jax_random_number_1)
print('JAX new - Random number 2:', jax_random_number_2)

JAX new - Random number 1: 0.107961535
JAX new - Random number 2: -1.2226542


In [8]:
def simple_graph(x):
    x = x + 2
    x = x ** 2
    x = x + 3
    y = x.mean()
    return y

inp = jnp.arange(3, dtype=jnp.float32)
print('Input', inp)
print('Output', simple_graph(inp))

Input [0. 1. 2.]
Output 12.666667


In [9]:
jax.make_jaxpr(simple_graph)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 2.0
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = add c 3.0
    e[35m:f32[][39m = reduce_sum[axes=(0,)] d
    f[35m:f32[][39m = div e 3.0
  [34m[22m[1min [39m[22m[22m(f,) }

In [10]:
grad_function = jax.grad(simple_graph)
gradients = grad_function(inp)
print('Gradient', gradients)

Gradient [1.3333334 2.        2.6666667]


In [11]:
jax.make_jaxpr(grad_function)(inp)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[3][39m = add a 2.0
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = integer_pow[y=1] b
    e[35m:f32[3][39m = mul 2.0 d
    f[35m:f32[3][39m = add c 3.0
    g[35m:f32[][39m = reduce_sum[axes=(0,)] f
    _[35m:f32[][39m = div g 3.0
    h[35m:f32[][39m = div 1.0 3.0
    i[35m:f32[3][39m = broadcast_in_dim[broadcast_dimensions=() shape=(3,)] h
    j[35m:f32[3][39m = mul i e
  [34m[22m[1min [39m[22m[22m(j,) }

In [12]:
val_grad_function = jax.value_and_grad(simple_graph)
val_grad_function(inp)

(DeviceArray(12.666667, dtype=float32),
 DeviceArray([1.3333334, 2.       , 2.6666667], dtype=float32))