In [1]:
import jax
import jax.numpy as jnp
from jax import nn
from jax import random
from jax import tree_util
import torchvision

In [2]:
from mlax.losses import categorical_crossentropy
from mlax.nn.blocks import Linear 
from mlax.optim import apply, sgd

In [3]:
mnist_train = torchvision.datasets.MNIST(
    root="data",
    train=True,
    download=False,
    transform=torchvision.transforms.ToTensor()
)
mnist_test = torchvision.datasets.MNIST(
    root="data",
    train=False,
    download=False,
    transform=torchvision.transforms.ToTensor()
)

In [4]:
X_train, y_train = jnp.asarray(mnist_train.data), jnp.asarray(mnist_train.targets)
X_test, y_test = jnp.asarray(mnist_test.data), jnp.asarray(mnist_test.targets)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)



(60000, 28, 28) (60000,)
(10000, 28, 28) (10000,)


In [5]:
def batch(x, y, batch_size, k, dtype="float32"):
  x = x.astype(dtype)
  y = y.astype(dtype) 
  batched_x, batched_y = [], []
  for i in range(0, len(x), batch_size):
      batched_x.append(x[i:i+batch_size])
      batched_y.append(y[i:i+batch_size])
  return batched_x, batched_y

batch_size = 64
X_train, y_train = batch(X_train, y_train, 64, 10)
X_test, y_test = batch(X_test, y_test, 64, 10)
print(len(X_train), len(y_train))
print(len(X_test), len(y_test))

938 938
157 157


In [6]:
def model_init(key):
    key1, key2, key3 = random.split(key, 3)
    w1 = Linear.init(key1, 28*28, 512)
    w2 = Linear.init(key2, 512, 512)
    w3 = Linear.init(key3, 512, 10)
    return [w1, w2, w3]

model_weights = model_init(random.PRNGKey(43))

In [13]:
def model_fwd(x, weights):
    x = jnp.reshape(x, (-1, ))
    w1, w2, w3 = weights
    x = Linear.fwd(x, w1, nn.relu)
    x = Linear.fwd(x, w2, nn.relu)
    x = Linear.fwd(x, w3, nn.softmax)
    return x

batched_model_fwd = jax.jit(jax.vmap(model_fwd, in_axes=[0, None]))

@jax.jit
def batched_loss(predictions, targets):
    predictions = jnp.clip(predictions, 1e-9, 1 - 1e-9)
    targets = nn.one_hot(targets, 10)
    losses = jax.vmap(categorical_crossentropy)(
        predictions, targets
    )
    return losses.mean()

def batched_model_loss(x, y, weights):
    return batched_loss(batched_model_fwd(x, weights), y)

In [23]:
optim_state = sgd.init(model_weights)

@jax.jit
def train_step(x_batch, y_batch, model_weights, optim_state, lr):
    loss = batched_model_loss(x_batch, y_batch, model_weights)
    gradients = jax.grad(batched_model_loss, argnums=2)(
        x_batch,
        y_batch,
        model_weights
    )
    gradients, optim_state = sgd.step(gradients, optim_state, lr, momentum=0.7)
    model_weights = apply(gradients, model_weights)
    return loss, model_weights, optim_state

In [24]:
def train_epoch(X_train, y_train, model_weights, optim_state, lr):
    num_batches = len(X_train)
    train_loss = 0.0
    for i in range(num_batches):
        x_batch, y_batch = X_train[i], y_train[i]
        loss, model_weights, optim_state = train_step(
            x_batch,
            y_batch,
            model_weights,
            optim_state,
            lr
        )
        train_loss += loss

    print(f"Train loss: {train_loss / num_batches}") 
    return model_weights, optim_state

In [25]:
def test(X_test, y_test, model_weights):
    num_batches = len(X_test)
    test_loss, accuracy = 0, 0.0
    for i in range(num_batches):
        x_batch, y_batch = X_test[i], y_test[i]
        preds = batched_model_fwd(
            x_batch, model_weights
        )
        loss = batched_loss(preds, y_batch)
        test_loss += loss
        accuracy += (jnp.argmax(preds, axis=1) == y_batch).sum() / len(x_batch)
    
    print(f"Test loss: {test_loss / num_batches}, accuracy: {accuracy / num_batches}")

In [26]:
def train_loop(
    X_train, y_train,
    X_test, y_test,
    model_weights, optim_state, lr,
    epochs, test_every
):
    for i in range(epochs):
        epoch = i + 1
        print(f"Epoch {epoch}\n----------------")
        model_weights, optim_state = train_epoch(
            X_train, y_train,
            model_weights, optim_state, lr
        )
        if (epoch % test_every == 0):
            test(X_test, y_test, model_weights)
        print(f"----------------")
    
    return model_weights, optim_state

In [27]:
new_model_weights, new_optim_state = train_loop(
    X_train, y_train,
    X_test, y_test,
    model_weights, optim_state,
    8e-4,
    50, 5
)

Epoch 1
----------------
Train loss: 11.780515670776367
----------------
Epoch 2
----------------
Train loss: 7.189499855041504
----------------
Epoch 3
----------------
Train loss: 5.246441841125488
----------------
Epoch 4
----------------
Train loss: 4.975132465362549
----------------
Epoch 5
----------------
Train loss: 4.049658298492432
Test loss: 2.898390531539917, accuracy: 0.8543989062309265
----------------
Epoch 6
----------------
Train loss: 2.8120436668395996
----------------
Epoch 7
----------------
Train loss: 2.6920533180236816
----------------
Epoch 8
----------------
Train loss: 2.640648126602173
----------------
Epoch 9
----------------
Train loss: 2.5953640937805176
----------------
Epoch 10
----------------
Train loss: 1.625260353088379
Test loss: 0.8335891962051392, accuracy: 0.9519307613372803
----------------
Epoch 11
----------------
Train loss: 0.5922964215278625
----------------
Epoch 12
----------------
Train loss: 0.46516886353492737
----------------
Epoch 1