In [1]:
import jax
from jax import numpy as jnp

jax.numpy.array(10)
print(jax.devices())
import tensorflow as tf

# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type="GPU")
import tensorflow_datasets as tfds
import optax
import equinox as eqx
from energax import nns
import time

[gpu(id=0)]


  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
# adapted from: https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html
data_dir = "/tmp/tfds"

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(
    name="mnist", batch_size=-1, data_dir=data_dir, with_info=True
)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data["train"], mnist_data["test"]
num_labels = info.features["label"].num_classes
h, w, c = info.features["image"].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data["image"], train_data["label"]
train_images = jnp.tile(
    jnp.reshape(train_images, (len(train_images), c, h, w)), [1, 3, 1, 1]
)
train_labels = jnp.array(train_labels).astype("int8")
# Full test set
test_images, test_labels = test_data["image"], test_data["label"]
test_images = jnp.tile(
    jnp.reshape(test_images, (len(test_images), c, h, w)), [1, 3, 1, 1]
)
test_labels = jnp.array(test_labels).astype("int8")

In [3]:
print("Train:", train_images.shape, train_labels.shape)
print("Test:", test_images.shape, test_labels.shape)

Train: (60000, 3, 28, 28) (60000,)
Test: (10000, 3, 28, 28) (10000,)


In [4]:
model = nns.resnet18(key=jax.random.PRNGKey(0), num_classes=10)



In [6]:
# adapted from https://docs.kidger.site/equinox/examples/mnist/
def loss_fn(model, x, y):
    pred_y = eqx.filter_vmap(model, in_axes=(0, None))(x, jax.random.PRNGKey(42))
    pred_y = jax.nn.log_softmax(pred_y)
    return cross_entropy(y, pred_y)


def cross_entropy(y, pred_y):
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


@eqx.filter_jit
def compute_accuracy(model, x, y):
    pred_y = eqx.filter_vmap(model, in_axes=(0, None))(x, jax.random.PRNGKey(42))
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)


loss = eqx.filter_jit(loss_fn)
optim = optax.adamw(0.1)
opt_state = optim.init(eqx.filter(model, eqx.is_array))


@eqx.filter_jit
def make_step(model, opt_state, x, y):
    loss_value, grads = eqx.filter_value_and_grad(loss)(model, x, y)
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value


def get_train_batches():
    # as_supervised=True gives us the (image, label) as a tuple instead of a dict
    ds = tfds.load(name="mnist", split="train", as_supervised=True, data_dir=data_dir)
    # You can build up an arbitrary tf.data input pipeline
    ds = ds.batch(32).prefetch(1)
    # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
    return tfds.as_numpy(ds)


for epoch in range(
    0
):  # github actions can't run this, so we have to do this until we have GPU testing set up
    start_time = time.time()
    for x, y in get_train_batches():
        x = jnp.tile(jnp.reshape(x, (len(x), c, h, w)), [1, 3, 1, 1]) / 255.0
        y = jnp.array(y).astype("int8")
        model, opt_state, train_loss = make_step(model, opt_state, x, y)
    epoch_time = time.time() - start_time

    train_acc = compute_accuracy(model, train_images / 255.0, train_labels)
    test_acc = compute_accuracy(model, test_images / 255.0, test_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))
    print(loss(model, train_images / 255.0, train_labels))

Epoch 0 in 24.67 sec
Training set accuracy 0.9485499858856201
Test set accuracy 0.955299973487854
0.25354156
Epoch 1 in 18.66 sec
Training set accuracy 0.9228833317756653
Test set accuracy 0.9253999590873718
0.28521124
Epoch 2 in 18.23 sec
Training set accuracy 0.9639833569526672
Test set accuracy 0.963699996471405
0.17131698
Epoch 3 in 18.44 sec
Training set accuracy 0.9675999879837036
Test set accuracy 0.9692999720573425
0.24755667
Epoch 4 in 18.39 sec
Training set accuracy 0.9816666841506958
Test set accuracy 0.9809999465942383
0.1749611
Epoch 5 in 18.47 sec
Training set accuracy 0.9700000286102295
Test set accuracy 0.9692999720573425
0.30764982
Epoch 6 in 18.35 sec
Training set accuracy 0.9630500078201294
Test set accuracy 0.9630999565124512
0.23849201
Epoch 7 in 18.38 sec
Training set accuracy 0.9869666695594788
Test set accuracy 0.9833999872207642
0.14752926
Epoch 8 in 18.39 sec
Training set accuracy 0.9830499887466431
Test set accuracy 0.983199954032898
0.15914722
Epoch 9 in 18.