In [1]:
import time
from tqdm import tqdm

# torch
import torch
import torch.nn as nn
import torch.optim as optim

# jax
import jax
import jax.numpy as jnp
import flax.linen as fnn
from flax.training import train_state
import optax

# dataset
import tensorflow_datasets as tfds
import tensorflow as tf

physical_devices = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)
tf.config.experimental.set_memory_growth(physical_devices[1], True)

In [2]:
# Dataloader (common)

def preprocessing(x, y):
    x = tf.cast(x, tf.float32) / 255.
    y = tf.one_hot(y, 10)
    
    return x, y

ds = tfds.load("cifar10", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)

In [3]:
# jax

class CNN(fnn.Module):
    @fnn.compact
    def __call__(self, x):
        x = fnn.Conv(features=32, kernel_size=(3, 3))(x)
        x = fnn.relu(x)
        x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = fnn.Conv(features=64, kernel_size=(3, 3))(x)
        x = fnn.relu(x)
        x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = fnn.Dense(features=256)(x)
        x = fnn.relu(x)
        x = fnn.Dense(features=10)(x)
        x = fnn.log_softmax(x)
        
        return x

def step(state, batch, is_train):
    x, y = [jnp.array(v) for v in batch]
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, x)
        loss = optax.softmax_cross_entropy(logits=logits, labels=y).mean()
        return loss, logits
    if is_train:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, logits), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
    else:
        loss, logits = loss_fn(state.params)
    acc = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(y, axis=-1))
    
    return loss, acc, state

@jax.jit
def train_step(state, batch):
    return step(state, batch, is_train=True)

@jax.jit
def eval_step(state, batch):
    return step(state, batch, is_train=False)

model = CNN()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 32, 32, 3]))['params']
tx = optax.adam(0.001)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

for e in range(5):
    tic = time.time()
    train_loss, val_loss, val_acc = 0, 0, 0
    for batch in tqdm(train_set.as_numpy_iterator()):
        loss, acc, state = train_step(state, batch)
        train_loss += loss
    train_loss /= len(train_set)
    for batch in tqdm(val_set.as_numpy_iterator()):
        loss, acc, state = eval_step(state, batch)
        val_loss += loss
        val_acc += acc
    val_loss /= len(val_set)
    val_acc /= len(val_set)
    elapsed = time.time() - tic
    print(f"epoch: {e} | train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {val_acc:0.2f}, elapsed: {elapsed:0.2f}")

1563it [00:07, 215.43it/s]
313it [00:01, 198.16it/s]


epoch: 0 | train_loss: 1.32, val_loss: 1.05, val_acc: 0.63, elapsed: 8.99


1563it [00:03, 519.37it/s]
313it [00:00, 711.09it/s]


epoch: 1 | train_loss: 0.94, val_loss: 0.92, val_acc: 0.68, elapsed: 3.51


1563it [00:03, 517.89it/s]
313it [00:00, 735.16it/s]


epoch: 2 | train_loss: 0.79, val_loss: 0.86, val_acc: 0.71, elapsed: 3.70


1563it [00:03, 510.90it/s]
313it [00:00, 715.82it/s]


epoch: 3 | train_loss: 0.66, val_loss: 0.85, val_acc: 0.72, elapsed: 3.75


1563it [00:03, 516.97it/s]
313it [00:00, 731.86it/s]

epoch: 4 | train_loss: 0.55, val_loss: 0.87, val_acc: 0.72, elapsed: 5.60





In [4]:
# pytorch

device = "cuda:1" if torch.cuda.is_available() else "cpu"

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(4096, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        h = self.conv1(x)
        h = torch.relu(h)
        h = torch.max_pool2d(h, (2, 2))
        h = self.conv2(h)
        h = torch.relu(h)
        h = torch.max_pool2d(h, (2, 2))
        h = h.reshape(len(h), -1)
        h = self.fc1(h)
        h = torch.relu(h)
        h = self.fc2(h)
        y = torch.log_softmax(h, -1)
        return y


model = CNN().to(device)
opt = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()

def step(model, batch, is_train):
    x, y = [torch.from_numpy(v) for v in batch]
    x = x.permute(0, 3, 1, 2).to(device)
    y = y.max(-1)[1].to(device)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    if is_train:
        opt.zero_grad()
        loss.backward()
        opt.step()
    acc = torch.mean((y_pred.max(-1)[1] == y).float())
    return loss, acc

def train_step(model, batch):
    model.train()
    return step(model, batch, is_train=True)

def eval_step(model, batch):
    model.eval()
    with torch.no_grad():
        return step(model, batch, is_train=False)

for e in range(5):
    tic = time.time()
    train_loss, val_loss, val_acc = 0, 0, 0
    for batch in tqdm(train_set.as_numpy_iterator()):
        loss, acc = train_step(model, batch)
        train_loss += loss
    train_loss /= len(train_set)
    for batch in tqdm(val_set.as_numpy_iterator()):
        loss, acc = eval_step(model, batch)
        val_loss += loss
        val_acc += acc
    val_loss /= len(val_set)
    val_acc /= len(val_set)
    elapsed = time.time() - tic
    print(f"epoch: {e} | train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {val_acc:0.2f}, elapsed: {elapsed:0.2f}")

  x, y = [torch.from_numpy(v) for v in batch]
1563it [00:04, 358.42it/s]
313it [00:00, 802.69it/s]


epoch: 0 | train_loss: 1.38, val_loss: 1.10, val_acc: 0.61, elapsed: 5.04


1563it [00:04, 356.24it/s]
313it [00:00, 762.60it/s]


epoch: 1 | train_loss: 1.01, val_loss: 1.00, val_acc: 0.65, elapsed: 5.08


1563it [00:04, 366.97it/s]
313it [00:00, 781.58it/s]


epoch: 2 | train_loss: 0.85, val_loss: 0.92, val_acc: 0.68, elapsed: 4.96


1563it [00:04, 363.06it/s]
313it [00:00, 901.76it/s]


epoch: 3 | train_loss: 0.72, val_loss: 0.87, val_acc: 0.70, elapsed: 5.81


1563it [00:04, 347.23it/s]
313it [00:00, 726.20it/s]


epoch: 4 | train_loss: 0.61, val_loss: 0.89, val_acc: 0.70, elapsed: 5.20
