This notebook is an exploration in strategies for scaling neural nets to _very large_ models across multiple devices. 

We'll start off by looking at the basic types of parallelism, then we might explore a more complex strategy which combines elements of the above, such as deepspeed!  

- Data parallelism
- Model parallelism
- Pipeline parallelism
- Tensor parallelism 

A note on hardware: In this notebook we'll use a TPU because the underlying hardware makes it much, much easier (if your needs scale, you can shift to larger and larger TPU pods without issues with inter-machine communication). Later on, we'll look at the classic approach (kubernetes clusters of individual devices) - but I believe in the long term most large model training will occur on mesh networks of devices (like TPUs, or Tesla's dojo). 

A couple of resources that I've leant on:

- [This excellent series on deep learning hardware](https://blog.inten.to/hardware-for-deep-learning-current-state-and-trends-51c01ebbb6dc)
- [Lilian Weng's superb notes on training large models](https://lilianweng.github.io/lil-log/2021/09/24/train-large-neural-networks.html)
- [Ben Wang's GPT-J - to my knowledge the main published https://github.com/kingoflolz/mesh-transformer-jax


In [4]:
import jax
import jax.numpy as jnp
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.PRNGKey(42), (10000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")


In [26]:
!go tool pprof memory.prof

[0;31mMain binary filename not available.
[0mType: space
Entering interactive mode (type "help" for commands, "o" for options)
(pprof) 

In [13]:
from jax import pmap
result = pmap(lambda x: x ** 2)(jnp.arange(7))
print(result)


[ 0  1  4  9 16 25 36]


In [11]:
del y, z

In [22]:
from jax import random

# create 8 random keys
keys = random.split(random.PRNGKey(0), 8)
# create a 5000 x 6000 matrix on each device by mapping over keys
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# the stack of matrices is represented logically as a single array
mats.shape

(8, 5000, 6000)

In [23]:
# run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)
result.shape
result.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")

In [17]:
# compute the mean on each device in parallel and print the results
print(pmap(jnp.mean)(result))

[1.1566595 1.1805978 1.2052746 1.2045677 1.1876795 1.2037715 1.2321935
 1.2015157]
