### What is JAX
1. JAX is a framework targeted at machine learning research. It's pretty much like **`numpy`** but uses the XLA compiler to compile native **`numpy`** code, allowing the program to be runnable on accelerators (GPU/TPU)
2. JAX comes with useful features such as automatic differentiation, automatic vectorization, automatic device parallelization and just-in-time compilation. 
3. JAX expresses numerical programs as compositions but with certain constraints i.e. pure functions. A pure function is one that always returns the same value with invoked with the same arguments and has no side-effects of state change and non-local variables.

#### Device array
1. The basic unit of jax is the **`DeviceArray`**, a specialized storage like **`numpy`**'s **`ndarray`** but is backed by a memory buffer on CPU/GPU/TPU devices.
2. A **`DeviceArray`** is device agnostic, i.e. the same code runs on all devices. Also, a **`DeviceArray`** is lazy in that the contents is not immediately available and is only pulled when requested.

In [1]:
import time
import jax
import numpy as np
import jax.numpy as jnp

array_numpy = np.arange(10, dtype=np.int32)
array_jax = jnp.arange(10, dtype=jnp.int32)

print("Array created using numpy: ", array_numpy)
print("Array created using JAX: ", array_jax)

print("array_numpy is of type : {}".format(type(array_numpy)))
print("array_jax is of type : {}".format(type(array_jax)))



Array created using numpy:  [0 1 2 3 4 5 6 7 8 9]
Array created using JAX:  [0 1 2 3 4 5 6 7 8 9]
array_numpy is of type : <class 'numpy.ndarray'>
array_jax is of type : <class 'jaxlib.xla_extension.DeviceArray'>


3. JAX arrays are immutable, just like TensorFlow tensors. Meaning, JAX arrays don't support item assignment as you do in **`ndarray`**

In [2]:
array1 = np.arange(5, dtype=np.int32)
array2 = jnp.arange(5, dtype=jnp.int32)

print("Original ndarray: ", array1)
print("Original DeviceArray: ", array2)

# Item assignment
array1[4] = 10
print("\nModified ndarray: ", array1)
print("\nTrying to modify DeviceArray-> ", end=" ")

try:
    array2[4] = 10
    print("Modified DeviceArray: ", array2)
except Exception as ex:
    print("{}\n{}".format(type(ex).__name__, ex))

# Proper item assignment
array2_modified = jax.ops.index_update(array2, 4, 10)
print("Jax Modified DeviceArray: ", array2_modified)

Original ndarray:  [0 1 2 3 4]
Original DeviceArray:  [0 1 2 3 4]

Modified ndarray:  [ 0  1  2  3 10]

Trying to modify DeviceArray->  TypeError
'<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
Jax Modified DeviceArray:  [ 0  1  2  3 10]
