In [3]:
import numpy as np
import jax.numpy as jnp
from jax import jit

# Manipulating JAX arrays

## Declaring and manipulating

JAX looks really similar to Numpy on several aspects (methods are very similar), although the main difference is that jnp arrays are *immutable* as the following examples show:

In [16]:
# We can change elements using slices after instanciating the array
npArr    = np.arange(10)
npArr[0] = 10
print("Target array: {}\n".format(npArr))

jnpArr = jnp.arange(10)
try:
    jnpArr[0] = 10
except Exception as e:
    print("This does not work, instead we could use .at[] and .set() to create a copy of the array.")
    copiedJnpArr = jnpArr.at[0].set(10)
    print("Old array: {}".format(jnpArr))
    print("New array: {}\n".format(copiedJnpArr))
    
    print("Let's have a look at the risen exception: {}.".format(e))
    

Target array: [10  1  2  3  4  5  6  7  8  9]

This does not work, instead we could use .at[] and .set() to create a copy of the array.
Old array: [0 1 2 3 4 5 6 7 8 9]
New array: [10  1  2  3  4  5  6  7  8  9]

Let's have a look at the risen exception: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?.


In [17]:
from jax.ops import index, index_add, index_update

# We now follow the above recommendation
copiedJnpArrUpdate = index_update(jnpArr, index[3:6], 0.)
copiedJnpArrAdd = index_add(jnpArr, index[3:6], 2.)
print("Old array: {}".format(jnpArr))
print("New array, update 3:6: {}".format(copiedJnpArrUpdate))
print("New array, add 2 to index in 3:6: {}\n".format(copiedJnpArrAdd))

Old array: [0 1 2 3 4 5 6 7 8 9]
New array, update 3:6: [0 1 2 0 0 0 6 7 8 9]
New array, add 2 to index in 3:6: [0 1 2 5 6 7 6 7 8 9]



## Out of bound indexing

In [51]:
jnpArr = jnp.arange(10)
print("The 20-th element of the array of shape (10,) is: {}. For JAX arrays, out of bound".format(jnpArr[20]) +
     " indexing does not throw an exception, and rather return the last element of the array.")

The 20-th element of the array of shape (10,) is: 9. For JAX arrays, out of bound indexing does not throw an exception, and rather return the last element of the array.


# JIT

JIT stands for Just-In-Time, and acts as a decorator on a python method. JAX executes operations sequencially, and JIT tries to optimize the operations' execution time. However JIT has a limited scope which we will expose in the next cells.

## JIT-able method

In [23]:
# First we give an example of a JIT-able method

def MatMul(X, Y):
    Z = jnp.dot(X, Y) # X @ Y is equivalent
    return Z

CompiledMatMul = jit(MatMul)

np.random.seed(1)
X = jnp.array(np.random.rand(100, 1000))
Y = jnp.array(np.random.rand(1000, 100))

%timeit MatMul(X, Y).block_until_ready()
%timeit CompiledMatMul(X, Y).block_until_ready()


334 µs ± 26.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
153 µs ± 5.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [60]:
# We now inspect what's in the JITed function

@jit
def VerboseMatMul(X, Y):
    print("Running VerboseMatMul")
    print("X is {}".format(X))
    print("Y is {}".format(Y))
    Z = jnp.dot(X, Y) # X @ Y is equivalent
    print("Z is {}".format(Z))
    return Z

Z = VerboseMatMul(X, Y)
print("\nWe can see a bunch of traced arrays with fixed shapes. Let's try to re-run the"
      + "function that should now be compiled.")

Z = VerboseMatMul(X, Y)
print("\nThe messages are not printed this time, as we now manipulate the compiled version of the method."
      + " We can have a look at the JAX expression of the function as follow:\n")

from jax import make_jaxpr

print(make_jaxpr(MatMul)(X, Y))

print("\nOne last thing! If one of the input's shape changes, the function will be recompiled."+
     " This can be disastrous in case this shape varies a lot.\n")
X1 = jnp.array(np.random.rand(101, 1000))
Z = VerboseMatMul(X1, Y)

Running VerboseMatMul
X is Traced<ShapedArray(float32[100,1000])>with<DynamicJaxprTrace(level=0/1)>
Y is Traced<ShapedArray(float32[1000,100])>with<DynamicJaxprTrace(level=0/1)>
Z is Traced<ShapedArray(float32[100,100])>with<DynamicJaxprTrace(level=0/1)>

We can see a bunch of traced arrays with fixed shapes. Let's try to re-run thefunction that should now be compiled.

The messages are not printed this time, as we now manipulate the compiled version of the method. We can have a look at the JAX expression of the function as follow:

{ lambda  ; a b.
  let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                       precision=None
                       preferred_element_type=None ] a b
  in (c,) }

One last thing! If one of the input's shape changes, the function will be recompiled. This can be disastrous in case this shape varies a lot.

Running VerboseMatMul
X is Traced<ShapedArray(float32[101,1000])>with<DynamicJaxprTrace(level=0/1)>
Y is Traced<ShapedArray(flo

## Un-JIT-able methods

In [48]:
# JITed methods must only manipulate arrays with static shapes

def GetNegatives(x):
    return x[x < 0]

np.random.seed(1)
x = jnp.array(np.random.randn(10))
print("Using the non-compiled version of GetNegatives works, and returns: {}".format(GetNegatives(x)))

CompiledGetNegatives = jit(GetNegatives)

try:
    print(CompiledGetNegatives(x))
except Exception as e:
    print("\nHowever the compiled version does not accept indices having a shape that can vary " + 
          "(called non-concrete or abstract). The raised exception reads:\n")
    print(e)


Using the non-compiled version of GetNegatives works, and returns: [-0.6117564  -0.5281718  -1.0729686  -2.3015387  -0.7612069  -0.24937038]

However the compiled version does not accept indices having a shape that can vary (called non-concrete or abstract). The raised exception reads:

Array boolean indices must be concrete; got ShapedArray(bool[10])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError


In [49]:
def TransformNegative(x, thresh):
    return x if thresh>0 else -x

np.random.seed(1)
x = jnp.array(np.random.randn(10))
thresh = jnp.array(np.random.randn(1))
print("Using the non-compiled version of TransformNegative works," +
      "and returns: \n{}".format(TransformNegative(x, thresh)))

CompileTransformNegative = jit(TransformNegative)

try:
    print(CompileTransformNegative(x, thresh))
except Exception as e:
    print("\nHowever the compiled version cannot have if statements that depend on the content of " + 
          "traced variables (which itself is not traced). Upon compilation, it should be clear"+ 
          " what if/else path is chosen. The raised exception reads:\n")
    print(e)

Using the non-compiled version of TransformNegative works,and returns: 
[ 1.6243454  -0.6117564  -0.5281718  -1.0729686   0.86540765 -2.3015387
  1.7448118  -0.7612069   0.3190391  -0.24937038]

However the compiled version cannot have if statements that depend on the content of traced variables (which itself is not traced). The raised exception reads:

Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[1])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function TransformNegative at <ipython-input-49-8a8424009ac0>:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'thresh'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


# Vectorization

We now show how functions can be vectorized using one of JAX's transforms: `vmap`. We reuse the example given in the documentation.

In [74]:
# We start with a simple convolution function

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
    output = []
    for i in range(1, len(x)-1):
        # This assumes that w has shape (3,)
        output.append(jnp.dot(x[i-1:i+2], w))
    return jnp.array(output)

print("Result of the convolution: {}".format(convolve(x, w)))

Result of the convolution: [11. 20. 29.]


In [None]:
# We naively create a vectorized version by looping over the batch dimension


In [None]:
# Or we could use vmap!

# Autodiff

We now show some examples for taking derivatives of various orders.

## First order: gradients

In [70]:
from jax import grad, value_and_grad

x = jnp.arange(4, dtype=jnp.float32)

@jit
def SumSquaresLoss(x):
    return .5 * jnp.sum(x**2)

# Just the gradient
GradSumSquaresLoss = grad(SumSquaresLoss)
print("Gradient of the loss: {}".format(GradSumSquaresLoss(x)))

# Value and gradient
ValGradSumSquaresLoss = value_and_grad(SumSquaresLoss)
print("Value and gradient of the loss: {}".format(ValGradSumSquaresLoss(x)))

# If we now want to output a auxiliary result we can use the following
@jit
def SumSquaresLossWithAux(x):
    return .5 * jnp.sum(x**2), x**2

ValGradSumSquaresLossWithAux = value_and_grad(SumSquaresLossWithAux, has_aux=True)
print("Value and gradient of the loss with aux: {}".format(ValGradSumSquaresLossWithAux(x)))

Gradient of the loss: [0. 1. 2. 3.]
Value and gradient of the loss: (DeviceArray(7., dtype=float32), DeviceArray([0., 1., 2., 3.], dtype=float32))
Gradient of the loss with aux: ((DeviceArray(7., dtype=float32), DeviceArray([0., 1., 4., 9.], dtype=float32)), DeviceArray([0., 1., 2., 3.], dtype=float32))


In [67]:
# We now show an example of a scalar function that takes to inputs

x = jnp.arange(4, dtype=jnp.float32)
y = jnp.arange(2, dtype=jnp.float32) + 4.

@jit
def SumSquaresLoss2(x, y):
    return .5 * (jnp.sum(x**2) + jnp.sum(y**2))

GradSumSquaresLoss2 = grad(SumSquaresLoss2)
print("Gradient of the loss wrt the first argument: {}".format(GradSumSquaresLoss2(x, y)))

GradSumSquaresLossBoth2 = grad(SumSquaresLoss2, argnums=(0, 1))
print("Gradient of the loss wrt both arguments: {}".format(GradSumSquaresLossBoth2(x, y)))

Gradient of the loss wrt the first argument: [0. 1. 2. 3.]
Gradient of the loss wrt both arguments: (DeviceArray([0., 1., 2., 3.], dtype=float32), DeviceArray([4., 5.], dtype=float32))


In [66]:
# Failure case
A = jnp.arange(8, dtype=jnp.float32).reshape(2, 4)

@jit
def MulByA(x):
    return jnp.dot(A, x)
GradMulByA = grad(MulByA)

try:
    GradMulByA(x)
except Exception as e:
    print("Using grad on a function that outputs multiple variables yields a error." + 
         " We get the following exception:\n")
    print(e)

Using grad on a function that outputs multiple variables yields a error. We get the following exception:

Gradient only defined for scalar-output functions. Output had shape: (2,).


## First order: Jacobian