## Import libraries

In [1]:
import numpy as np
import jax
import time
import jax.numpy as jnp
from jax.scipy.special import logsumexp

from jax import random, vmap, pmap, jit
from jax.tree_util import tree_map


from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import functools

  from .autonotebook import tqdm as notebook_tqdm


## Check for TPU devices

In [2]:
n_devices = jax.local_device_count()
print(f"Number of available devices: {n_devices}")
jax.devices()

E0419 20:20:27.064041705     231 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:"2024-04-19T20:20:27.064025181+00:00", grpc_status:2}


Number of available devices: 8


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
seed = 42
num_classes = 10
image_size = (28, 28)
batch_size_per_device = 128
batch_size = batch_size_per_device * n_devices

In [4]:
def custom_transform(x):
    x = np.resize(x, new_shape=image_size)
    return np.ravel(np.array(x, dtype=np.float32))

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))

    labels = np.array(transposed_data[1])
    imgs = np.stack(transposed_data[0])

    return imgs, labels

train_dataset = MNIST(root='train_mnist', 
                      train=True, 
                      download=True, 
                      transform=custom_transform)
test_dataset = MNIST(root='test_mnist', 
                     train=False, 
                     download=True,
                     transform=custom_transform)

train_loader = DataLoader(train_dataset, 
                          batch_size, 
                          collate_fn=custom_collate_fn,
                          shuffle=True, 
                          drop_last=True)
test_loader = DataLoader(test_dataset, 
                         batch_size, 
                         collate_fn=custom_collate_fn,
                         shuffle=False, 
                         drop_last=True)

# test
batch_data = next(iter(train_loader))
imgs = batch_data[0]
lbls = batch_data[1]
print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)

# # optimization - loading the whole dataset into memory
train_images = jnp.array(train_dataset.data).reshape(len(train_dataset), -1)
train_lbls = jnp.array(train_dataset.targets)

test_images = jnp.array(test_dataset.data).reshape(len(test_dataset), -1)
test_lbls = jnp.array(test_dataset.targets)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

100%|██████████| 9912422/9912422 [00:00<00:00, 111431963.32it/s]




Extracting train_mnist/MNIST/raw/train-images-idx3-ubyte.gz to train_mnist/MNIST/raw



Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

100%|██████████| 28881/28881 [00:00<00:00, 49706891.19it/s]




Extracting train_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to train_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

100%|██████████| 1648877/1648877 [00:00<00:00, 33308568.55it/s]




Extracting train_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to train_mnist/MNIST/raw



Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

100%|██████████| 4542/4542 [00:00<00:00, 9789583.13it/s]




Extracting train_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to train_mnist/MNIST/raw



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to test_mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

 99%|█████████▊| 9764864/9912422 [00:00<00:00, 95954886.76it/s]

100%|██████████| 9912422/9912422 [00:00<00:00, 95851341.88it/s]

Extracting test_mnist/MNIST/raw/train-images-idx3-ubyte.gz to test_mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to test_mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

100%|██████████| 28881/28881 [00:00<00:00, 63488309.13it/s]




Extracting test_mnist/MNIST/raw/train-labels-idx1-ubyte.gz to test_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to test_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

100%|██████████| 1648877/1648877 [00:00<00:00, 33277795.99it/s]




Extracting test_mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to test_mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to test_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

100%|██████████| 4542/4542 [00:00<00:00, 12483963.81it/s]




Extracting test_mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to test_mnist/MNIST/raw



(1024, 784) float32 (1024,) int64


In [5]:
def forward(params, inputs):
    activations = inputs
    for w, b in params[:-1]:
        outputs = jnp.dot(activations, w) + b
        activations = jax.nn.relu(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(activations, final_w) + final_b
    return logits - logsumexp(logits, axis=1, keepdims=True)

In [6]:
def init_MLP(layer_widths, key, scale=0.001):
    params = []
    keys = jax.random.split(key, num=len(layer_widths)-1)
    for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):
        weight_key, bias_key = jax.random.split(key)
        params.append(
            [scale * random.normal(weight_key, shape=(in_width, out_width)),
             scale * random.normal(bias_key, shape=(out_width, ))]
        )

    return params

key = jax.random.PRNGKey(seed)
MLP_params = init_MLP([image_size[0] * image_size[1], 1024, 1024, num_classes], key)


In [7]:
batch_data = np.stack(jnp.split(imgs, n_devices))
replicated_params = jax.tree_map(lambda x: np.stack([x for _ in range(n_devices)]), MLP_params)

In [8]:
parallel_forward = jax.pmap(forward)
print(parallel_forward(replicated_params, batch_data).shape) # (8, 2), out is a matrix of shape (n_devices, n_data // n_devices)

(8, 128, 10)


In [9]:
def loss_fn(params, images, labels):
    preds = forward(params, images)
    return -jnp.mean(jnp.sum(preds * labels, axis=1))

@jit
def accuracy(params, batch):
    inputs, targets = batch
    predicted_class = jnp.argmax(forward(params, inputs), axis=1)
    return jnp.mean(predicted_class == targets)

# function for performing one SGD update step (fwd & bwd pass)
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params, xs, ys, lr=0.01):
    loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
    grads = jax.lax.pmean(grads, axis_name='num_devices')
    loss = jax.lax.pmean(loss, axis_name='num_devices')
    new_params = jax.tree_map(
      lambda param, g: param - g * lr, params, grads)
 
    return loss / n_devices, new_params

num_epochs = 10
MLP_params = init_MLP([image_size[0] * image_size[1], 1024, 1024, num_classes], key)
replicated_params = jax.tree_map(lambda x: np.stack([x for _ in range(n_devices)]), MLP_params)

for epoch in range(num_epochs):
    loss_l = []
    start_time = time.time()
    for images, labels in train_loader:
        labels = jax.nn.one_hot(labels, num_classes)
        images, labels = np.stack(jnp.split(images, n_devices)), np.stack(jnp.split(labels, n_devices))
        loss, replicated_params = update(replicated_params, images, labels)
        loss_l.append(loss)
        epoch_time = time.time() - start_time
    params = jax.tree_map(lambda x: x[0], replicated_params)
    print(f"Epoch {epoch+1} in {epoch_time:0.2f} sec, loss = {jnp.mean(jnp.array(loss_l))} " \
          f"train_acc = {accuracy(params, (train_images, train_lbls))} test_acc = {accuracy(params, (test_images, test_lbls))}")

     

Epoch 1 in 4.52 sec, loss = 0.26211103796958923 train_acc = 0.6533499956130981 test_acc = 0.6635000109672546


Epoch 2 in 3.80 sec, loss = 0.10136502236127853 train_acc = 0.8719000220298767 test_acc = 0.8769999742507935


Epoch 3 in 3.91 sec, loss = 0.047795940190553665 train_acc = 0.9121833443641663 test_acc = 0.9138000011444092


Epoch 4 in 3.84 sec, loss = 0.036304499953985214 train_acc = 0.9282000064849854 test_acc = 0.9305999875068665


Epoch 5 in 3.78 sec, loss = 0.02930382266640663 train_acc = 0.9400833249092102 test_acc = 0.9406999945640564


Epoch 6 in 3.95 sec, loss = 0.0243099182844162 train_acc = 0.9508166909217834 test_acc = 0.9496999979019165


Epoch 7 in 3.76 sec, loss = 0.021001029759645462 train_acc = 0.9553333520889282 test_acc = 0.9540999531745911


Epoch 8 in 3.84 sec, loss = 0.018768660724163055 train_acc = 0.9600333571434021 test_acc = 0.9569000005722046


Epoch 9 in 3.84 sec, loss = 0.016259359195828438 train_acc = 0.9660833477973938 test_acc = 0.9625999927520752


Epoch 10 in 3.88 sec, loss = 0.01457468792796135 train_acc = 0.9695000052452087 test_acc = 0.9637999534606934
