In [64]:
import kagglehub
import pandas
import numpy
import jax
import jax.numpy as jnp
import jax.random as jrnd
import flax
import flax.linen as nn
import optax
import jax.nn as jnn
from flax.training.train_state import TrainState

path = kagglehub.dataset_download("oddrationale/mnist-in-csv")

path_train = path + "/mnist_train.csv"
path_test = path + "/mnist_test.csv"

df_train = pandas.read_csv(path_train)
df_test = pandas.read_csv(path_test)

TRAIN_SIZE = 60000
TEST_SIZE = 10000
IMG_SIZE = (28, 28)
LR = 0.003
NUM_CLASSES = 10
KEY = jrnd.key(0)

In [65]:
X_train = jnp.array(df_train.iloc[:, 1:].to_numpy()).reshape((TRAIN_SIZE,)+IMG_SIZE+(1,))
Y_train = jnn.one_hot(jnp.array(df_train.iloc[:, 0].to_numpy()), num_classes=NUM_CLASSES)

X_test = jnp.array(df_test.iloc[:, 1:].to_numpy()).reshape((TEST_SIZE,)+IMG_SIZE+(1,))
Y_test = jnn.one_hot(jnp.array(df_test.iloc[:, 0].to_numpy()), num_classes=NUM_CLASSES)

In [66]:
class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.leaky_relu(nn.Conv(features=10, kernel_size=(5,5))(x))
        x = nn.leaky_relu(nn.Conv(features=40, kernel_size=(3,3))(x))
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))
        x = nn.leaky_relu(nn.Conv(features=20, kernel_size=(3,3))(x))
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))
        x = nn.leaky_relu(nn.Conv(features=4, kernel_size=(3,3))(x))
        x = x.reshape(x.shape[0], -1)
        x = nn.leaky_relu(nn.Dense(features=45)(x))
        x = nn.leaky_relu(nn.Dense(features=NUM_CLASSES)(x))
        x = nn.softmax(x, axis=1)
        return x

In [67]:
def create_train_state(model, subkey, lr=LR):
    optim = optax.adam(learning_rate=lr)
    params = model.init(subkey, jnp.zeros((1,)+IMG_SIZE+(1,)))
    return TrainState.create(apply_fn=jax.jit(model.apply), params=params, tx = optim)

@staticmethod
@jax.jit
def step(model_state: TrainState, inputs: jnp.ndarray, targets: jnp.ndarray):
    def loss_fn(params):
        predictions = model_state.apply_fn(params, inputs)
        return ((predictions-targets)**2).mean()
    loss, grads = jax.value_and_grad(loss_fn)(model_state.params)
    model_state = model_state.apply_gradients(grads=grads)
    return model_state, loss

KEY, subkey = jrnd.split(KEY)
model = Model()
model_state = create_train_state(model, subkey)

In [68]:
for i in range(10):
    model_state, loss = step(model_state, X_train[0:100], Y_train[0:100])
    print(loss)

0.16497976
0.1583861
0.16236266
0.16789934
0.15452114
0.1554834
0.16067591
0.16123646
0.1575496
0.15792121
