In [7]:
import jax
import jax.numpy as jnp
from jax import random, jit, grad
import flax.linen as nn
import optax
import tensorflow_datasets as tfds
import time

In [8]:
class Net(nn.Module):
    def setup(self):
        self.conv1 = nn.Conv(features=32, kernel_size=(3, 3))
        self.conv2 = nn.Conv(features=64, kernel_size=(3, 3))
        self.fc1 = nn.Dense(features=128)
        self.fc2 = nn.Dense(features=10)
        self.dropout = nn.Dropout(rate=0.5)

    def __call__(self, x, training=True, dropout_key=None):
        x = self.conv1(x)
        x = nn.relu(x)
        x = self.conv2(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = self.fc1(x)
        x = nn.relu(x)
        x = self.dropout(x, deterministic=not training, rng=dropout_key)
        x = self.fc2(x)
        return x

In [9]:
key = random.PRNGKey(0)

# Load the MNIST dataset
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(
ds_builder.as_dataset(split="train", batch_size=-1))
test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))

# Normalize data
train_images, train_labels = train_ds['image'], train_ds['label']
test_images, test_labels = test_ds['image'], test_ds['label']

train_images = jnp.float32(train_images) / 255.0
test_images = jnp.float32(test_images) / 255.0

Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mnist/incomplete.S2II7R_3.0.1/mnist-train.tfrecord*...:   0%|          | 0…

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mnist/incomplete.S2II7R_3.0.1/mnist-test.tfrecord*...:   0%|          | 0/…

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


In [10]:
model = Net()
rngs = {'params': key, 'dropout': key}
params = model.init(
    rngs, train_images[0:1])
tx = optax.adam(0.001)
opt_state = tx.init(params)

In [11]:
num_epochs = 5
batch_size = 32
start_time = time.time()
# Training loop
for epoch in range(num_epochs):
    for i in range(0, len(train_images), batch_size):
        batch = {
            'image': train_images[i:i+batch_size],
            'label': train_labels[i:i+batch_size],
        }

        # Split key for dropout
        key, dropout_key = random.split(key)

        # Compute loss
        logits = model.apply(params, batch['image'],
                  training=True, dropout_key=dropout_key)

        one_hot = jax.nn.one_hot(batch['label'], 10)
        loss = jnp.mean(optax.softmax_cross_entropy(
            logits=logits, labels=one_hot))

        # Compute gradients
        grads = grad(lambda p, i, l, k: jnp.mean(optax.softmax_cross_entropy(logits=model.apply(
            p, i, training=True, dropout_key=k), labels=jax.nn.one_hot(l, 10))))(params, batch['image'], batch['label'], dropout_key)

        # Update parameters
        updates, opt_state = tx.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        if i % (batch_size * 100) == 0:
            print(f"Epoch {epoch + 1}, Loss: {loss}")

end_time = time.time()
print(f"Training time: {end_time - start_time} seconds")

Epoch 1, Loss: 2.3075971603393555
Epoch 1, Loss: 0.4712601602077484
Epoch 1, Loss: 0.15664419531822205
Epoch 1, Loss: 0.10837571322917938
Epoch 1, Loss: 0.21009168028831482
Epoch 1, Loss: 0.24175487458705902
Epoch 1, Loss: 0.027780739590525627
Epoch 1, Loss: 0.05333751440048218
Epoch 1, Loss: 0.2942686080932617
Epoch 1, Loss: 0.07526043057441711
Epoch 1, Loss: 0.1585114449262619
Epoch 1, Loss: 0.1525411456823349
Epoch 1, Loss: 0.017257004976272583
Epoch 1, Loss: 0.0840107873082161
Epoch 1, Loss: 0.18193233013153076
Epoch 1, Loss: 0.03528980165719986
Epoch 1, Loss: 0.3647249937057495
Epoch 1, Loss: 0.12711209058761597
Epoch 1, Loss: 0.07941868156194687
Epoch 2, Loss: 0.01175493374466896
Epoch 2, Loss: 0.15552851557731628
Epoch 2, Loss: 0.011898327618837357
Epoch 2, Loss: 0.02475767210125923
Epoch 2, Loss: 0.014384630136191845
Epoch 2, Loss: 0.12825539708137512
Epoch 2, Loss: 0.007483629044145346
Epoch 2, Loss: 0.010263720527291298
Epoch 2, Loss: 0.012538439594209194
Epoch 2, Loss: 0.051

In [12]:
test_logits = model.apply(params, test_images, training=False)
test_predictions = jnp.argmax(test_logits, axis=-1)
accuracy = jnp.mean(test_predictions == test_labels)
print(f"Test accuracy: {accuracy * 100:.2f}%")

Test accuracy: 99.25%
