# 01 - Why jax - Motivating the need/usage

https://www.youtube.com/watch?v=z-WSrQDXkuM

## How might you a performant & scalable deep neural network from scratch in Python ?

### Missing stuff in `numpy`

- Running on accelerated hardware
- Fast optimisation using some sort of automatic differentiation
- Fusing operations via compilation
- Parallelisation of data and computation

### Jax fills in via :

- Jax uses XLA to compile down to code that runs on various devices (CPU/GPU/TPU)
- Automatic differentiation (analytical `grad` using a tracer object) and vectorisation via `vmap`
- `jit` compilation can fuse operations by compiling them all together to run them efficiently
- Can parallise across devices easily using `pmap`

## Jax effectively is 
- an extensible system 
- for building composable function transformations 
- using Python + numpy.

# Some code now

In [1]:
import numpy as np

In [2]:
x = np.random.rand(2000,2000)
x

array([[0.04273037, 0.23183986, 0.48447575, ..., 0.61164107, 0.46734011,
        0.2769249 ],
       [0.6870176 , 0.13298142, 0.87604064, ..., 0.26202938, 0.32005995,
        0.80281536],
       [0.88467422, 0.87342925, 0.07306913, ..., 0.08519661, 0.88314336,
        0.62023012],
       ...,
       [0.18655909, 0.43413159, 0.71624317, ..., 0.21963128, 0.59678682,
        0.25249576],
       [0.55847249, 0.88173011, 0.18338242, ..., 0.18198948, 0.85956361,
        0.46423807],
       [0.39157313, 0.03832188, 0.06445656, ..., 0.67164811, 0.95078567,
        0.85916955]])

In [3]:
np.tan(x)

array([[0.04275639, 0.23608493, 0.52631301, ..., 0.70136439, 0.50462418,
        0.28422788],
       [0.82033445, 0.13377089, 1.19995753, ..., 0.26819571, 0.33145594,
        1.03545551],
       [1.22124363, 1.19360588, 0.07319945, ..., 0.08540334, 1.2174367 ,
        0.71425647],
       ...,
       [0.18875401, 0.46363114, 0.87044293, ..., 0.22323229, 0.67943005,
        0.25800211],
       [0.62482365, 1.21393482, 0.18546611, ..., 0.18402564, 1.1605312 ,
        0.50073829],
       [0.41289504, 0.03834065, 0.06454598, ..., 0.79494027, 1.40070716,
        1.15960683]])

## On the GPU via Jax

In [4]:
import jax.numpy as jnp

In [5]:
y = jnp.array(x)
y

DeviceArray([[0.04273036, 0.23183987, 0.48447576, ..., 0.61164105,
              0.4673401 , 0.2769249 ],
             [0.6870176 , 0.13298142, 0.87604064, ..., 0.26202938,
              0.32005996, 0.8028154 ],
             [0.8846742 , 0.87342924, 0.07306913, ..., 0.08519661,
              0.88314337, 0.62023014],
             ...,
             [0.1865591 , 0.4341316 , 0.71624315, ..., 0.21963128,
              0.5967868 , 0.25249577],
             [0.5584725 , 0.8817301 , 0.18338242, ..., 0.18198948,
              0.8595636 , 0.46423808],
             [0.39157313, 0.03832189, 0.06445657, ..., 0.6716481 ,
              0.95078564, 0.85916954]], dtype=float32)

Note that the array is actually a `DeviceArray`. This is the abstraction Jax uses to represent an array irrespective of the device it runs on.

In [6]:
jnp.sin(y)

DeviceArray([[0.04271736, 0.22976856, 0.46574453, ..., 0.5742118 ,
              0.4505132 , 0.27339903],
             [0.6342342 , 0.13258982, 0.7682102 , ..., 0.2590412 ,
              0.31462348, 0.71931475],
             [0.77370864, 0.76653576, 0.07300413, ..., 0.08509358,
              0.7727379 , 0.5812225 ],
             ...,
             [0.1854788 , 0.4206227 , 0.6565556 , ..., 0.21786977,
              0.5619876 , 0.24982136],
             [0.52989143, 0.7718401 , 0.18235631, ..., 0.18098655,
              0.7575578 , 0.44774166],
             [0.38164294, 0.03831251, 0.06441195, ..., 0.62227696,
              0.8138723 , 0.7573005 ]], dtype=float32)

In [7]:
y[0]

DeviceArray([0.04273036, 0.23183987, 0.48447576, ..., 0.61164105,
             0.4673401 , 0.2769249 ], dtype=float32)

Now let's time `np` and `jnp`

In [8]:
%timeit np.dot(x, x)
%timeit jnp.dot(y, y)

80 ms ± 390 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
157 µs ± 34.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


The Jax operation runs much faster (`ms` vs `us`)

## Now lets look into JIT

Fuses operations together into a compiled and efficient code. This below code is horrible in terms of perf.

In [9]:
def f(z):
    for i in range(10):
        z -= 0.1 * z

In [10]:
%timeit f(y)

1.09 ms ± 2.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [11]:
from jax import jit

a better compiled function now

In [12]:
g = jit(f)

In [13]:
%timeit g(y)

11.5 µs ± 64.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


The jitted function is a huge improvement over the normal jax function

### On to Automatic differentiation

In [14]:
def f(x):
    return x**3 + 5*x**2

This fn can be hand differentiated as follows

In [15]:
def manual_df(x):
    return 3*x**2 + 10*x

but we don't want to hand differentiate things as there's going to tons of these in a neural network.

let's use jax `grad` fn.

In [16]:
from jax import grad

In [17]:
grad_df = grad(f)

Now, let's compare those two.

In [18]:
manual_df(4.0)

88.0

In [19]:
grad_df(4.0)

DeviceArray(88., dtype=float32)

Now, just imagine that we can do this on a neural network scale automatically, jit all the functions, then run it on accelerators.

### Let's now vectorize a function (similar to broadcasting in numpy)

In [20]:
def square_sum(x):
    return jnp.sum(x**2)

In [21]:
xs = jnp.arange(100).reshape(10,10)
xs

DeviceArray([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
             [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
             [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
             [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
             [40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
             [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
             [60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
             [70, 71, 72, 73, 74, 75, 76, 77, 78, 79],
             [80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
             [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]], dtype=int32)

In [22]:
[square_sum(x) for x in xs]

[DeviceArray(285, dtype=int32),
 DeviceArray(2185, dtype=int32),
 DeviceArray(6085, dtype=int32),
 DeviceArray(11985, dtype=int32),
 DeviceArray(19885, dtype=int32),
 DeviceArray(29785, dtype=int32),
 DeviceArray(41685, dtype=int32),
 DeviceArray(55585, dtype=int32),
 DeviceArray(71485, dtype=int32),
 DeviceArray(89385, dtype=int32)]

but, this is inefficient. Instead we can vectorize this.

In [23]:
from jax import vmap

In [24]:
vmap(square_sum)(xs)

DeviceArray([  285,  2185,  6085, 11985, 19885, 29785, 41685, 55585,
             71485, 89385], dtype=int32)

fairly complex under the hood in how this is optmized. transforming multiple function calls into a batch function call.

## Tomorrow we'll dig deeper into how JAX works and pulls off these operations