In [None]:
import jax
from flax import linen as nn
import jax.numpy as jnp
import optax
from flaxmodels import ResNet18
from typing import Callable
from tensorflow

from optimizers import SGD
from dataloader import get_cifar10_dataloaders, tf_to_jax

In [None]:
def init_model(init_batch: tuple[jnp.ndarray, jnp.ndarray], dropout_rate: float = 0.0) -> tuple[nn.Module, dict]:
    model = ResNet18(output="logits",
                     pretrained=False,
                     num_classes=10,
                     dropout_rate=dropout_rate)
    params = model.init(init_batch)["params"]
    return model, params

In [None]:
def get_cross_entropy_loss_fn(model: nn.Module) -> Callable:
    def loss_fn(params: dict, batch: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
        images, labels = batch
        logits = model.apply(params, images)
        return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

    return loss_fn

In [None]:
def get_compute_accuracy_fn(model: nn.Module) -> Callable:
    def compute_accuracy(params: dict, batch: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
        images, labels = batch
        logits = model.apply(params, images)
        return jnp.mean(logits.argmax(axis=1) == labels)
    return compute_accuracy

In [None]:
def train(params: dict,
          optimizer: SGD,
          train_loader: DataLoader,
          val_loader: DataLoader,
          num_epochs: int,
		  loss_fn,
		  compute_accuracy) -> dict:
	for epoch in range(num_epochs):
		for batch in train_loader:
			jax_batch = tf_to_jax(batch)
			params, opt_state = optimizer.update(params, jax_batch, opt_state)
		
		val_acc, val_loss, val_count = 0, 0, 0
		for batch in val_loader:
			jax_batch = tf_to_jax(batch)
			val_loss += loss_fn(params, jax_batch)
			acc = compute_accuracy(params, jax_batch)
			val_acc += acc * batch[0].shape[0]
			val_count += batch[0].shape[0]

		print(f"Epoch {epoch + 1}: val_acc={val_acc / val_count}, val_loss={val_loss / val_count}")
    

In [None]:
rng = jax.random.PRNGKey(42)


train_loader, val_loader, _ = get_cifar10_dataloaders(batch_size=128)
init_batch = next(iter(train_loader))
resnet18, params = init_model(init_batch, dropout_rate=0.0)

loss_fn = jax.jit(get_cross_entropy_loss_fn(resnet18))
compute_accuracy = jax.jit(get_compute_accuracy_fn(resnet18))

optimizer = SGD(loss_fn, lr=1e-4)

train(params=params,
      optimizer=optimizer,
      train_loader=train_loader,
      val_loader=val_loader,
      num_epochs=10,
      loss_fn=loss_fn,
      compute_accuracy=compute_accuracy)