<a href="https://colab.research.google.com/github/rahuldave/LearningJax/blob/main/Jax_first.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Jax.first

Based on "pmap jit vmap oh my" by Mat Kecey (_[solving y=mx+b... with jax on a tpu pod slice](http://matpalm.com/blog/ymxb_pod_slice/) blog series_)

Based on Sabrina Mielke

Based on Kevin Murphy

In [31]:
!cat /proc/cpuinfo

processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
microcode	: 0x1
cpu MHz		: 2199.998
cache size	: 56320 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
bogomips	: 4399.99
clflush size	: 64
cache_alignment	: 64
address sizes	: 46 bits physical, 48 b

In [32]:
USE_TPU = True

if USE_TPU:
  import jax
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()
else:
  # x8 cpu devices  
  import os
  os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

In [33]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

## Value Proposition 1:  `jax.numpy` replaces `numpy`

jax.numpy provides a dropin replacement for numpy (*) that runs ops on whatever accelerator you have

In [34]:
import numpy as np
import jax.numpy as jnp

In [35]:
size = int(1e3)
number_of_loops=int(1e2)

In [36]:
def f(x=None):
  if not isinstance(x, np.ndarray):
    x=np.ones((size, size), dtype=np.float32) 
  return np.dot(x, x.T)

In [37]:
%timeit -o -n $number_of_loops f()

100 loops, best of 5: 40 ms per loop


<TimeitResult : 100 loops, best of 5: 40 ms per loop>

In [38]:
# JAX device execution
# https://github.com/google/jax/issues/1598

def jf(x=None): 
  if not isinstance(x, jnp.ndarray):
    x=jnp.ones((size, size), dtype=jnp.float32)
  return jnp.dot(x, x.T)

In [39]:
%timeit -o -n $number_of_loops jf() 

100 loops, best of 5: 6.42 ms per loop


<TimeitResult : 100 loops, best of 5: 6.42 ms per loop>

In [40]:
from jax import jit
f_tpu = jit(jf)
f_cpu = jit(jf, backend='cpu')

In [41]:
%timeit -o -n $number_of_loops f_cpu() 

100 loops, best of 5: 49.5 ms per loop


<TimeitResult : 100 loops, best of 5: 49.5 ms per loop>

In [42]:
%timeit -o -n $number_of_loops f_tpu().block_until_ready() 

100 loops, best of 5: 1.86 ms per loop


<TimeitResult : 100 loops, best of 5: 1.86 ms per loop>

# function transformation

we start with a simple function

In [43]:
def f(x):
  return 2*x*x + 3*x + 3

f(3)

30

we can use `make_jaxpr` to trace the function and show us a jax expression of what the function does

In [44]:
from jax import make_jaxpr

trace_f = make_jaxpr(f)

trace_f(3)

{ lambda ; a:i32[]. let
    b:i32[] = mul a 2
    c:i32[] = mul b a
    d:i32[] = mul a 3
    e:i32[] = add c d
    f:i32[] = add e 3
  in (f,) }

## gradients

a key thing we use jax for is calculating gradients with `grad`. by dft it calculates with represent to the first arg.

In [45]:
from jax import grad

g_f = grad(f)

g_f(3.)

DeviceArray(15., dtype=float32, weak_type=True)

all functions are composable; so we can use `make_jaxpr` to get insight into what `grad` is doing

In [46]:
trace = make_jaxpr(grad(f))
trace(3.)

{ lambda ; a:f32[]. let
    b:f32[] = mul a 2.0
    c:f32[] = mul b a
    d:f32[] = mul a 3.0
    e:f32[] = add c d
    _:f32[] = add e 3.0
    f:f32[] = mul 1.0 3.0
    g:f32[] = mul b 1.0
    h:f32[] = mul 1.0 a
    i:f32[] = add_any f g
    j:f32[] = mul h 2.0
    k:f32[] = add_any i j
  in (k,) }

there's also a small helper function called `value_and_grad` which runs the function as well as calculating the gradien

In [47]:
from jax import value_and_grad

vg_f = value_and_grad(f)

value, gradient = vg_f(3.)
value, gradient

(DeviceArray(30., dtype=float32, weak_type=True),
 DeviceArray(15., dtype=float32, weak_type=True))

# jitting

jitting compiles an entire function using [xla](https://www.tensorflow.org/xla) to one with the same signature but is optimised for the accelerator you are running on (e.g. GPU or TPU).

In [48]:
from jax import jit

jitted_f = jit(f)

jitted_f(3)

DeviceArray(30, dtype=int32, weak_type=True)

In [49]:
make_jaxpr(jitted_f)(3)

{ lambda ; a:i32[]. let
    b:i32[] = xla_call[
      call_jaxpr={ lambda ; c:i32[]. let
          d:i32[] = mul c 2
          e:i32[] = mul d c
          f:i32[] = mul c 3
          g:i32[] = add e f
          h:i32[] = add g 3
        in (h,) }
      name=f
    ] a
  in (b,) }

## vectorisation

a key aspect of speeding things up is vectorising calls. consider the following call

In [78]:
D = 2
N = 3

w = np.random.normal(size=(D,))
X = np.random.normal(size=(N,D))

def sigmoid(x): return 0.5 * (jnp.tanh(x / 2.) + 1)

def predict_single(x):
    return sigmoid(jnp.dot(w, x)) # <(D) , (D)> = (1) # inner product

In [79]:
print(predict_single(X[0,:])) # works

0.4431823


In [80]:
print(predict_single(X)) 

TypeError: ignored

In [81]:
def predict_batch(X):
    return sigmoid(jnp.dot(X, w)) # (N,D) * (D,1) = (N,1) # matrix-vector multiply

print(predict_batch(X)) 

[0.4431823  0.44772643 0.23593867]


In [82]:
print(vmap(predict_single)(X))

[0.4431823  0.44772643 0.23593867]


In [50]:
import jax.numpy as jnp
import numpy as np

def f(x):    
  return jnp.array([jnp.min(x), jnp.max(x)])

let's run it with an input of `(2, 3)`

In [51]:
x = np.arange(6).reshape((2, 3))
x

array([[0, 1, 2],
       [3, 4, 5]])

In [52]:
f(x)

DeviceArray([0, 5], dtype=int32)

now say we want to batch this call and run it on a batch of 4 inputs; i.e. x is `(4, 2, 3)` to return 4 sets of the `min, max`  i.e. `(4, 2)`

In [53]:
bx = np.arange(24).reshape((4, 2, 3))
bx

array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23]]])

the code as it is now doesn't do exactly what we want since the min, max operate globally

In [54]:
f(bx)

DeviceArray([ 0, 23], dtype=int32)

we could fix this directly in the original function by being explicit about what axis we want to do the min and max over but let's do it more implicitly using `vmap`.

`vmap` transforms a function keeping the signature the same but expecting params, by default the first, to operate with an extra leading dim. 

In [55]:
from jax import vmap

               # f  in:    (2, 3) out:    (2)
v_f = vmap(f)  # vf in: (B, 2, 3) out: (B, 2)
v_f(bx)

DeviceArray([[ 0,  5],
             [ 6, 11],
             [12, 17],
             [18, 23]], dtype=int32)

In [56]:
make_jaxpr(f)(x)

{ lambda ; a:i32[2,3]. let
    b:i32[] = reduce_min[axes=(0, 1)] a
    c:i32[] = reduce_max[axes=(0, 1)] a
    d:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] b
    e:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] c
    f:i32[2] = concatenate[dimension=0] d e
  in (f,) }

In [57]:
make_jaxpr(v_f)(bx)

{ lambda ; a:i32[4,2,3]. let
    b:i32[4] = reduce_min[axes=(1, 2)] a
    c:i32[4] = reduce_max[axes=(1, 2)] a
    d:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] b
    e:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] c
    f:i32[4,2] = concatenate[dimension=1] d e
  in (f,) }

In [58]:
jv_f = jit(vmap(f))
jv_f(bx)

DeviceArray([[ 0,  5],
             [ 6, 11],
             [12, 17],
             [18, 23]], dtype=int32)

the main point is that as f gets more and more complex it can make it harder and harder for f to be batch aware if you do it yourself. with vmap it's not your problem.

vectorising in the above example was based on the default behaviour of vectorising the compute across the first axis of the first arg, but we have a lot of control on how we want to vectorise; e.g. consider

In [59]:
make_jaxpr(jv_f)(bx)

{ lambda ; a:i32[4,2,3]. let
    b:i32[4,2] = xla_call[
      call_jaxpr={ lambda ; c:i32[4,2,3]. let
          d:i32[4] = reduce_min[axes=(1, 2)] c
          e:i32[4] = reduce_max[axes=(1, 2)] c
          f:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] d
          g:i32[4,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] e
          h:i32[4,2] = concatenate[dimension=1] f g
        in (h,) }
      name=f
    ] a
  in (b,) }

and since jax calls are composable we'd be able to jit this entire call to allow it to be compiled by XLA


In [37]:
def f(a, b, c):
  return a + b + c
  
v_f = jit(vmap(f, in_axes=(0, None, 1)))

make_jaxpr(v_f)(a, b, c)

{ lambda ; a:i32[2,3] b:i32[] c:i32[3,2]. let
    d:i32[2,3] = xla_call[
      call_jaxpr={ lambda ; e:i32[2,3] f:i32[] g:i32[3,2]. let
          h:i32[] = convert_element_type[new_dtype=int32 weak_type=False] f
          i:i32[2,3] = add e h
          j:i32[2,3] = transpose[permutation=(1, 0)] g
          k:i32[2,3] = add i j
        in (k,) }
      name=f
    ] a b c
  in (d,) }

# parallelisation

`pmap` has the same behaviour as `vmap` except that instead of modifying the ops of the function to use vectorised versions; it uses xla to create a SPMD program to run the function in parallel across all available devices.

in this case we've set things up to operate with 8 cpu devices to match the standard 8 we get from a TPU

In [60]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [61]:
from jax import pmap

def f(a, b, c):
  return a + b + c
  
p_f = pmap(f, in_axes=(0, None, 0)) 

a = np.random.random(size=(8, 3))
b = 4
c = np.random.random(size=(8, 3))

p_f(a, b, c)

ShardedDeviceArray([[4.610702 , 5.027648 , 5.1398377],
                    [5.3299513, 5.7859383, 4.8855596],
                    [5.386921 , 5.0821943, 5.6099157],
                    [5.2989388, 4.456516 , 5.5993443],
                    [4.8503604, 4.570438 , 5.9344563],
                    [4.9746323, 4.9166603, 5.000825 ],
                    [4.8715687, 4.8265533, 5.4546843],
                    [4.97842  , 5.034208 , 4.9314623]], dtype=float32)

In [62]:
make_jaxpr(pmap(f, in_axes=(0, None, 0)))(a, b, c)

{ lambda ; a:f32[8,3] b:i32[] c:f32[8,3]. let
    d:f32[8,3] = xla_pmap[
      axis_name=<axis 0x7fe26549ce60>
      axis_size=8
      backend=None
      call_jaxpr={ lambda ; e:f32[3] f:i32[] g:f32[3]. let
          h:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
          i:f32[3] = add e h
          j:f32[3] = add i g
        in (j,) }
      devices=None
      donated_invars=(False, False, False)
      global_arg_shapes=(None, None, None)
      global_axis_size=None
      in_axes=(0, None, 0)
      name=f
      out_axes=(0,)
    ] a b c
  in (d,) }

also note that the return type is `ShardedDeviceArray`. 

this result is sharded across the 8 devices so even though we can trivially manipulate it, the data is stored across the 8. this is super important for larger calculations where we want the input and output to be sharded across the devices so that there is minimal data transfer. it's the classic move-the-compute-to-the-data idea

# composition

recall again that all these things are composable!

consider this matrix multiply and add function

In [69]:
def f(a, b, c):
  return a @ b + c

a = np.random.random(size=(2, 3))
b = np.random.random(size=(3, 4))
c = np.random.random(size=(2, 4))

f(a, b, c)

array([[0.38923554, 1.6374275 , 1.47928243, 0.81759826],
       [0.75177693, 1.39819926, 1.88215504, 1.90658164]])

we start by vmapping a leading dimension; not that it does much since both the mat mul and the add handle this natively.

In [70]:
v_f = vmap(f)

a = np.random.random(size=(5, 2, 3))
b = np.random.random(size=(5, 3, 4))
c = np.random.random(size=(5, 2, 4))

v_f(a, b, c).shape

(5, 2, 4)

In [71]:
v_f(a, b, c)

DeviceArray([[[0.5705424 , 1.0822833 , 0.3329847 , 1.2430818 ],
              [2.0149455 , 0.5200294 , 1.2903602 , 2.2827947 ]],

             [[0.28845692, 0.34275436, 1.1540034 , 0.39978826],
              [0.74911225, 1.3497909 , 1.1284459 , 0.8178835 ]],

             [[1.0889641 , 1.3528488 , 0.30811077, 0.4808404 ],
              [2.2658257 , 1.8742003 , 1.844905  , 1.925225  ]],

             [[1.4738675 , 1.1812127 , 0.63808006, 1.3177973 ],
              [1.9129457 , 1.702003  , 1.4550859 , 2.5734744 ]],

             [[1.7625363 , 1.8102849 , 1.3327197 , 0.4980368 ],
              [1.9050145 , 1.9309726 , 1.2529042 , 0.92324704]]],            dtype=float32)

to ensure this operation runs on any accelerator we have we can jit it

In [72]:
jv_f = jit(vmap(f))

a = np.random.random(size=(5, 2, 3))
b = np.random.random(size=(5, 3, 4))
c = np.random.random(size=(5, 2, 4))

jv_f(a, b, c).shape

(5, 2, 4)

furthermore we could wrap the vmap with a pmap which will operate across the 8 devices using `xla_pmap`. 

this makes a function that
* ships the function and shards of the params to the 8 devices  (pmap)
* each of which runs an xla optimised version of the function  (implied by pmap)
* which is a vectorised verison of the original f  (vmap)

In [75]:
pjv_f = pmap(vmap(f))

a = np.random.random(size=(3, 5, 2, 3))
b = np.random.random(size=(3, 5, 3, 4))
c = np.random.random(size=(3, 5, 2, 4))

pjv_f(a, b, c)

ShardedDeviceArray([[[[0.9433532 , 0.7800323 , 0.9148023 , 0.8643985 ],
                      [1.2193415 , 1.4516981 , 1.0177786 , 0.9542628 ]],

                     [[1.1579235 , 0.5196863 , 0.68611336, 0.9083721 ],
                      [1.242131  , 0.5495486 , 1.0628719 , 1.039368  ]],

                     [[0.49707088, 1.200995  , 1.3639457 , 1.0757515 ],
                      [0.92943966, 0.6963442 , 2.1469278 , 1.4274914 ]],

                     [[1.6816928 , 1.8286586 , 1.5609622 , 2.277726  ],
                      [0.8542969 , 0.6514341 , 0.5881252 , 0.44545168]],

                     [[1.5182525 , 0.99001217, 1.5513334 , 1.5588503 ],
                      [1.5047882 , 1.4383408 , 1.6903621 , 1.7581959 ]]],


                    [[[2.7925072 , 2.59694   , 1.9662273 , 1.248892  ],
                      [1.1679978 , 1.1228945 , 1.2898302 , 0.38388687]],

                     [[0.9602714 , 0.6349373 , 0.5677942 , 1.0699091 ],
                      [1.0451251 , 1.8112588 , 2.5

in the next colab we'll use these functions to do a simple optimisation

# random numbers in jax

as a last piece, let's just talk quickly about random numbers which require just a little bit more work in jax; there is no stateful random number generation. everything needs to be based on a key.

In [76]:
key = jax.random.PRNGKey(1337)
print(jax.random.uniform(key))
print(jax.random.uniform(key))

0.02251327
0.02251327


when you want to generate more random numbers you need to explicit split the key

In [77]:
key = jax.random.PRNGKey(1337)
for _ in range(10):
  key, key2 = jax.random.split(key)
  print(jax.random.uniform(key2))

0.31418657
0.8795924
0.81679654
0.82754505
0.78958607
0.9006636
0.18430483
0.2509787
0.5041659
0.11195123


In [83]:
# You cannot assign directly to elements of an array.

A = jnp.zeros((3,3), dtype=np.float32)

# In place update of JAX's array will yield an error!
try:
  A[1, :] = 1.0
except:
  print('must use index_update')

must use index_update
