<a href="https://colab.research.google.com/github/sbl1996/hanser/blob/master/examples/TPU_CIFAR10_Trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install -U git+https://github.com/sbl1996/hanser.git
# !pip uninstall pillow -y
# !pip install -U --force-reinstall pillow-simd

In [0]:
import sys
import os

import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

import tensorflow as tf
from hanser.tpu import get_colab_tpu, auth

tpu = get_colab_tpu()
use_tpu = tpu is not None
if use_tpu:
    strategy = tf.contrib.distribute.TPUStrategy(tpu)    
    auth()

In [0]:
from datetime import datetime
import math
import time
from toolz import curry

from tensorflow.keras.metrics import Mean, SparseCategoricalAccuracy

from hanser.datasets import prepare
from hanser.train import Trainer
from hanser.model.cifar import pyramidnet
from hanser.train.callbacks import cosine_lr
from hanser.transform import random_crop, cutout, normalize, to_tensor
from hanser.transform.autoaugment import autoaugment


In [0]:
def decode(example_proto):
    features = {
        'image': tf.io.FixedLenFeature((), tf.string, default_value=''),
        'label': tf.io.FixedLenFeature((), tf.int64, default_value=0),
    }
    example = tf.io.parse_single_example(example_proto, features)
    x = tf.decode_raw(example['image'], tf.uint8)
    x = tf.reshape(x, (32, 32, 3))
    y = example['label']
    return x, y


@curry
def preprocess(example, training):
    image, label = decode(example)

    if training:
        image = random_crop(image, (32, 32), (4, 4))
        image = tf.image.random_flip_left_right(image)
        image = autoaugment(image, "CIFAR10")

    image, label = to_tensor(image, label)
    image = normalize(image, [0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616])

    if training:
        image = cutout(image, 16)

    return image, label


In [0]:
train_files = !gsutil ls -r gs://hrvvi-datasets/CIFAR10/train* | cat
test_files = !gsutil ls -r gs://hrvvi-datasets/CIFAR10/test* | cat

In [0]:
num_train_examples = 45000
num_val_examples = 5000
num_test_examples = 10000
batch_size = 128 * 8
eval_batch_size = batch_size * 2
steps_per_epoch = num_train_examples // batch_size
val_steps = math.ceil(num_val_examples / eval_batch_size)
test_steps = math.ceil(num_test_examples / eval_batch_size)

ds = tf.data.TFRecordDataset(train_files)
ds_train = prepare(
    ds.take(num_train_examples), preprocess(training=True),
    batch_size, training=True, buffer_size=10000)
ds_val = prepare(
    ds.skip(num_train_examples), preprocess(training=False),
    eval_batch_size, training=False)
ds_test = prepare(
    tf.data.TFRecordDataset(test_files), preprocess(training=False),
    eval_batch_size, training=False)


In [7]:
with strategy.scope():
    input_shape = (32, 32, 3)
    model = pyramidnet(input_shape, 10, 16, 270, [18, 18, 18])
    criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
    optimizer = tf.keras.optimizers.SGD(learning_rate=0.1 * 8, momentum=0.9, nesterov=True)
    lr_schedule = cosine_lr(base_lr=0.1 * 8, total_epochs=600, warmup=10, gamma=1/8)

    training_loss = Mean('loss', dtype=tf.float32)
    training_accuracy = SparseCategoricalAccuracy('acc', dtype=tf.float32)
    test_loss = Mean('loss', dtype=tf.float32)
    test_accuracy = SparseCategoricalAccuracy('acc', dtype=tf.float32)


Instructions for updating:
If using Keras pass *_constraint arguments to layers.


In [0]:
trainer = Trainer(model, criterion, optimizer, lr_schedule,
                  metrics=[training_loss, training_accuracy],
                  test_metrics=[test_loss, test_accuracy],
                  tpu=tpu, strategy=strategy, model_dir="gs://hrvvi-models/checkpoints/cifar10-pyramidnet",
                  weight_decay=1e-4)

In [0]:
trainer.train_and_evaluate(
    600, ds_train, steps_per_epoch, ds_val, val_steps)




Initializing from scratch.
Epoch 1
Train 	cost: 96s, loss: 1.942, acc: 0.287
Val 	cost: 36s, loss: 1.742, acc: 0.406
Epoch 2
Train 	cost: 15s, loss: 1.566, acc: 0.436
Val 	cost: 0s, loss: 1.645, acc: 0.497
Epoch 3
Train 	cost: 15s, loss: 1.335, acc: 0.522
Val 	cost: 0s, loss: 1.106, acc: 0.629
Epoch 4
Train 	cost: 15s, loss: 1.171, acc: 0.586
Val 	cost: 0s, loss: 1.424, acc: 0.586
Epoch 5
Train 	cost: 15s, loss: 1.050, acc: 0.633
Val 	cost: 0s, loss: 0.849, acc: 0.725
Epoch 6
Train 	cost: 15s, loss: 0.953, acc: 0.667
Val 	cost: 0s, loss: 0.697, acc: 0.774
Epoch 7
Train 	cost: 15s, loss: 0.885, acc: 0.690
Val 	cost: 0s, loss: 0.785, acc: 0.749
Epoch 8
Train 	cost: 15s, loss: 0.853, acc: 0.702
Val 	cost: 0s, loss: 0.712, acc: 0.765
Epoch 9
Train 	cost: 15s, loss: 0.795, acc: 0.722
Val 	cost: 0s, loss: 0.547, acc: 0.818
Epoch 10
Train 	cost: 15s, loss: 0.772, acc: 0.733
Val 	cost: 0s, loss: 0.434, acc: 0.856
Epoch 11
Train 	cost: 15s, loss: 0.734, acc: 0.744
Val 	cost: 0s, loss: 0.681,

In [0]:
trainer.evaluate(ds_test, test_steps)