<a href="https://colab.research.google.com/github/rahuldave/LearningJax/blob/main/Jax_first.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Jax.first

What is Jax? Its readme describes it as: "AX is really an extensible system for composable function transformations.", which does not tell us much.

A more germane question might be why use it? The answer is, it enables many different kinf=d of interesting applications, all of which run very fast.

For example, you can make BERT 12 tmes faster than the default implementation.

Yes, neural networks. But also gradient boosting, solving diffferential equations, high performance computing, etc.

How?

![](https://i.imgur.com/COjvmnE.png)

This diagram and a detailed discussion on how why|when you should use jax can be found at https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/#references .

Let us address its features one by one.

In [1]:
!cat /proc/cpuinfo

processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2199.998
cache size	: 56320 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4399.99
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 b

In [2]:
USE_TPU = True

if USE_TPU:
  import jax
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()
else:
  # x8 cpu devices  
  import os
  os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

In [3]:
num_replicas = len(jax.devices())
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## 1:  `jax.numpy` replaces `numpy` and makes it faster

jax.numpy provides a dropin replacement for numpy (*) that runs ops on whatever accelerator you have.

This section uses code from https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/jax_intro.ipynb#scrollTo=D2dSgG-cnVIe

In [4]:
import numpy as np
import jax.numpy as jnp

Everything (almost) that you might call in numpy can be obtained using jax. The advatage is that there is a bunch of optimizations you can take advatage of...

In [5]:
size = int(1e3)
number_of_loops=int(1e2)

In [6]:
def f(x=None):
  if not isinstance(x, np.ndarray):
    x=np.ones((size, size), dtype=np.float32) 
  return np.dot(x, x.T)

In [7]:
%timeit -o -n $number_of_loops f()

100 loops, best of 5: 44.7 ms per loop


<TimeitResult : 100 loops, best of 5: 44.7 ms per loop>

JAX supports execution on XLA devices through a process called `jit`ting, in which your python code is compiled down firtst to an intermediate representation language, and finally using XLA into machine code appropriate for a CPU, a GPU, or a TPU.

In [8]:
# JAX device execution
# https://github.com/google/jax/issues/1598

def jf(x=None): 
  if not isinstance(x, jnp.ndarray):
    x=jnp.ones((size, size), dtype=jnp.float32)
  return jnp.dot(x, x.T)

In [9]:
from jax import jit
f_tpu = jit(jf)
f_cpu = jit(jf, backend='cpu')

In [10]:
for i in range(3):
  %time f_cpu() 

CPU times: user 224 ms, sys: 28.9 ms, total: 253 ms
Wall time: 173 ms
CPU times: user 96.9 ms, sys: 10 ms, total: 107 ms
Wall time: 58.4 ms
CPU times: user 95 ms, sys: 0 ns, total: 95 ms
Wall time: 53.5 ms


Sometimes JAX `jit`ted CPU can be slower than numpy CPU, but if a function is complex, this is usually not the case.

In [11]:
for i in range(3):
  %time f_tpu() 

CPU times: user 198 ms, sys: 436 ms, total: 633 ms
Wall time: 1.11 s
CPU times: user 10.5 ms, sys: 10.5 ms, total: 20.9 ms
Wall time: 30.7 ms
CPU times: user 12.5 ms, sys: 18 ms, total: 30.5 ms
Wall time: 49.1 ms


In [12]:
%timeit -o -n $number_of_loops f_tpu().block_until_ready() 

100 loops, best of 5: 1.88 ms per loop


<TimeitResult : 100 loops, best of 5: 1.88 ms per loop>

Why is there a `block_until_ready` there? It is needed to get accurate timings. This is because JAX does not:

>wait for the operation to complete before returning control to the Python program. Instead, JAX returns a DeviceArray value, which is a future, i.e., a value that will be produced in the future on an accelerator device but isn’t necessarily available immediately. We can inspect the shape or type of a DeviceArray without waiting for the computation that produced it to complete, and we can even pass it to another JAX computation, as we do with the addition operation here. Only if we actually inspect the value of the array from the host, for example by printing it or by converting it into a plain old numpy.ndarray will JAX force the Python code to wait for the computation to complete.

(from https://jax.readthedocs.io/en/latest/async_dispatch.html)

The `DeviceArray` here replaces numpy's standard `ndarray`. Because of the design above, it is furthermore "lazy", which is slightly different from the general concept of a lazy computation.

For example computations in Spark are lazy because they can be composed, the idea being that a compiler can figure how to "fuse" operations to make them more performant. For example, in spark, if only a portion of a dataframe is required, the compiler might figure that the previous computation didnt need the whole dataframe either and optimize.

Pretty much the same happens here, except that the result of a `DeviceArray` computation is kept on the device where it is being done, for example the TPU. Indeed, if you `jit` a function, the XLA compiler might be able to figure a faster code path and run that instead.

## 2: Function transformation

The JAX readme tells us that JAX is a function composition engine, and indeed this is true. Consider a simple quadratic function.

In [13]:
def f(x):
  return 2*x*x + 3*x + 3

f(3)

30

we can use `make_jaxpr` to trace the function and show us a jax expression of what the function does

In [14]:
from jax import make_jaxpr

trace_f = make_jaxpr(f)

trace_f(3)

{ lambda ; a:i32[]. let
    b:i32[] = mul a 2
    c:i32[] = mul b a
    d:i32[] = mul a 3
    e:i32[] = add c d
    f:i32[] = add e 3
  in (f,) }

Notice that this trace is done at a particular value. This is an important point; keep it at the back of your mind.

Now this is just a function. But what is a function transformation? It is a function, that takes in a function, and pops out another function.

Sunch transformations are at the core of functional programming. They provide a simple way to compose functions into a useful pieces of code.

For example, consider:

In [15]:
def sum_of_anything(f):
  def doit(a, b):
    return f(a) + f(b)
  return doit

In [18]:
sum_of_squares = sum_of_anything(lambda x: x*x)
sum_of_squares(3, 4)

25

In [19]:
sum_of_cubes = sum_of_anything(lambda x: x*x*x)
sum_of_cubes(3, 4)

91

By returning a function we have created ageneral system that can be applied to any function of one variable, and return a sum.

Jax implements four key function transformations for us which enable, in this functional style above, a lot of great functionality. These are (from https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2022/#references):

1. `grad()` for evaluating the gradient function of the input function
2. jit() to transform functions into just-in-time compiled versions
3. vmap() for automatic vectorization of operations
4. pmap() for easy parallelization of computations
5. make_jaxpr() to get insight into what any jax function is doing

Let's tackle these one by one and see how they compose for us a lovely deep learning and scientific computation system...


##  3: Automatic Differentiation and  Gradients

There is a key difference between Jax and other deep learning frameworks like tensorflow and pytorch. Rather than compute the gradient at a loss function at a certain point by backpropogation through a computational graph, Jax implements a function transformation and returns the gradient function.

What is a computational graph? Consider the trace we saw above:




In [20]:
trace_f(3)

{ lambda ; a:i32[]. let
    b:i32[] = mul a 2
    c:i32[] = mul b a
    d:i32[] = mul a 3
    e:i32[] = add c d
    f:i32[] = add e 3
  in (f,) }

This computation can be described by the following graph:

![](https://i.imgur.com/V0y5oqF.jpg)

Now, basic rules of calculus can be used to differentiate this graph. The question is, how is this done.

`Pytorch` for example uses the original `autograd` library. As very nicely described in Sabrina Mielke's post, the rough idea there is:

![](https://sjmielke.com/images/blog/jax-purify/comparison_small.png)

As she describes:

>PyTorch builds up a graph as you compute the forward pass, and one call to backward() on some “result” node then augments each intermediate node in the graph with the gradient of the result node with respect to that intermediate node. JAX on the other hand makes you express your computation as a Python function, and by transforming it with grad() gives you a gradient function that you can evaluate like your computation function—but instead of the output it gives you the gradient of the output with respect to (by default) the first parameter that your function took as input:

In [22]:
from jax import grad

g_f = grad(f)
print(type(g_f))
g_f(3.)

<class 'function'>


DeviceArray(15., dtype=float32, weak_type=True)

Currently (from a assumbly.ai aticle):

>With grad(), you can differentiate through native Python and NumPy functions, such as loops, branches, recursion, closures, and “PyTrees” (e.g. dictionaries).

`make_jaxpr` tells us what is happening in `grad`

In [23]:
trace = make_jaxpr(grad(f))
trace(3.)

{ lambda ; a:f32[]. let
    b:f32[] = mul a 2.0
    c:f32[] = mul b a
    d:f32[] = mul a 3.0
    e:f32[] = add c d
    _:f32[] = add e 3.0
    f:f32[] = mul 1.0 3.0
    g:f32[] = mul b 1.0
    h:f32[] = mul 1.0 a
    i:f32[] = add_any f g
    j:f32[] = mul h 2.0
    k:f32[] = add_any i j
  in (k,) }

`value_and_grad` gibe us the function as well asthe gradient:

In [24]:
from jax import value_and_grad

vg_f = value_and_grad(f)

value, gradient = vg_f(3.)
value, gradient

(DeviceArray(30., dtype=float32, weak_type=True),
 DeviceArray(15., dtype=float32, weak_type=True))

## 4: `jit`ting

We have seen `jit`ting before, but the summary is that [xla](https://www.tensorflow.org/xla) transforms functions from the python version to another one  with the same signature but which is optimised for the accelerator you are running on (e.g. GPU or TPU).

The code is compiled at runtime, so the first execution will be slower.

One of the key things that XLA will do for us is to fuse kernels: in other words, make code transformations that improve efficiency. This is done at the intermediate-representation (IR) level.

In [25]:
from jax import jit

jitted_f = jit(f)

jitted_f(3)

DeviceArray(30, dtype=int32, weak_type=True)

In [26]:
make_jaxpr(jitted_f)(3)

{ lambda ; a:i32[]. let
    b:i32[] = xla_call[
      call_jaxpr={ lambda ; c:i32[]. let
          d:i32[] = mul c 2
          e:i32[] = mul d c
          f:i32[] = mul c 3
          g:i32[] = add e f
          h:i32[] = add g 3
        in (h,) }
      name=f
    ] a
  in (b,) }

In [27]:
def ff(x,y):
  return x**2 + y

# Partial derviatives
x = 2.0; y= 3.0;
v, gx = value_and_grad(ff, argnums=0)(x,y)
print(v)
print(gx)

gy = grad(ff, argnums=1)(x,y)
print(gy)

7.0
4.0
1.0


You can differentiate to any level. Higher order optimization routines ususally want hessians, which are very easy to calculate (more later)

One thing you have probably been wondering about (because we alluded to it) is why these traces always have an argument.

The answer lies in how `jit`ting is carried out. Remember that we first convert a function to the IR called jaxpr, and then the XLA JIT compiler will particularize the code for CPU, GPU, or TPU.

This jaxpr is created by tracing the function for a specific value. So functions with wildly different behavior for different types, or conditional branches, need to be  treated more carefully. I'll defer to the documentation and https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/jax_intro.ipynb#scrollTo=6Ps1W8LhKKj9 for more details.



In [29]:
def f3(x):
  if x > 0:
    return x
  else:
    return 2 * x

print(f(3.0))

f3_jit = jit(f3)
print(f3_jit(3.0))

30.0


ConcretizationTypeError: ignored

## 5: Vectorization with `vmap`

Jax provides you two places where vectorization is useful: takinf a function that acts on one input and making it act on more than one, and secondly, having the vectorization apply on a batch axis, so that the same function can be executed across multiple samples in a batch.

Here is an example of the first kind of vectorization.
We want to carry out a logistic regression prediction in this example from https://colab.research.google.com/github/probml/probml-notebooks/blob/main/notebooks/jax_intro.ipynb#scrollTo=D2dSgG-cnVIe. Let us write the code to do this on a single example:

In [19]:
D = 2
N = 3

w = np.random.normal(size=(D,))
X = np.random.normal(size=(N,D))

def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.) + 1)

def predict_single(x):
    return sigmoid(jnp.dot(w, x)) # <(D) , (D)> = (1) # inner product

In [20]:
print(predict_single(X[0,:])) # works

0.37815177


Clearly we cant use this code to predicr on a vector...

In [21]:
print(predict_single(X)) 

TypeError: ignored

So we could do the matrix multiplication ourselves, taking care to match dimensions...

In [22]:
def predict_batch(X):
    return sigmoid(jnp.dot(X, w)) # (N,D) * (D,1) = (N,1) # matrix-vector multiply

print(predict_batch(X)) 

[0.37815177 0.55771613 0.4303379 ]


Or let `jax` do it using `vmap`.

In [24]:
from jax import vmap
print(vmap(predict_single)(X))

[0.37815177 0.55771613 0.4303379 ]


`vmap` vectorizes over the first axis of each of its inputs. In the above example there is only one input, which is 3 samples of size 2. The first axis of this 2-D arrray is the 0th axis, or the sample axis. Which is exactly what we want in a prediction system.

This way of thinking of the 0th axis as the sample axis, or the batch axis leads to the second usage of `vmap`. Which is to carry out ops on everything but the sample/batch axis. Consider the example below (from http://matpalm.com/blog/ymxb_pod_slice/):

In [25]:
import jax.numpy as jnp
import numpy as np

def f(x):    
  return jnp.array([jnp.min(x), jnp.max(x)])

Suppose our input has shape: `(2, 3)`. Then we'll get the min and the max in the whole array:

In [26]:
x = np.arange(6).reshape((2, 3))
x

array([[0, 1, 2],
       [3, 4, 5]])

In [27]:
f(x)

DeviceArray([0, 5], dtype=int32)

Now, what if we have a batch of 4 inputs; i.e. x is of shape `(4, 2, 3)` and thus we want to return shape `(4, 2)`. Let us set up some data for this:

In [28]:
bx = np.arange(24).reshape((4, 2, 3))
bx

array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23]]])

If you now call `f(x)` its gonna give you the min and max of the whole array, and thats not what you want to do. Normally you might use an explicit axis argument in the `min` and `max`.

In [30]:
f(bx)

DeviceArray([ 0, 23], dtype=int32)

We can do it implicitly using `vmap`. The function you feed to `vmap` is expected to have arguments operating with a extra leading gimension. Normally this is the first argument...

In [31]:
from jax import vmap

               # f  in:    (2, 3) out:    (2)
v_f = vmap(f)  # vf in: (B, 2, 3) out: (B, 2)
v_f(bx)

DeviceArray([[ 0,  5],
             [ 6, 11],
             [12, 17],
             [18, 23]], dtype=int32)

In [32]:
make_jaxpr(f)(x)

{ lambda ; a:i32[2,3]. let
    b:i32[] = reduce_min[axes=(0, 1)] a
    c:i32[] = reduce_max[axes=(0, 1)] a
    d:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] b
    e:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] c
    f:i32[2] = concatenate[dimension=0] d e
  in (f,) }

In [33]:
make_jaxpr(v_f)(bx)

{ lambda ; a:i32[4,2,3]. let
    b:i32[4] = reduce_min[axes=(1, 2)] a
    c:i32[4] = reduce_max[axes=(1, 2)] a
    d:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] b
    e:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] c
    f:i32[4,2] = concatenate[dimension=1] d e
  in (f,) }

You can now `jit` the function if you like to speed it up.

In [34]:
jv_f = jit(vmap(f))
jv_f(bx)

DeviceArray([[ 0,  5],
             [ 6, 11],
             [12, 17],
             [18, 23]], dtype=int32)

In [35]:
make_jaxpr(jv_f)(bx)

{ lambda ; a:i32[4,2,3]. let
    b:i32[4,2] = xla_call[
      call_jaxpr={ lambda ; c:i32[4,2,3]. let
          d:i32[4] = reduce_min[axes=(1, 2)] c
          e:i32[4] = reduce_max[axes=(1, 2)] c
          f:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] d
          g:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] e
          h:i32[4,2] = concatenate[dimension=1] f g
        in (h,) }
      name=f
    ] a
  in (b,) }

As `f` and your tensors get more complex it is harder to make `f` batch-aware. `vmap` does it for you.

The default behavior of `vmap` is to vectorise  across the first axis of the first arg, but we have a lot of control on how we want to vectorise; e.g. consider

In [39]:
def f(a, b, c):
  return a + b + c
a = np.arange(6).reshape((2, 3)) + 10
b = 2
c = np.arange(6).reshape((3, 2)) + 20
print(repr(a))
c

array([[10, 11, 12],
       [13, 14, 15]])


array([[20, 21],
       [22, 23],
       [24, 25]])

In [40]:
v_f = vmap(f, in_axes=(0, None, 1))

v_f(a, b, c)

DeviceArray([[32, 35, 38],
             [36, 39, 42]], dtype=int32)

Ok, so what happened here? Look at the incantation of `in_axes`. The tuple there refers to the argument of `f`. `a` is a 2x3 so we want to vmap over the rows. with the tuple element being the axis 0. For `b` the tuple element is `None`, which means, dont vmap over this. Now for `c`, which is 3x2 it wants us to vmap over axis 1, or the horizontal axis. This means that the 22 in `c` is added to the 11 in `a` and then added to the broadcasted 2 to create the 35 in the output. You can see these machinations in the jaxpr for this function.

In [41]:
make_jaxpr(v_f)(a, b, c)

{ lambda ; a:i32[2,3] b:i32[] c:i32[3,2]. let
    d:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    e:i32[2,3] = add a d
    f:i32[2,3] = transpose[permutation=(1, 0)] c
    g:i32[2,3] = add e f
  in (g,) }

## 6: Parallelisation

`pmap` has pretty much the same interface as  `vmap`. So why a new construct? Well, pmap uses XLA and provides automatic jitting. It does not provide a vectorized version like `vmap` does though. Instead, XLA is used to create a SPMD (Single Program Multiple Data) program to run the function in parallel across all available devices: CPU's, GPU's or TPUs.

Here are devices are:

In [42]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [43]:
from jax import pmap

def f(a, b, c):
  return a + b + c
  
p_f = pmap(f, in_axes=(0, None, 0)) 

a = np.random.random(size=(num_replicas, 3))
b = 4
c = np.random.random(size=(num_replicas, 3))

a

array([[0.77718412, 0.00176833, 0.44951046],
       [0.27297925, 0.46497406, 0.91643829],
       [0.09127398, 0.59443344, 0.95430272],
       [0.00510437, 0.10622125, 0.29679499],
       [0.78516484, 0.09587237, 0.7798978 ],
       [0.50380788, 0.25051366, 0.61133905],
       [0.56083878, 0.44882798, 0.2974137 ],
       [0.96834715, 0.72368435, 0.94016945]])

Notice now that we have as many rows as we have devices. Now it is not the batch axis we are "vectorizing" or more precisely "spmd"ing over, but rather this device dimension

In [44]:
p_f(a, b, c)

ShardedDeviceArray([[5.010401 , 4.144878 , 4.6788063],
                    [5.11807  , 4.5080056, 5.2585855],
                    [4.3505087, 4.9883356, 5.7644   ],
                    [4.076356 , 4.7490406, 4.4232683],
                    [4.9303174, 4.374606 , 5.433234 ],
                    [5.1140704, 4.5758166, 5.3989067],
                    [5.4698825, 5.1499176, 4.4592896],
                    [5.0957966, 5.681382 , 5.3702273]], dtype=float32)

Notice that an identical program has been shipped off to 8 devices. A broadcasted addition has been carried out on these `num_replicas` devices. Notice the output data type. It is a _Sharded_ device array. This means that those parts of the array are not actually on the main machine, but rather on the individual devices. We have moved the program to the datas, and run it on each of the datas.

jaxpr as usual tells us more:

In [45]:
make_jaxpr(pmap(f, in_axes=(0, None, 0)))(a, b, c)

{ lambda ; a:f32[8,3] b:i32[] c:f32[8,3]. let
    d:f32[8,3] = xla_pmap[
      axis_name=<axis 0x7f92facb9ef0>
      axis_size=8
      backend=None
      call_jaxpr={ lambda ; e:f32[3] f:i32[] g:f32[3]. let
          h:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
          i:f32[3] = add e h
          j:f32[3] = add i g
        in (j,) }
      devices=None
      donated_invars=(False, False, False)
      global_arg_shapes=(None, None, None)
      global_axis_size=None
      in_axes=(0, None, 0)
      name=f
      out_axes=(0,)
    ] a b c
  in (d,) }

Our output is stored `ShardedDeviceArray`, which is an abstraction over remote arrays constructed on each device. Imagine a larger calculation where we want both input and output to be sharded across devices so there is minimum data transfer. This is classic compute at data SPMD.

## 7: Combing `vmap` and `pmap`: composition again

Here is a slightly different f: we matrix multiply and add:

In [48]:
def f(a, b, c):
  return a @ b + c

a = np.random.random(size=(2, 3))
b = np.random.random(size=(3, 4))
c = np.random.random(size=(2, 4))

f(a, b, c)

array([[0.76190696, 1.01631292, 0.83752259, 1.87883684],
       [1.28544108, 1.36528368, 0.84584105, 1.27333203]])

Let us add a batch dimension and vmap.

In [49]:
v_f = vmap(f)

a = np.random.random(size=(5, 2, 3))
b = np.random.random(size=(5, 3, 4))
c = np.random.random(size=(5, 2, 4))

v_f(a, b, c)

DeviceArray([[[1.0452026 , 1.3957614 , 0.8974083 , 0.2570575 ],
              [1.3226345 , 1.6668731 , 0.8013219 , 1.0267715 ]],

             [[0.42478555, 0.48035237, 1.0448409 , 1.054654  ],
              [1.2317226 , 1.3148572 , 0.54127437, 0.699142  ]],

             [[1.8477    , 1.4496236 , 1.5217985 , 2.1509185 ],
              [0.96027756, 0.20201734, 1.2771938 , 1.2706423 ]],

             [[1.3032501 , 1.8480506 , 1.5902729 , 0.78827995],
              [0.9036236 , 0.8725722 , 1.3766314 , 0.88759273]],

             [[0.6759734 , 1.483907  , 0.9024238 , 1.6253147 ],
              [0.5552775 , 2.0771217 , 0.7588869 , 1.6169556 ]]],            dtype=float32)

Now we have 5 2x4 arrays. Lets jit to make sure we can run on an accelerator...

In [50]:
jv_f = jit(vmap(f))

a = np.random.random(size=(5, 2, 3))
b = np.random.random(size=(5, 3, 4))
c = np.random.random(size=(5, 2, 4))

jv_f(a, b, c).shape

(5, 2, 4)

What happens is we wrap the `vmap` with a `pmap`? We have all these tpus or cpus we want to use..

Lets add a batch dimension which makes it useful to do something like this...

In [52]:
pjv_f = pmap(vmap(f))

a = np.random.random(size=(num_replicas, 5, 2, 3))
b = np.random.random(size=(num_replicas, 5, 3, 4))
c = np.random.random(size=(num_replicas, 5, 2, 4))

pjv_f(a, b, c)

ShardedDeviceArray([[[[2.1330729 , 1.079713  , 1.6747804 , 2.4263163 ],
                      [0.72165257, 0.5211822 , 0.4666624 , 1.4015088 ]],

                     [[1.4648497 , 1.58633   , 1.068167  , 1.4730864 ],
                      [1.5538089 , 1.2625966 , 1.2007633 , 1.5472561 ]],

                     [[1.3659166 , 0.9606316 , 1.2130721 , 0.9296636 ],
                      [1.4485664 , 1.824628  , 0.9115788 , 0.98683596]],

                     [[1.2636752 , 0.94934475, 0.9281723 , 1.1574597 ],
                      [1.4386181 , 1.0023386 , 0.80278707, 0.8467031 ]],

                     [[0.7390131 , 1.171852  , 0.98206335, 0.5656377 ],
                      [0.91225004, 2.0936654 , 1.1719465 , 1.3942947 ]]],


                    [[[0.7920339 , 1.0567851 , 1.3067839 , 0.8583878 ],
                      [1.1420336 , 1.4154868 , 1.1066592 , 1.0678368 ]],

                     [[0.49421442, 1.4615567 , 1.4088886 , 1.3548925 ],
                      [0.5242963 , 1.0997252 , 1.3

Each vmap gave us 5 arrays. But because we had `num_replicas` of these, each got assigned to a compute device, and the calculation got garried out there. In this way we can reap the advantages of both vectorization and  parallelism.



## 8. State: random numbers in jax

Jax is a pure functional framework. It ont allow you to change `DeviceArrays`, as opposed to how numpy deals with `ndarrays`.

This means that jax never mutates arrays in place. If you are familiar with Spark dataframes, you will find the same restriction.

This restriction causes us to write functiona with _referential transparancy_. Any time you run these functions, you get the samew result. State cannot play a role.

This is important as it is precisely what enables the spark schedulers or xla compilers wo wotl. When you know that a piece of data or metadata is not going to change except be explicit change rather than by some object oriented hidden state, you can compose tranformations and make other efficiency changes. As well, your computation can be statically, state-free analysed. It is a radeoff..you will create new arrays rather than mutate them in place, requiring more memory. But it also means you can reason about and transform your computation.

One os the huge epitomes of state in a python program is the random number generator.

Jax allows stateless random number generayion, but it thus requires more care than in the numpy case.


In [76]:
key = jax.random.PRNGKey(1337)
print(jax.random.uniform(key))
print(jax.random.uniform(key))

0.02251327
0.02251327


when you want to generate more random numbers you need to explicit split the key

In [30]:
key = jax.random.PRNGKey(1337)
for _ in range(10):
  key, key2 = jax.random.split(key)
  print(jax.random.uniform(key2))

0.31418657
0.8795924
0.81679654
0.82754505
0.78958607
0.9006636
0.18430483
0.2509787
0.5041659
0.11195123


In [31]:
# You cannot assign directly to elements of an array.

A = jnp.zeros((3,3), dtype=np.float32)

# In place update of JAX's array will yield an error!
try:
  A[1, :] = 1.0
except:
  print('cannot update in place')

cannot update in place
