In [None]:
# Installs
!pip install --upgrade -q pip jax jaxlib
!pip install --upgrade -q git+https://github.com/google/flax.git    
!pip install --upgrade -q git+https://github.com/rolandgvc/flaxvision.git

In [None]:
# General imports
import jax
import jax.numpy as jnp
import numpy as np

from flax import linen as nn
from flax import optim
from flaxvision import models
from torchvision import datasets

In [None]:
# Load dataset from torchvision into memory
train_ds = datasets.MNIST('./data', train=True, download=True)
test_ds = datasets.MNIST('./data', train=False, download=True)
train_ds = {'image': np.expand_dims(train_ds.data.numpy(), 3), 
            'label': train_ds.targets.numpy()}
test_ds = {'image': np.expand_dims(test_ds.data.numpy(), 3), 
           'label': test_ds.targets.numpy()}
train_ds['image'] = jnp.float32(train_ds['image']) / 255.
test_ds['image'] = jnp.float32(test_ds['image']) / 255.

In [None]:
# Instantiate pretrained model
RNG = jax.random.PRNGKey(0)
vgg, vgg_params = models.vgg16(RNG, pretrained=True)

#TODO: test with an image from dataset
batch = jnp.ones((1, 224, 224, 3))
out = vgg.apply(vgg_params, batch, mutable=False)

In [None]:
# Define backbone instantization as a lambda function
vgg_backbone = lambda: models.VGG.make_backbone(vgg)

In [None]:
# Define new model
from flax import linen as nn

class Classifier(nn.Module):
  dtype: Any = jnp.float32

  @nn.compact
  def __call__(self, inputs, train: bool = False):
    x = nn.Dense(2048, dtype=self.dtype)(inputs)
    x = nn.relu(x)
    x = nn.Dropout(rate=0.5)(x, deterministic=not train)
    x = nn.Dense(2048, dtype=self.dtype)(x)
    x = nn.relu(x)
    x = nn.Dropout(rate=0.5)(x, deterministic=not train)
    x = nn.Dense(10, dtype=self.dtype)(x)
    return x


class MyModel(nn.Module):

  def setup(self):
    self.backbone = vgg_backbone()
    self.classifier = Classifer()

  def __call__(self, inputs, train: bool = False)
    x = self.backbone(inputs, train=False)
    x = x.transpose((0, 3, 1, 2))
    x = x.reshape((x.shape[0], -1))
    x = self.classifier(x, train)
    return x    

In [None]:
# Setup training loop
def get_initial_params(key):
  init_shape = jnp.ones((1, 224, 224, 3), jnp.float32)
  initial_params = MyModel().init(key, init_shape)['params']
  return initial_params


def create_optimizer(params, learning_rate, beta):
  optimizer_def = optim.Momentum(learning_rate=learning_rate, beta=beta)
  optimizer = optimizer_def.create(params)
  return optimizer


def onehot(labels, num_classes=10):
  x = (labels[..., None] == jnp.arange(num_classes)[None])
  return x.astype(jnp.float32)


def cross_entropy_loss(logits, labels):
  return -jnp.mean(jnp.sum(onehot(labels) * logits, axis=-1))


def compute_metrics(logits, labels):
  loss = cross_entropy_loss(logits, labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics


@jax.jit
def train_step(optimizer, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits, batch['label'])
    return loss, logits
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grad = grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  metrics = compute_metrics(logits, batch['label'])
  return optimizer, metrics


@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits, batch['label'])


def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm] for k, v in train_ds.items()}
    optimizer, metrics = train_step(optimizer, batch)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: onp.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
               epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100)

  return optimizer, epoch_metrics_np


def eval_model(model, test_ds):
  metrics = eval_step(model, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']

In [None]:
# TODO: Run training loop

# summary_writer = tensorboard.SummaryWriter()

rng, init_rng = jax.random.split(PRNG)
params = get_initial_params(init_rng)
optimizer = create_optimizer(params, learning_rate, momentum)

for epoch in range(1, num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    optimizer, train_metrics = train_epoch(optimizer, train_ds, batch_size, 
                                           epoch, input_rng)
    loss, accuracy = eval_model(optimizer.target, test_ds)

    summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
    summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch)
    summary_writer.scalar('eval_loss', loss, epoch)
    summary_writer.scalar('eval_accuracy', accuracy, epoch)

    summary_writer.flush()
    
    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:
      state = sync_batch_stats(state)
      step = int(state.step)
      checkpoints.save_checkpoint(workdir, state, step, keep=3)

In [None]:
# Load from checkpoint and inference