In [None]:
%cd ..

In [None]:
from functools import partial

import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import ml_collections
from tensorflow_datasets import Split
import wandb

from src.dataset import load_mnist
from src.models import MnistCNN
from src import training_utils as utils

In [None]:
# paper params
cfg = ml_collections.ConfigDict()
cfg.learning_rate = 1.07e-1
cfg.batch_size = 256
cfg.epochs = 40
cfg.model_params = {
    "k_filters": 32,
    "activation": nn.relu
}

In [None]:
train_ds = load_mnist(cfg.batch_size)
train_ds = list(train_ds.as_numpy_iterator())
test_ds = load_mnist(cfg.batch_size, split=Split.TEST, shuffle=False)
test_ds = next(
    test_ds
    .unbatch()
    .batch(60_000)
    .as_numpy_iterator()
)

In [None]:
rng = jax.random.PRNGKey(42)
rng, init_rng = jax.random.split(rng)

In [None]:
model = MnistCNN(**cfg.model_params)
tx = optax.sgd(
    learning_rate=cfg.learning_rate
)

state = utils.create_train_state(model, tx, rng)

In [None]:
wandb.init(project="making_the_shoe_fit", entity="shpotes", name="test")
wandb.config = dict(cfg)

In [None]:
for epoch in range(1, cfg.epochs + 1):
    for batch in train_ds:
        state, training_metrics = utils._train_step(model, state, batch)
        
        wandb.log(training_metrics)
        break
        
    test_metrics = utils._test_step(model, state.params, test_ds)
    wandb.log(test_metrics)

    break