In [None]:
import jax
from flax import linen as nn
import jax.numpy as jnp
import optax
from flaxmodels import ResNet18
from typing import Callable
import tensorflow
import tqdm

from optimizers import SGD
from dataloader import get_cifar10_dataloaders, tf_to_jax

In [None]:
def init_model(rng: jax.random.PRNGKey, init_batch: tuple[jnp.ndarray, jnp.ndarray]) -> tuple[nn.Module, dict]:
    model = ResNet18(output="logits",
                     pretrained=False,
                     num_classes=10)
    variables = model.init(rng, init_batch[0])
    return model, variables

In [None]:
def get_cross_entropy_loss_fn(model: nn.Module) -> Callable:
    def loss_fn(variables: dict, batch: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
        images, labels = batch
        logits, new_batch_stats = model.apply(variables,
                                              images,
                                              mutable=["batch_stats"])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
        return loss, new_batch_stats

    return loss_fn

In [None]:
def get_compute_accuracy_fn(model: nn.Module) -> Callable:
    def compute_accuracy(variables: dict, batch: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
        images, labels = batch
        logits = model.apply(variables, images, train=False)
        return jnp.mean(logits.argmax(axis=1) == labels)
    return compute_accuracy

In [None]:
def train(variables: dict,
          optimizer: SGD,
          train_loader,
          val_loader,
          num_epochs: int,
		  compute_accuracy) -> dict:
    state = optimizer.init(variables)
    params = variables["params"]
    for epoch in range(num_epochs):
        for batch in tqdm(train_loader):
            jax_batch = tf_to_jax(batch)
            loss, state = optimizer.update(state, jax_batch)
            print(f"loss: {loss}")
        
        val_acc, val_loss, val_count = 0, 0, 0
        for batch in val_loader:
            jax_batch = tf_to_jax(batch)
            acc = compute_accuracy({"params": params, "batch_stats": state.batch_stats}, jax_batch)
            val_acc += acc * batch[0].shape[0]
            val_count += batch[0].shape[0]

        print(f"Epoch {epoch + 1}: val_acc={val_acc / val_count}, val_loss={val_loss / val_count}")
    

In [None]:
rng = jax.random.PRNGKey(42)


train_loader, val_loader = get_cifar10_dataloaders(batch_size=128)
init_batch = tf_to_jax(next(iter(train_loader)))
resnet18, variables = init_model(rng, init_batch)

loss_fn = jax.jit(get_cross_entropy_loss_fn(resnet18))
compute_accuracy = jax.jit(get_compute_accuracy_fn(resnet18))

optimizer = SGD(loss_fn, lr=1e-4)

train(variables=variables,
      optimizer=optimizer,
      train_loader=train_loader,
      val_loader=val_loader,
      num_epochs=10,
      loss_fn=loss_fn,
      compute_accuracy=compute_accuracy)