<a href="https://colab.research.google.com/github/yvrjsharma/JAX/blob/main/JAX_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JAX: PMaps

JAX has lots of cool features to evaluate your code parallely. This is also called SPMD, or Single-Program Multiple-Data code. In this technique same code or computation is run in parallel on different input data on different devices (e.g. TPUs)

You can use pmap() to write a piece of code suitable for running on one device as well as on multiple devices.



Lets start using TPUs for this one. If you are using this Colab notebook, make sure you change your Runtime to TPU.

In [2]:
#Lets import the required libraries
import jax
import jax.numpy as jnp
import numpy as np

#transformation : higher order functions which take fuun as an input and outputs a transformed fun 
from jax import grad, jit, vmap, pmap
from jax import random

import matplotlib.pyplot as plt
from copy import deepcopy
from typing import Tuple, NamedTuple
import functools

In [3]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [4]:
import jax
jax.devices() #Eight devices

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In order to ustilize all the cores that are available to us, lets create some array with batch size equal to the number of cores available.

Lets now perform a dummy computation - convolving a small array over this array which is spread across the cores. This dummy example will help us in understanding later how pmap() helps in speeding up more complex or bigger computations.  

In [5]:
n_devices = jax.local_device_count(  )
xs = np.arange(5*n_devices).reshape(-1,5)
w = np.array([2.,3.,4.])  #for convolving over the given metric
ws = np.stack([w] * n_devices)  #duplicating the same convolution kernel on 8 cores
xs, ws

(array([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]]), array([[2., 3., 4.],
        [2., 3., 4.],
        [2., 3., 4.],
        [2., 3., 4.],
        [2., 3., 4.],
        [2., 3., 4.],
        [2., 3., 4.],
        [2., 3., 4.]]))

In [10]:

def convolve(w,x):
  

array([[1., 2., 3.],
       [1., 2., 3.]])

To train big NNs we need to parallelize our computation. This is where pmap() comes into the picture.

Interesting thing is if you shard or distribute a *device array* computation across multiple cores or accelerators or devices you create what is called **Sharded Device Array**.

We can choose while parallelizing the computation to have a communication between the multiple cores or not. We can train our Neural Netwrok model in a distributed fashion in which every core will receive a batch of data and then they will communicate and cordinate among them to get the mean of gradients to update our ML model.

Also, note that pmap() calls jit() internally on computations.

#### **Working towards training a Neural Network**

If you want to log both, your loss value and your gradients of the loss then you can us jax.value_and_grad() inbuilt function.

All functions in jax are implemented as composale program/function transformations.

vmap automatically batches your code.
