In [34]:
import kagglehub
import pandas
import numpy
import jax
import jax.numpy as jnp
import jax.random as jrnd
import flax
import flax.linen as nn
import optax
import jax.nn as jnn
from flax.training.train_state import TrainState

path = kagglehub.dataset_download("oddrationale/mnist-in-csv")

path_train = path + "/mnist_train.csv"
path_test = path + "/mnist_test.csv"

df_train = pandas.read_csv(path_train)
df_test = pandas.read_csv(path_test)

TRAIN_SIZE = 60000
TEST_SIZE = 10000
IMG_SIZE = (28, 28)
LR = 0.006
NUM_CLASSES = 10
BATCH_SIZE = 500
KEY = jrnd.key(0)

In [42]:
X_train = jnp.array(df_train.iloc[:, 1:].to_numpy()).reshape((TRAIN_SIZE,)+IMG_SIZE+(1,))
Y_train = jnn.one_hot(jnp.array(df_train.iloc[:, 0].to_numpy()), num_classes=NUM_CLASSES)
X_train = X_train / 255.0

X_test = jnp.array(df_test.iloc[:, 1:].to_numpy()).reshape((TEST_SIZE,)+IMG_SIZE+(1,))
Y_test = jnn.one_hot(jnp.array(df_test.iloc[:, 0].to_numpy()), num_classes=NUM_CLASSES)
X_test = X_test / 255.0

In [44]:
class Model(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.leaky_relu(nn.Conv(features=10, kernel_size=(5,5))(x))
        x = nn.leaky_relu(nn.Conv(features=40, kernel_size=(3,3))(x))
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))
        x = nn.leaky_relu(nn.Conv(features=20, kernel_size=(3,3))(x))
        x = nn.max_pool(x, window_shape=(2,2), strides=(2,2))
        x = nn.leaky_relu(nn.Conv(features=4, kernel_size=(3,3))(x))
        x = x.reshape(x.shape[0], -1)
        x = nn.leaky_relu(nn.Dense(features=45)(x))
        x = nn.Dense(features=NUM_CLASSES)(x)
        return x

In [None]:
def create_train_state(model, subkey, lr=LR):
    optim = optax.adam(learning_rate=lr)
    params = model.init(subkey, jnp.zeros((1,)+IMG_SIZE+(1,)))
    return TrainState.create(apply_fn=jax.jit(model.apply), params=params, tx = optim)

@jax.jit
def step(model_state: TrainState, inputs: jnp.ndarray, targets: jnp.ndarray):
    def loss_fn(params):
        predictions = model_state.apply_fn(params, inputs)
        return optax.softmax_cross_entropy(predictions,targets).mean()
    loss, grads = jax.value_and_grad(loss_fn)(model_state.params)
    model_state = model_state.apply_gradients(grads=grads)
    return model_state, loss

KEY, subkey = jrnd.split(KEY)

In [47]:
model = Model()
model_state = create_train_state(model, subkey)
def param_dist(params1, params2):
    return sum([param_dist(el, params2[name]) if isinstance(el, dict) else (el-params2[name]).sum() for name, el in params1.items()])
EPOCHS = 10
def train():
    global model_state, KEY
    for epoch in range(EPOCHS):
        # Shuffle data each epoch
        KEY, subkey = jrnd.split(KEY)
        perm = jrnd.permutation(subkey, TRAIN_SIZE)
        X_shuffled = X_train[perm]
        Y_shuffled = Y_train[perm]
        
        for batch_idx in range(TRAIN_SIZE // BATCH_SIZE):
            X = X_shuffled[batch_idx*BATCH_SIZE:(batch_idx+1)*BATCH_SIZE]
            Y = Y_shuffled[batch_idx*BATCH_SIZE:(batch_idx+1)*BATCH_SIZE]
            
            model_state, loss = step(model_state, X, Y)
            
            if (batch_idx+1) % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {loss:.4f}")
train()

Epoch 1, Batch 10, Loss: 0.9563
Epoch 1, Batch 20, Loss: 0.5635
Epoch 1, Batch 30, Loss: 0.3307
Epoch 1, Batch 40, Loss: 0.2079
Epoch 1, Batch 50, Loss: 0.1551
Epoch 1, Batch 60, Loss: 0.1227
Epoch 1, Batch 70, Loss: 0.1628
Epoch 1, Batch 80, Loss: 0.1578
Epoch 1, Batch 90, Loss: 0.0967
Epoch 1, Batch 100, Loss: 0.1172
Epoch 1, Batch 110, Loss: 0.1112
Epoch 1, Batch 120, Loss: 0.1112


KeyboardInterrupt: 