In [None]:
!git clone https://github.com/umepy/famous_model_with_jax_flax.git
%cd famous_model_with_jax_flax
!git checkout -b alexnet origin/alexnet

In [None]:
!pip install -U -q pip optax
!pip install -U -q git+https://github.com/google/flax.git
!pip install -U torch torchvision

In [1]:
import os,sys
from pathlib import Path

sys.path.append(str(Path(os.path.abspath("__file__")).parent.parent.parent))

import jax
import jax.numpy as jnp
import jax.tools.colab_tpu

from tqdm import tqdm

from flax import linen as nn
from flax.training import train_state

import numpy as np
import optax

from torch.utils.data import DataLoader
from CVFlax.models import AlexNet
from CVFlax.utils.preprocess import alexnet_dataloader, download_food101

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
download_food101()

seems already splitted, so skipping to split dataset


In [3]:
def create_train_state(key, learning_rate):
    model = AlexNet(output_dim=101)
    params = model.init(key, jnp.ones([1,227,227,3]))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [21]:
def compute_accuracy(logits, y):
  accuracy = jnp.mean(jnp.argmax(logits, -1) == y)
  return accuracy


@jax.jit
def train_step(state, x, y):
  def loss_fn(params):
    logits = AlexNet(output_dim=101).apply({'params':params}, x)
    one_hot_labels = jax.nn.one_hot(y, num_classes=101)
    loss = -jnp.mean(jnp.sum(one_hot_labels * logits, axis=-1))
    return loss, logits
  
  (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = {
      'loss': loss,
      'accuracy': compute_accuracy(logits, y),
  }
  return state, metrics


@jax.jit
def eval_step(state, x, y):
  logits = AlexNet().apply({'params':state.params}, x)
  return compute_accuracy(logits, y)

In [22]:
def train_epoch(state, dataloader, epoch):
  batch_metrics = []
  track_metrics = {'loss':[], 'accuracy':[]}
  with tqdm(total=len(dataloader)) as tq:
    for cnt, (x, y) in enumerate(dataloader):
      tq.update(1)
      state, metrics = train_step(state, x, y) # update state 
      batch_metrics.append(metrics)
      print(metrics)
      track_metrics['loss'].append(metrics['loss'])
      track_metrics['accuracy'].append(metrics['accuracy'])

      if cnt%20==0:
        print(f'Epoch: {cnt}\tloss: {np.mean(track_metrics["loss"])}\taccuracy: {np.mean(track_metrics["accuracy"])}')
        track_metrics = {'loss':[], 'accuracy':[]}
  
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]
  }
  return state, epoch_metrics_np


def evaluate_model(state, x, y):
  metrics = eval_step(state, x, y)
  metrics = jax.device_get(metrics)
  metrics = jax.tree_map(lambda x: x.item(), metrics)  # np.ndarray -> scalar
  return metrics

In [23]:
learning_rate = 0.01
num_epochs = 10
batch_size = 128
key = jax.random.PRNGKey(0)
state = create_train_state(key, learning_rate)
train_loader, test_loader = alexnet_dataloader(batch_size)

for epoch in range(1, num_epochs + 1):
  state, train_metrics = train_epoch(state, train_loader, epoch)
  print(f"Train epoch: {epoch}, loss: {train_metrics['loss']:.4}, accuracy: {train_metrics['accuracy'] * 100:.4}")


  #test_metrics = eval_step(state, test_images, test_lbls)
  #print(f"Test epoch: {epoch}, accuracy: {test_metrics * 100:.4}")

TypeError: __init__() got an unexpected keyword argument 'output_dim'