In [3]:
import jax
from jax import numpy as jnp

import flax
from flax import linen as nn 
from flax import optim

import numpy as np
import tensorflow_datasets as tfds


In [4]:
print(jax.devices())

[TpuDevice(id=0, task=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, task=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, task=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, task=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, task=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, task=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, task=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, task=0, coords=(1,1,0), core_on_chip=1)]


In [5]:
class mnist(nn.Module):
    """."""
    def setup(self):
        self.conv1 = nn.Conv(features=32, kernel_size=(3,3))
        self.conv2 = nn.Conv(features=64, kernel_size=(2,2))
        self.dens1 = nn.Dense(features=256)
        self.dens2 = nn.Dense(features=10)
        
    def __call__(self, x):
        x = self.conv1(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = self.conv2(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = x.reshape((x.shape[0], -1))
        x = self.dens1(x)
        x = nn.relu(x)
        x = self.dens2(x)
        x = nn.log_softmax(x)
        return x

In [6]:
def cross_entropy_loss(logits, labels):
    """ logits are assumed to be log(p^) """
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

In [7]:
def get_optimizer(params, **kwopts):
    optimizer_def = optim.Adam(**kwopts)
    return optimizer_def.create(params)

In [8]:
def get_initial_params(key):
    init_shape = np.ones((1, 28, 28, 1), jnp.float32)
    initial_params = mnist().init(key, init_shape)['params']
    return initial_params

In [9]:
def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy
    }
    return metrics

In [10]:
def get_datasets():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
    return train_ds, test_ds

In [21]:
@jax.jit
def train_step(optimizer, batch):
    def loss_fn(params):
        logits = mnist().apply({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits, batch['label'])
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grad = grad_fn(optimizer.target)
    optimizer = optimizer.apply_gradient(grad)
    metrics = compute_metrics(logits, batch['label'])
    return optimizer, metrics

In [22]:
@jax.jit
def eval_step(params, batch):
    logits = mnist().apply({'params': params}, batch['image'])
    return compute_metrics(logits, batch['label'])

In [23]:
def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
    # Compute number of steps
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size
    
    #Shuffle the data
    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size] # drop incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []

    
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        optimizer, metrics = train_step(optimizer, batch)
        batch_metrics.append(metrics)
    
    train_batch_metrics = jax.device_get(batch_metrics)
    train_epoch_metrics = {
        k: np.mean([metrics[k] for metrics in train_batch_metrics])
        for k in train_batch_metrics[0]
    }
    print(f"Training - epoch: {epoch}, loss: {train_epoch_metrics['loss']}, accuracy: {train_epoch_metrics['accuracy']}")
    return optimizer, train_epoch_metrics

In [24]:
def eval_model(model, test_ds):
    """ model / params ?????"""
    metrics = eval_step(model, test_ds)
    metrics = jax.device_get(metrics)
    eval_summary = jax.tree_map(lambda x: x.item(), metrics)
    return eval_summary['loss'], eval_summary['accuracy']

In [25]:
train_ds, test_ds = get_datasets()

In [26]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

In [27]:
params = get_initial_params(init_rng)

In [32]:
learning_rate = 0.001
beta = 0.9
num_epochs = 10
batch_size = 32

In [33]:
optimizer = get_optimizer(params, learning_rate=learning_rate)

In [34]:
for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    optimizer, train_metrics = train_epoch(optimizer, train_ds, batch_size, epoch, input_rng)
    test_loss, test_accuracy = eval_model(optimizer.target, test_ds)
    print(f"Testing - epoch: {epoch}, loss: {test_loss}, accuracy: {test_accuracy}")
    

Training - epoch: 1, loss: 0.1260634958744049, accuracy: 0.9616333246231079
Testing - epoch: 1, loss: 0.050780244171619415, accuracy: 0.9829999804496765
Training - epoch: 2, loss: 0.04492035135626793, accuracy: 0.9859166741371155
Testing - epoch: 2, loss: 0.03572174161672592, accuracy: 0.988099992275238
Training - epoch: 3, loss: 0.030871640890836716, accuracy: 0.9904666543006897
Testing - epoch: 3, loss: 0.02965500019490719, accuracy: 0.9889999628067017
Training - epoch: 4, loss: 0.0235484316945076, accuracy: 0.9924666881561279
Testing - epoch: 4, loss: 0.02583594061434269, accuracy: 0.9918999671936035
Training - epoch: 5, loss: 0.01658724993467331, accuracy: 0.9947166442871094
Testing - epoch: 5, loss: 0.031633954495191574, accuracy: 0.9907999634742737
Training - epoch: 6, loss: 0.01394949946552515, accuracy: 0.9955833554267883
Testing - epoch: 6, loss: 0.03818190097808838, accuracy: 0.9888999462127686
Training - epoch: 7, loss: 0.010369012132287025, accuracy: 0.9966333508491516
Test