In [1]:
import time

# torch
import torch
import torch.nn as nn
import torch.optim as optim

# jax
import jax
import jax.numpy as jnp
import flax.linen as fnn
from flax.training import train_state
import optax

# dataset
import tensorflow_datasets as tfds
import tensorflow as tf


In [2]:
# Dataloader (common)

def preprocessing(x, y):
    x = tf.reshape(tf.cast(x, tf.float32), (-1, 28 * 28)) / 255.
    y = tf.one_hot(y, 10)
    
    return x, y

ds = tfds.load("mnist", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)

In [3]:
# jax

class MLP(fnn.Module):
    @fnn.compact
    def __call__(self, x):
        x = fnn.Dense(64)(x)
        x = fnn.relu(x)
        x = fnn.Dense(10)(x)
        x = fnn.log_softmax(x)
        
        return x

def step(state, batch, is_train):
    x, y = [jnp.array(v) for v in batch]
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, x)
        loss = optax.softmax_cross_entropy(logits=logits, labels=y).mean()
        return loss, logits
    if is_train:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
    else:
        loss, logits = loss_fn(state.params)
    acc = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(y, axis=-1))
    
    return loss, acc, state

@jax.jit
def train_step(state, batch):
    return step(state, batch, is_train=True)

@jax.jit
def eval_step(state, batch):
    return step(state, batch, is_train=False)

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 28 * 28]))['params']
tx = optax.adam(0.001)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

for e in range(5):
    tic = time.time()
    train_loss, val_loss, val_acc = 0, 0, 0
    for batch in train_set.as_numpy_iterator():
        loss, acc, state = train_step(state, batch)
        train_loss += loss
    train_loss /= len(train_set)
    for batch in val_set.as_numpy_iterator():
        loss, acc, state = eval_step(state, batch)
        val_loss += loss
        val_acc += acc
    val_loss /= len(val_set)
    val_acc /= len(val_set)
    elapsed = time.time() - tic
    print(f"epoch: {e} | train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {val_acc:0.2f}, elapsed: {elapsed:0.2f}")



epoch: 0 | train_loss: 0.32, val_loss: 0.18, val_acc: 0.95, elapsed: 4.52
epoch: 1 | train_loss: 0.16, val_loss: 0.14, val_acc: 0.96, elapsed: 1.66
epoch: 2 | train_loss: 0.11, val_loss: 0.11, val_acc: 0.97, elapsed: 1.64
epoch: 3 | train_loss: 0.09, val_loss: 0.09, val_acc: 0.97, elapsed: 1.60
epoch: 4 | train_loss: 0.07, val_loss: 0.09, val_acc: 0.97, elapsed: 1.71


In [4]:
# pytorch

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(),
                                 nn.Linear(64, 10), nn.LogSoftmax(-1))
        
    def forward(self, x):
        return self.net(x)


model = MLP()
opt = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()

def step(model, batch, is_train):
    x, y = [torch.from_numpy(v) for v in batch]
    y = y.max(-1)[1]
    y_pred = model(x)
    loss = criterion(y_pred, y)
    if is_train:
        opt.zero_grad()
        loss.backward()
        opt.step()
    acc = torch.mean((y_pred.max(-1)[1] == y).float())
    return loss, acc

def train_step(model, batch):
    model.train()
    return step(model, batch, is_train=True)

def eval_step(model, batch):
    model.eval()
    with torch.no_grad():
        return step(model, batch, is_train=False)

for e in range(5):
    tic = time.time()
    train_loss, val_loss, val_acc = 0, 0, 0
    for batch in train_set.as_numpy_iterator():
        loss, acc = train_step(model, batch)
        train_loss += loss
    train_loss /= len(train_set)
    for batch in val_set.as_numpy_iterator():
        loss, acc = eval_step(model, batch)
        val_loss += loss
        val_acc += acc
    val_loss /= len(val_set)
    val_acc /= len(val_set)
    elapsed = time.time() - tic
    print(f"epoch: {e} | train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {val_acc:0.2f}, elapsed: {elapsed:0.2f}")

  x, y = [torch.from_numpy(v) for v in batch]


epoch: 0 | train_loss: 0.35, val_loss: 0.18, val_acc: 0.95, elapsed: 2.45
epoch: 1 | train_loss: 0.16, val_loss: 0.13, val_acc: 0.96, elapsed: 2.33
epoch: 2 | train_loss: 0.12, val_loss: 0.11, val_acc: 0.97, elapsed: 2.44
epoch: 3 | train_loss: 0.09, val_loss: 0.10, val_acc: 0.97, elapsed: 2.48
epoch: 4 | train_loss: 0.08, val_loss: 0.09, val_acc: 0.97, elapsed: 2.33
