In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU

In [2]:
import jax

In [3]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

In [4]:
jax.device_count()

8

In [10]:
import jax.numpy as jnp
import numpy as np

def linear_layer(x, w):
    print(x.shape, w.shape)
    return jnp.dot(x, w)

In [11]:
n = 16
d = 3
devices = jax.device_count()

In [12]:
xs = jnp.array(np.random.rand(n, d))
ws = jnp.array(np.random.rand(d,))

In [13]:
x_parts = np.stack(jnp.split(xs, devices))
w_parts = jax.tree_map(lambda x: jnp.stack([x for _ in range(devices)]), ws)

out = jax.pmap(linear_layer)(x_parts, w_parts)
print(out.shape) # (8, 2), out is a matrix of shape (n_devices, n_data // n_devices)

(2, 3) (3,)
(8, 2)


In [14]:
x_parts.shape

(8, 2, 3)

In [15]:
out = jax.pmap(linear_layer, in_axes=(0, None))(x_parts, ws)
print(out.shape) # (8, 2), out is a matrix of shape (n_devices, n_data // n_devices)

(2, 3) (3,)
(8, 2)


In [16]:
import numpy as np

import jax
import jax.random as random
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax import jit, vmap, pmap, grad, value_and_grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

from typing import List, Tuple, Dict

In [17]:
n_devices = jax.local_device_count()
print(f'Number of available devices: {n_devices}')

Number of available devices: 8


In [18]:
SEED = 42
MNIST_IMG_SIZE = (28, 28)
NUM_INP_NODES = np.prod(MNIST_IMG_SIZE)
NUM_CLASSES = 10

In [19]:
def init_MLP(layer_widths: List, parent_key, scale: float=0.01):

    params = []

    keys = random.split(parent_key, num=len(layer_widths) - 1)

    for num_in_nodes, num_out_nodes, key in zip(layer_widths[:-1], layer_widths[1:], keys):
        weight_key, bias_key = random.split(key)

        params.append(
            [
                scale * random.normal(weight_key, shape=(num_in_nodes, num_out_nodes)),
                scale * random.normal(bias_key, shape=(num_out_nodes,)),
            ]
        )

    return params

In [20]:
# test
key = random.PRNGKey(SEED)
MLP_params = init_MLP([NUM_INP_NODES, 512, 256, NUM_CLASSES], key)
print(jax.tree_map(lambda x: x.shape, MLP_params))

[[(784, 512), (512,)], [(512, 256), (256,)], [(256, 10), (10,)]]


In [21]:
def MLP_predict(params, x):

    hidden_layers = params[:-1]

    activation = x

    # print("activation.shape", activation.shape)
    for w, b in hidden_layers:
        # print("w.shape", w.shape)
        activation = jax.nn.relu(jnp.dot(activation, w) + b) # 16x784 x 512x784 = 16x784 x 784x512 = 16x512 x 512x256 = 16x256 x 256x10 = 16x10
        # print("activation.shape", activation.shape)

    w_last, b_last = params[-1]
    # print("w_last.shape", w_last.shape)
    # print("activation.shape", activation.shape)
    logits = jnp.dot(activation, w_last) + b_last
    # print("logits.shape", logits.shape)

    # return logits - logsumexp(logits, axis=0) 
    return logits - logsumexp(logits, axis=1, keepdims=True)

In [22]:
# tests

# test single example

dummy_img_flat = np.random.randn(NUM_INP_NODES)
print(dummy_img_flat.shape)

prediction = MLP_predict(MLP_params, dummy_img_flat)
print(prediction.shape)

# test batched function: method 1

dummy_imgs_flat = np.random.randn(16, NUM_INP_NODES)
print(dummy_imgs_flat.shape)

batched_MLP_predict = vmap(MLP_predict, in_axes=(None, 0))
predictions = batched_MLP_predict(MLP_params, dummy_imgs_flat)
print(predictions.shape)

(784,)
(10,)
(16, 784)
(16, 10)


In [23]:
# print(predictions)

In [24]:
from functools import partial

# test batched function: method 2

batched_MLP_predict = vmap(partial(MLP_predict, MLP_params))

predictions = batched_MLP_predict(dummy_imgs_flat)
print(predictions.shape)

(16, 10)


In [25]:
# print(predictions)

In [27]:
from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

[1.1566511 1.1805044 1.2052325 1.20457   1.1876906 1.203792  1.2322344
 1.2015207]


In [28]:
random.PRNGKey(0)

Array([0, 0], dtype=uint32)

In [29]:
def custom_transform(x):
    return np.ravel(np.array(x, dtype=np.float32))

def custom_collate_fn(batch):
    """Required because pytorch will convert the images and labels into torch tensors."""
    transposed_data = list(zip(*batch))

    imgs = jnp.stack(transposed_data[0])
    labels = jnp.array(transposed_data[1])

    imgs = jnp.stack(jnp.split(imgs, jax.device_count()))
    labels = jnp.stack(jnp.split(labels, jax.device_count()))

    return imgs, labels

In [30]:
BATCH_SIZE = 512
train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='test_mnist', train=False, download=True, transform=custom_transform)

train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 138223098.88it/s]

Extracting train_mnist/MNIST/raw/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 70633057.62it/s]


Extracting train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 225957833.06it/s]


Extracting train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7788441.85it/s]

Extracting train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to test_mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 130012262.20it/s]


Extracting test_mnist/MNIST/raw/train-images-idx3-ubyte.gz to test_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to test_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 29210439.79it/s]


Extracting test_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to test_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to test_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 93983793.07it/s]

Extracting test_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to test_mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to test_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7862372.58it/s]

Extracting test_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to test_mnist/MNIST/raw






In [31]:
# test
batch_data = next(iter(train_loader))
imgs = batch_data[0]
lbls = batch_data[1]
print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)

(8, 64, 784) float32 (8, 64) int32


In [32]:
# for batch in iter(train_loader):
#     imgs = batch_data[0]
#     lbls = batch_data[1]
#     print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)

In [33]:
# optimization - loading the whole dataset into memory
train_images = jnp.array(train_dataset.data).reshape(len(train_dataset), -1)
train_lbls = jnp.array(train_dataset.targets)

test_images = jnp.array(test_dataset.data).reshape(len(test_dataset), -1)
test_lbls = jnp.array(test_dataset.targets)

In [34]:
def loss_fn(params, images, labels):
    # print("images.shape, labels.shape", images.shape, labels.shape)
    prediction = MLP_predict(params, images)
    # print("prediction.shape", prediction.shape)

    return -jnp.mean(jnp.sum(prediction * labels, axis=1))


@jit
def accuracy(params, dataset_imgs, dataset_lbls):
    # target_class = jnp.argmax(dataset_lbls, axis=1)
    predicted_class = jnp.argmax(MLP_predict(params, dataset_imgs), axis=1)
    return jnp.mean(predicted_class == dataset_lbls)



import functools

@functools.partial(pmap, axis_name="batch", static_broadcasted_argnums=3)
def update(params, images, labels, step_size):

    loss_value, grads = value_and_grad(loss_fn)(params, images, labels)

    # We compute the total gradients, summing across the device-mapped axis,
    # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum.
    grads = [(jax.lax.pmean(dw, "batch"), jax.lax.pmean(db, "batch")) for dw, db in grads]

    """
    When you use pmap in JAX, it parallelizes the function execution across multiple devices,
    and the result you get back is a "sharded" array where each element corresponds to the result
    from one of the devices. Even if you aggregate the values using lax.pmean, when you print the
    result outside the pmap function, you will see an array where each element is the same
    aggregated value, one from each device.
    """
    loss_value_aggr = jax.lax.pmean(loss_value, axis_name="batch")

    return loss_value_aggr, [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]

In [35]:
NUM_EPOCHS = 5
LR = 1e-2

In [36]:
# Create a MLP
MLP_params = init_MLP([NUM_INP_NODES, 512, 256, len(MNIST.classes)], key)
print(jax.tree_map(lambda p: p.shape, MLP_params))

# Replicated to devices
MLP_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), MLP_params)
print(jax.tree_map(lambda p: p.shape, MLP_params))

[[(784, 512), (512,)], [(512, 256), (256,)], [(256, 10), (10,)]]
[[(8, 784, 512), (8, 512)], [(8, 512, 256), (8, 256)], [(8, 256, 10), (8, 10)]]


In [37]:
for epoch in range(NUM_EPOCHS):

    for cnt, (images, labels) in enumerate(train_loader):
        labels = jax.nn.one_hot(labels, NUM_CLASSES)
        loss, MLP_params = update(MLP_params, images, labels, LR)

        if cnt % 50 == 0:
            print(loss[0])
    # print(f'Epoch {epoch}')
    MLP_params_single = jax.tree_map(lambda x: x[0], MLP_params)
    # print(jax.tree_map(lambda p: p.shape, MLP_params))
    print(f'Epoch {epoch}, train acc = {accuracy(MLP_params_single, train_images, train_lbls)} test acc = {accuracy(MLP_params_single, test_images, test_lbls)}')
    # break

4.1816835
2.3606865
2.315687
Epoch 0, train acc = 0.9383666515350342 test acc = 0.9414999485015869
2.2659068
2.2056766
2.1791258
Epoch 1, train acc = 0.9557000398635864 test acc = 0.9535999894142151
2.1463342
2.1553676
2.1014066
Epoch 2, train acc = 0.9635666608810425 test acc = 0.9613999724388123
2.1307936
2.07494
2.1368456
Epoch 3, train acc = 0.9695833325386047 test acc = 0.9648999571800232
2.1357253
2.1224568
2.0870843
Epoch 4, train acc = 0.9738166928291321 test acc = 0.9679999947547913


In [None]:
# https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py


Additional Resources:
1.   [/docs/jax-101](https://github.com/google/jax/tree/main/docs/jax-101)
2.   [/cloud_tpu_colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs)
3.   [/examples](https://github.com/google/jax/tree/main/examples)
4.   [/notebooks](https://github.com/google/jax/tree/main/docs/notebooks)
