In [3]:
!git clone https://github.com/umepy/famous_model_with_jax_flax.git
%cd famous_model_with_jax_flax
!git checkout -b alexnet origin/alexnet

Cloning into 'famous_model_with_jax_flax'...
remote: Enumerating objects: 134, done.[K
remote: Counting objects: 100% (134/134), done.[K
remote: Compressing objects: 100% (77/77), done.[K
remote: Total 134 (delta 48), reused 119 (delta 36), pack-reused 0[K
Receiving objects: 100% (134/134), 17.65 KiB | 1.04 MiB/s, done.
Resolving deltas: 100% (48/48), done.
/workspaces/famous_model_with_jax_flax/CVFlax/colab/famous_model_with_jax_flax/famous_model_with_jax_flax
Branch 'alexnet' set up to track remote branch 'alexnet' from 'origin'.
Switched to a new branch 'alexnet'


In [4]:
!pip install -U -q pip
!pip install -U -q jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install -U -q jaxlib optax
!pip install -U -q git+https://github.com/google/flax.git
!pip install -U -q torch torchvision

[31mERROR: Could not find a version that satisfies the requirement jaxlib==0.3.15+cuda11.cudnn805; extra == "cuda11_cudnn805" (from jax[cuda11_cudnn805]) (from versions: 0.1.32, 0.1.40, 0.1.41, 0.1.42, 0.1.43, 0.1.44, 0.1.46, 0.1.50, 0.1.51, 0.1.52, 0.1.55, 0.1.56, 0.1.57, 0.1.58, 0.1.59, 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75, 0.1.76, 0.3.0, 0.3.2, 0.3.5, 0.3.7, 0.3.8, 0.3.10, 0.3.14, 0.3.15)[0m
[31mERROR: No matching distribution found for jaxlib==0.3.15+cuda11.cudnn805; extra == "cuda11_cudnn805" (from jax[cuda11_cudnn805])[0m
[31m  ERROR: Command errored out with exit status 128:
   command: git clone -q https://github.com/google/flax.git /tmp/pip-req-build-fo20z9sy
       cwd: None
  Complete output (2 lines):
  fatal: Unable to read current working directory: No such file or directory
  fatal: remote did not send all necessary objects
  ----------------------------------------[0m
[31mER

In [1]:
import os,sys
from pathlib import Path

sys.path.append(str(Path(os.path.abspath("__file__")).parent.parent.parent))

import jax
import jax.numpy as jnp
import jax.tools.colab_tpu

from tqdm import tqdm

from flax import linen as nn
from flax.training import train_state, checkpoints

import numpy as np
import optax

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from CVFlax.models import AlexNet
from CVFlax.utils.preprocess import alexnet_dataloader, download_food101

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
download_food101()

seems already being downloaded, so skipping to split dataset


In [3]:
def create_train_state(key, learning_rate):
    params = AlexNet(output_dim=101).init(key, jnp.ones([1,227,227,3]))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=AlexNet(output_dim=101).apply, params=params, tx=tx)

In [38]:
@jax.jit
def compute_accuracy(logits, y):
  accuracy = jnp.mean(jnp.argmax(logits, -1) == y)
  return accuracy


@jax.jit
def train_step(state, x, y):
  def loss_fn(params):
    logits = AlexNet(output_dim=101).apply({'params':params}, x)
    one_hot_labels = jax.nn.one_hot(y, num_classes=101)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels))
    return loss, logits
  
  (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = {
      'loss': loss,
      'accuracy': compute_accuracy(logits, y),
  }
  return state, metrics


@jax.jit
def eval_step(state, x, y):
  # transform x because we used FiveCrops
  bs,ncrops,h,w,c = x.shape
  x = x.reshape((-1,h,w,c))
  logits = AlexNet(output_dim=101).apply({'params':state.params}, x)
  return jnp.mean(jnp.argmax(logits.reshape(bs,ncrops, -1).mean(1), -1) == y)

In [42]:
def train_epoch(state, dataloader, global_step, writer):
  batch_metrics = []
  track_metrics = {'loss':[], 'accuracy':[]}
  with tqdm(total=len(dataloader)) as tq:
    for cnt, (x, y) in enumerate(dataloader):
      tq.update(1)
      x,y=jnp.array(x),jnp.array(y)
      state, metrics = train_step(state, x, y) # update state 
      batch_metrics.append(metrics)
      track_metrics['loss'].append(metrics['loss'])
      track_metrics['accuracy'].append(metrics['accuracy'])

      writer.add_scalar('train/loss', metrics['loss'].item(), global_step)
      writer.add_scalar('train/accuracy', metrics['accuracy'].item(), global_step)
      global_step += 1
  
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]
  }
  return state, epoch_metrics_np

def evaluate_epoch(state, dataloader, global_step, writer):
  all_score = []
  for x, y in dataloader:
    precision = eval_step(state, x, y)
    all_score.append(precision.item())
  all_precision = np.mean(all_score)
  writer.add_scalar('test/accuracy', all_precision, global_step)
  return all_precision

In [4]:
from google.colab import drive
drive.mount('/content/drive')

# paste path where you want to save model checkpoints
save_path = '/content/drive/My Drive/'

# difine tensorboard writer
writer = SummaryWriter()

ModuleNotFoundError: No module named 'google.colab'

In [43]:
learning_rate = 0.001
num_epochs = 100
batch_size = 128
global_step = 1
key = jax.random.PRNGKey(20220319)
state = create_train_state(key, learning_rate)
train_loader, test_loader = alexnet_dataloader(batch_size=batch_size)
test_scores = []

for epoch in range(1, num_epochs + 1):
  state, train_metrics = train_epoch(state, train_loader, global_step, writer)
  print(f"Train epoch: {epoch}, loss: {train_metrics['loss']:.4}, accuracy: {train_metrics['accuracy'] * 100:.4}")

  test_accuracy = evaluate_epoch(state, test_loader, global_step, writer)
  print(f"Test epoch: {epoch}, accuracy: {test_accuracy * 100:.4}")

  if test_scores==[] or test_accuracy > test_scores[-1]:
    checkpoints.save_checkpoint(ckpt_dir=save_path, target=state, step=epoch, prefix='Alexnet_checkpoint_epoch_')
  test_scores.append(test_accuracy)

Test epoch: 1, accuracy: 1.006


KeyboardInterrupt: 

In [6]:
learning_rate = 0.001
num_epochs = 100
batch_size = 128
global_step = 1
key = jax.random.PRNGKey(20220319)
state = create_train_state(key, learning_rate)
train_loader, test_loader = alexnet_dataloader(batch_size=batch_size)
for x,y in test_loader:
    break