<a href="https://colab.research.google.com/github/ritwikraha/A-guide-to-ML-Workflows-with-JAX/blob/main/Evolution-of-JAX-and-its-Power-Tools.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Evolution of JAX & its Power Tools

So as we discussed, I will be going through the evolution of JAX and how the JAX was conceived and then Aritra will be going through some of the power tools of JAX, to make way for Soumik to finally walk you through a training loop in JAX.




# Imports and Setup

In [None]:
# Install autograd
!pip install --quiet autograd

In [None]:
# Import the necessary packages
from autograd import numpy as anp
from autograd import grad
from autograd import elementwise_grad as egrad

# The ideation of JAX

So obviously JAX is a whole new library and in this introductory lesson we will be going through the chapters shown here, to get a sense of what JAX is and where it came from. 

![Summary of Part 1](https://imgur.com/QCF888w.png)

![What is JAX](https://imgur.com/UvKZmPz.png)

JAX is the combination of `autograd` and `XLA`.


![JAX equation](https://imgur.com/Au7ExM9.png)

Okay! That is great. We just uncovered a new term and came across two new terms. To understand the essence of what JAX really is, we would need to take a look at autograd and XLA individually.


![Jargons too many of them](https://imgur.com/nvXtpxr.png)

## Various ways of doing differentiation


![Various types of differentiation](https://imgur.com/rf0mrnZ.png)

And the best way to talk about autograd is to talk about gradients.

Gradients run the deep learning world, quite literally. We back propagate the gradients through our DL models which powers them to learn from their mistakes.

To compute these gradients we have 4 options.


*   Manual: We use our calculus knowledge and derive the derivatives by hand. The problem with this approach is that it is manual. It would take a lot of time for a Deep Learning researcher to derive the model's derivatives by hand.
*   Symbolic: We can obtain the derivatives via symbols and a program that can mimic the manual process. The problem with this approach is termed expression swell. Here the derivatives of a particular expression are exponentially 
longer (think chain rule) than the expression itself. This becomes quite difficult to track.
*   Numeric: Here, we use the finite differences method to derive the derivatives. 
*   Automatic: The star ⭐️ of the show.  










![Autograd](https://imgur.com/GNPL1Ym.png)


Automatic differentiation (autodiff) is the type of differentiation we all love and use when training our deep neural networks. `autograd` is a python package that performs automatic differentiation on native python and NumPy code. The code base is fairly simple.


Two points to note here:
*  There is a light wrapper `autograd.numpy` around the native NumPy codebase. This allows users to use NumPy-like semantics, harnessing the power of automatic differentiation.
*  `autograd.grad` and `autograd.elementwise_grad` help with the actual automatic differentiation.


## `autograd`

$$f(x)=x^2$$

In [None]:
def get_square(value):
    # Return the square of the input
    return value**2

# Build a scalar input and pass it to the
# square function
value = 4.0
value_squared = get_square(value)
print(f"value => {value}\nvalue**2 => {value_squared}")

value => 4.0
value**2 => 16.0


$$f'(x)=2x$$

In [None]:
# Compute the derivative of the square function
grad_func = grad(get_square)

point = 1.0
# Retrieve the gradient of the function at a particular point
print(f"Gradient of square func at {point} => {grad_func(1.0)}")

Gradient of square func at 1.0 => 2.0


## Vectorization in `autograd`?

In [None]:
# Let's pass a vector to the square function
vector = anp.arange(1, 10, dtype=anp.float32)
out_vector = get_square(vector)

# Iterate over the vector and its output
for v, o in zip(vector, out_vector):
    print(f"Value at point {v} => {o}")

Value at point 1.0 => 1.0
Value at point 2.0 => 4.0
Value at point 3.0 => 9.0
Value at point 4.0 => 16.0
Value at point 5.0 => 25.0
Value at point 6.0 => 36.0
Value at point 7.0 => 49.0
Value at point 8.0 => 64.0
Value at point 9.0 => 81.0


In [None]:
# Now let's try the grad function with vector inputs
try:
    out_vector = grad_func(vector)
except Exception as ex:
    print(f"Type of exception => {type(ex).__name__}")
    print(f"Excpetion => {ex}")

Type of exception => TypeError
Excpetion => Grad only applies to real scalar-output functions. Try jacobian, elementwise_grad or holomorphic_grad.


In [None]:
# Using element wise gradient
egrad_func = egrad(get_square)

try:
    out_vector = egrad_func(vector)
    for v, o in zip(vector, out_vector):
        print(f"Grad at point {v} => {o}")
except Exception as ex:
    print(f"Type of exception => {type(ex).__name__}")
    print(f"Excpetion => {ex}")

Grad at point 1.0 => 2.0
Grad at point 2.0 => 4.0
Grad at point 3.0 => 6.0
Grad at point 4.0 => 8.0
Grad at point 5.0 => 10.0
Grad at point 6.0 => 12.0
Grad at point 7.0 => 14.0
Grad at point 8.0 => 16.0
Grad at point 9.0 => 18.0


# What about the other parent?

It is safe to say that the fields of Deep Learning (DL) and Machine Learning (ML) consist of an enormous amount of Linear Algebra. All computations from start to finish are mostly Linear Algebra. 


What if we told you there is a compiler in town that can make Linear Algebra operations more efficient? 


Enters XLA: XLA stands for Accelerated Linear Algebra. It is a domain-specific compiler that accelerates linear algebra operations. 

The compiled operations are device agnostic. It runs on the CPU, GPU, and TPU with no code change.



![XLA](https://imgur.com/08C1Lqt.png)

Let me ask you the question again.

What is JAX? Having a fair amount of knowledge about XLA and Autograd should help you with a holistic overview of JAX. 

It should also make you more at ease for the things that are about to come. 


![What is JAX again?](https://imgur.com/o4FwAIK.png)

Before we start multiplying matrices and backpropagating on them, let us take a moment to understand the various components of  JAX. While starting with a library, knowing its basic API design is always a good practice. 


The API design of JAX is done in a way where we have the high-level abstraction of `jax.numpy` and the low-level abstraction of `jax.lax`.


Where `jax.numpy` is similar to the original NumPy package, `jax.lax` is a wrapper around Google's XLA compiler.


Note: Did you notice that lax is an anagram of xla? 🤯


If you head over to the official documentation of JAX API, you will see several sub-packages and sub-topics with their APIs listed. These would be:



*   Just-in-time compilation(jit)
*   Automatic differentiation(grad)
*   Vectorization(vmap)
*   Parallelization(pmap)





And if this sounds alien, don't worry Aritra will be making this a cake walk in a second.




![Most used APIs](https://imgur.com/Kr4YrZF.png)

In [None]:
import numpy as np

import jax
from jax import numpy as jnp
from jax import make_jaxpr
from jax import grad, jit, vmap, pmap, make_jaxpr

Before diving in, we must note that JAX is not a Deep Learning (DL) framework. Instead, it is a numerical computation library. It is just that DL falls into the numerical computation paradigm. 


For the ease of numerical computation, it has a NumPy API that mirrors the API of yet another very powerful numerical computation library (yes, you guessed it, NumPy 😁).


The thing that makes JAX stand out is its wrapper for the XLA compiler, `jax.lax`. The `jax.numpy` wrapper is basic XLA code with the `jax.lax` API. This makes JAX code not only device agnostic but also jit compilable.


Being device agnostic means that the same code can be run on different hardware (CPUs, GPUs, and TPUs). With the JIT compilation, the same code can run much faster and more efficiently. This is why JAX is referred to as NumPy on steroids. Soumik will be going through this in details.



## `jax.numpy`

In [None]:
# Build an array of 0 to 9 with the `jax.numpy` API
array = jnp.arange(0, 10, dtype=jnp.int32)

print(f"array => {array}")
print(type(array))



array => [0 1 2 3 4 5 6 7 8 9]
<class 'jaxlib.xla_extension.ArrayImpl'>


In [None]:
# SET AT
jax_array = jnp.arange(1, 10, dtype=jax.numpy.int64)
numpy_array = np.arange(1, 10).astype(np.int64)

try:
    numpy_array[2] = 2
except Exception as ex:
    print(f"Type of exception => {type(ex).__name__}")
    print(f"Excpetion => {ex}")
try:
    jax_array[2] = 2
except Exception as ex:
    print(f"Type of exception => {type(ex).__name__}")
    print(f"Excpetion => {ex}")

Type of exception => TypeError
Excpetion => '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html


  jax_array = jnp.arange(1, 10, dtype=jax.numpy.int64)


In [None]:
# SET AT
try:
    mutated_jax_array = jax_array.at[2].set(200)
except Exception as ex:
    print(f"Type of exception => {type(ex).__name__}")
    print(f"Excpetion => {ex}")
print(f"Original Array => {jax_array}")
print(f"Mutated Array => {mutated_jax_array}")

Original Array => [1 2 3 4 5 6 7 8 9]
Mutated Array => [  1   2 200   4   5   6   7   8   9]


In [None]:
# INDEX OUT OF BOUNDS
try:
    print("Indexing 1000th position of a NumPy array...")
    print(numpy_array[1000])
except Exception as ex:
    print(type(ex).__name__)
    print(ex)

Indexing 1000th position of a NumPy array...
IndexError
index 1000 is out of bounds for axis 0 with size 9


In [None]:
# INDEX OUT OF BOUNDS
try:
    print("Indexing 1000th position of a JAX array...")
    print(jax_array[1000])
except Exception as ex:
    print(type(ex).__name__)
    print(ex)

Indexing 1000th position of a JAX array...
9


What happened here?
In JAX, the indexing is capped. This is a little caveat that we need to take care of so that we do not see our code fail silently.

It is being discussed in an [issue](https://github.com/google/jax/issues/9839) (9839) to fix this by returning NaN as a default.

Unfortunately the issue is still open.



![What is next?](https://i.imgur.com/pppIMk1.png)

# The power tools of JAX

![What we cover?](https://imgur.com/Cvh33Xj.png)

![Important Functional Transformations](https://imgur.com/KEMRsgp.png)

![All of the important functional transformations](https://imgur.com/TjspCdA.png)

![What are functional transformations](https://imgur.com/WDJWWHu.png)

![Definition](https://imgur.com/KzNskIu.png)

![Definition](https://imgur.com/rxQkhIH.png)

![Pure Functions](https://imgur.com/Nf1pI48.png)

![Definition](https://imgur.com/Mo1ygwF.png)

![Checklist](https://imgur.com/rWdu8SZ.png)

## StateFul and StateLess

In [None]:
class StateFul:
    def __init__(self):
        self.state = 0
    
    def change_state(self):
        self.state = self.state + 1
        output = self.state ** 2
        return output

stateful = StateFul()
print(f"Initial state => {stateful.state}")
output = stateful.change_state()
print(f"Output => {output}")
print(f"Changed state => {stateful.state}")

Initial state => 0
Output => 1
Changed state => 1


In [None]:
class StateLess:    
    def change_state(self, state):
        changed_state = state + 1
        output = changed_state ** 2
        return output, changed_state

stateless = StateLess()
initial_state = 0
print(f"Initial state => {initial_state}")
output, changed_state = stateless.change_state(state=initial_state)
print(f"Output => {output}")
print(f"Changed state => {changed_state}")

Initial state => 0
Output => 1
Changed state => 1


In [None]:
from typing import NamedTuple, Any

In [None]:
class PureState(NamedTuple):
    state: Any

    def update_state(self, new_state):
        return PureState(new_state)

p1 = PureState(1)
p2 = p1.update_state(2)

print(p1)
print(p2)

PureState(state=1)
PureState(state=2)


## jaxpr

![jaxpr](https://imgur.com/VQuRKVP.png)

![Definition](https://imgur.com/xbmzea4.png)

![jaxpr](https://imgur.com/IiynmyJ.png)

![jaxpr illustrated](https://imgur.com/KDO7A2i.png)

In [None]:
def demo_function(arg1, arg2, arg3):
    temp = arg1 + arg2
    temp = temp * arg3
    return temp / 3.0

make_jaxpr(demo_function)(1.0, 1.0, 1.0)

{ lambda ; a:f32[] b:f32[] c:f32[]. let
    d:f32[] = add a b
    e:f32[] = mul d c
    f:f32[] = div e 3.0
  in (f,) }

![illustrated function](https://imgur.com/nOA1NSq.png)

## `grad`

![jax grad](https://imgur.com/LP5SppJ.png)

In [None]:
def equation(x):
    return 4*x**3 + 3*x**2 + 2*x + 1

make_jaxpr(equation)(2.0)

{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=3] a
    c:f32[] = mul 4.0 b
    d:f32[] = integer_pow[y=2] a
    e:f32[] = mul 3.0 d
    f:f32[] = add c e
    g:f32[] = mul 2.0 a
    h:f32[] = add f g
    i:f32[] = add h 1.0
  in (i,) }

![jaxpr of the function](https://imgur.com/odXjT3D.png)

In [None]:
equation_first_der = grad(equation)
make_jaxpr(equation_first_der)(2.0)

{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=3] a
    c:f32[] = integer_pow[y=2] a
    d:f32[] = mul 3.0 c
    e:f32[] = mul 4.0 b
    f:f32[] = integer_pow[y=2] a
    g:f32[] = integer_pow[y=1] a
    h:f32[] = mul 2.0 g
    i:f32[] = mul 3.0 f
    j:f32[] = add e i
    k:f32[] = mul 2.0 a
    l:f32[] = add j k
    _:f32[] = add l 1.0
    m:f32[] = mul 2.0 1.0
    n:f32[] = mul 3.0 1.0
    o:f32[] = mul n h
    p:f32[] = add_any m o
    q:f32[] = mul 4.0 1.0
    r:f32[] = mul q d
    s:f32[] = add_any p r
  in (s,) }

![jaxpr of the derivative of the function defined](https://imgur.com/TZBlWUF.png)

In [None]:
equation_first_der(2.0)

Array(62., dtype=float32, weak_type=True)

In [None]:
equation_second_der = grad(equation_first_der)
equation_second_der(2.0)

Array(54., dtype=float32, weak_type=True)

In [None]:
equation_third_der = grad(equation_second_der)
equation_third_der(2.0)

Array(24., dtype=float32, weak_type=True)

## `jit`

![jit](https://imgur.com/QNB8mMN.png)

![steps for jit](https://imgur.com/eoRqRRM.png)

In [None]:
def matrix_mul(a, b):
    return jnp.matmul(a, b)

key = jax.random.PRNGKey(42)

a = jax.random.normal(key, shape=(1000, 5000))
b = jax.random.normal(key, shape=(5000, 1000))

make_jaxpr(matrix_mul)(a, b)

{ lambda ; a:f32[1000,5000] b:f32[5000,1000]. let
    c:f32[1000,1000] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a b
  in (c,) }

In [None]:
%timeit -n5 matrix_mul(a, b).block_until_ready()

252 ms ± 86.9 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [None]:
jit_matrix_mul = jit(matrix_mul)

make_jaxpr(jit_matrix_mul)(a, b)

{ lambda ; a:f32[1000,5000] b:f32[5000,1000]. let
    c:f32[1000,1000] = pjit[
      jaxpr={ lambda ; d:f32[1000,5000] e:f32[5000,1000]. let
          f:f32[1000,1000] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
          ] d e
        in (f,) }
      name=matrix_mul
    ] a b
  in (c,) }

![xla in the jaxpr](https://imgur.com/7mkyL2n.png)

In [None]:
# warmup
warmup_results = jit_matrix_mul(a, b)

# ⚡️ speed em up!
%timeit -n5 jit_matrix_mul(a, b).block_until_ready()

183 ms ± 25.8 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


## Can we JIT everything?

In [None]:
# Let's break JIT now
@jit
def f(x):
    if x > 0:
        return x+1
    else:
        return x

In [None]:
try:
    f(10)
except Exception as ex:
    print(f"Type of exception => {type(ex).__name__}")
    print(f"Exception => {ex}")

Type of exception => ConcretizationTypeError
Exception => Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function. 
The error occurred while tracing the function f at <ipython-input-46-e3be4ae0da4e>:2 for jit. This concrete value was not available in Python because it depends on the value of the argument x.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


In [None]:
@jit
def f(x):
    return jnp.where(x > 0, x+1, x)

In [None]:
try:
    f(10)
except Exception as ex:
    print(f"Type of exception => {type(ex).__name__}")
    print(f"Exception => {ex}")

## `vmap`

In [None]:
a = jnp.array([1.0, 4.0, 0.5])
b = jnp.arange(5, 10, dtype=jnp.float32)

def weighted_mean(a, b):
    output = []
    for idx in range(1, b.shape[0]-1):
        output.append(jnp.mean(a + b[idx-1 : idx+2]))
    return jnp.array(output)

print(f"a => {a.shape}")
print(f"b => {b.shape}")
output = weighted_mean(a, b)
print(f"output => {output.shape}")

a => (3,)
b => (5,)
output => (3,)


In [None]:
# Let's include the batch dim to the inputs
batch_size = 8
batched_a = jnp.stack([a] * batch_size)
batched_b = jnp.stack([b] * batch_size)

print(f"batched_a => {batched_a.shape}")
print(f"batched_b => {batched_b.shape}")

batched_a => (8, 3)
batched_b => (8, 5)


In [None]:
batched_weighted_mean = vmap(weighted_mean)
batched_output = batched_weighted_mean(batched_a, batched_b)

print(f"batched output => {batched_output.shape}")

batched output => (8, 3)


## `pmap`

For this section you would need to go to the `Runtime` of the colab notebook and change runtime to TPU.

In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
jax.devices()

In [None]:
from jax import numpy as jnp
from jax import pmap
from jax import random

In [None]:
key = random.PRNGKey(42)
a = random.normal(key, shape=(3000,5000))
b = random.normal(key, shape=(5000,3000))

matrix_mul = lambda a, b: jnp.matmul(a, b)
matrix_mul(a, b).shape

In [None]:
n_devices = jax.local_device_count()
a = random.normal(key, shape=(n_devices, 3000, 5000))
b = random.normal(key, shape=(n_devices, 5000, 3000))

parallel_matrix_mul = pmap(matrix_mul)
parallel_matrix_mul(a, b).shape

## 🎲 Randomness

In [None]:
import numpy as np
# random number generation using numpy

np.random.seed(42)
rn1 = np.random.normal()
rn2 = np.random.normal()
print(f"NumPy Random Number Generation: {rn1: .2f} {rn2: .2f}")


In [None]:
from jax import random

key = random.PRNGKey(65)

print(key)

jrn1 = random.normal(key)
jrn2 = random.normal(key)

print(f"JAX Random Number Generation: {jrn1: .2f} {jrn2: .2f}")

In [None]:
print("JAX original key", key)
mod_key, subkey = random.split(key)

print("JAX modified key", mod_key)
print("JAX sub key", subkey)

![What is next?](https://i.imgur.com/pppIMk1.png)