In [2]:
import os
import requests
os.environ['TPU_ADDR'] = "10.20.137.146"

if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['TPU_ADDR']
print('Registered TPU:', config.FLAGS.jax_backend_target)

Registered TPU: grpc://10.20.137.146


In [3]:
import jax
from jax import numpy as jnp

import flax
from flax import linen as nn 
from flax import optim

import numpy as np
import tensorflow_datasets as tfds


In [5]:
print(jax.devices())

RuntimeError: Deadline exceeded: Failed to connect to remote server at address: grpc://10.20.137.146. Error from gRPC: Deadline Exceeded. Details: 

In [46]:
class mnist(nn.Module):
    """."""
    def setup(self):
        self.conv1 = nn.Conv(features=32, kernel_size=(3,3))
        self.conv2 = nn.Conv(features=64, kernel_size=(2,2))
        self.dens1 = nn.Dense(features=256)
        self.dens2 = nn.Dense(features=10)
        
    def __call__(self, x):
        x = self.conv1(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = self.conv2(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2))
        x = x.reshape((x.shape[0], -1))
        x = self.dens1(x)
        x = nn.relu(x)
        x = self.dens2(x)
        x = nn.log_softmax(x)
        return x

In [47]:
def cross_entropy_loss(logits, labels):
    """ logits are assumed to be log(p^) """
    one_hot_labels = jax.nn.one_hot(labels, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))

In [48]:
def get_optimizer(params, **kwopts):
    optimizer_def = optim.Adam(**kwopts)
    return optimizer_def.create(params)

In [49]:
def get_initial_params(key):
    init_shape = np.ones((1, 28, 28, 1), jnp.float32)
    initial_params = mnist().init(key, init_shape)['params']
    return initial_params

In [108]:
def compute_metrics(logits, labels):
    loss = cross_entropy_loss(logits, labels)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {
        'loss': loss,
        'accuracy': accuracy
    }
    return metrics

In [109]:
def get_datasets():
    ds_builder = tfds.builder('mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.0
    return train_ds, test_ds

In [110]:
def train_step(optimizer, batch):
    def loss_fn(params):
        logits = mnist().apply({'params': params}, batch['image'])
        loss = cross_entropy_loss(logits, batch['label'])
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grad = grad_fn(optimizer.target)
    optimizer = optimizer.apply_gradient(grad)
    metrics = compute_metrics(logits, batch['label'])
    return optimizer, metrics

In [111]:
def eval_step(params, batch):
    logits = mnist().apply({'params': params}, batch['image'])
    return compute_metrics(logits, batch['label'])

In [151]:
def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
    # Compute number of steps
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size
    
    #Shuffle the data
    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size] # drop incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))
    batch_metrics = []

    
    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        optimizer, metrics = train_step(optimizer, batch)
        batch_metrics.append(metrics)
    
    train_batch_metrics = jax.device_get(batch_metrics)
    train_epoch_metrics = {
        k: np.mean([metrics[k] for metrics in train_batch_metrics])
        for k in train_batch_metrics[0]
    }
    print(f"Training - epoch: {epoch}, loss: {train_epoch_metrics['loss']}, accuracy: {train_epoch_metrics['accuracy']}")
    return optimizer, train_epoch_metrics

In [152]:
def eval_model(model, test_ds):
    """ model / params ?????"""
    metrics = eval_step(model, test_ds)
    metrics = jax.device_get(metrics)
    eval_summary = jax.tree_map(lambda x: x.item(), metrics)
    return eval_summary['loss'], eval_summary['accuracy']

In [153]:
train_ds, test_ds = get_datasets()

In [154]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

In [155]:
params = get_initial_params(init_rng)

In [156]:
learning_rate = 0.1
beta = 0.9
num_epochs = 2
batch_size = 32

In [157]:
optimizer = get_optimizer(params, learning_rate=learning_rate)

In [158]:
for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    optimizer, train_metrics = train_epoch(optimizer, train_ds, batch_size, epoch, input_rng)
    test_loss, test_accuracy = eval_model(optimizer.target, test_ds)
    print(f"Testing - epoch: {epoch}, loss: {test_loss}, accuracy: {test_accuracy}")
    

Training - epoch: 1, loss: 2.71403169631958, accuracy: 0.10456666350364685
Testing - epoch: 1, loss: 2.3091375827789307, accuracy: 0.10090000182390213
Training - epoch: 2, loss: 2.3141160011291504, accuracy: 0.1029166653752327
Testing - epoch: 2, loss: 2.3095853328704834, accuracy: 0.10279999673366547
