<a href="https://colab.research.google.com/github/tranlethaison/60_Days_RL_Challenge/blob/master/super_convergence/cifar_10_super_convergence.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Summary
Train a model within 18 epochs, to at least 0.94 validation accuracy, on Cifar10 dataset.

# References
- https://www.fast.ai/2018/07/02/adam-weight-decay/
- https://keras.io/guides/transfer_learning/
- https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/
- https://blog.tensorflow.org/2020/05/bigtransfer-bit-state-of-art-transfer-learning-computer-vision.html

# To-do
- Try https://github.com/szagoruyko/wide-residual-networks

# Utils

In [None]:
%%writefile requirements.txt
plotly
snoop

# tensorflow
tensorboard_plugin_profile
tensorflow_addons
tensorflow_hub

Writing requirements.txt


In [None]:
!pip install -qUr requirements.txt

In [None]:
import os
import subprocess
import json
import pickle
import datetime
import enum

import numpy as np
# import numba
# from numba import njit, prange

import pandas as pd

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow.keras as tk
import tensorflow.keras.backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.mixed_precision import experimental as mixed_precision
import tensorflow_hub as hub

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
from PIL import Image

from tqdm import tqdm
from IPython.display import display, HTML

import snoop
snoop.install()

%load_ext tensorboard
%load_ext snoop


pp(tf.__version__)
pp(tfa.__version__)
pp(hub.__version__)

04:31:47.87 LOG:
04:31:47.88 .... tf.__version__ = '2.3.1'
04:31:47.88 LOG:
04:31:47.89 .... tfa.__version__ = '0.11.2'
04:31:47.89 LOG:
04:31:47.89 .... hub.__version__ = '0.10.0'


'0.10.0'

In [None]:
# Use these if run into Conv2D error 
# gpus = tf.config.list_physical_devices("GPU")
# for gpu in gpus:
#     tf.config.experimental.set_memory_growth(gpu, True)
# gpus

# Config

In [None]:
ds_name = "cifar-10"

# Behaviors
do_make_dataset = True  #@param {type:"boolean"}
do_use_mixed_precision = False  #@param {type:"boolean"}
do_augmentation = True  #@param {type:"boolean"}
do_find_lr = False  #@param {type:"boolean"}
do_train = True  #@param {type:"boolean"}
do_load_model = False  #@param {type:"boolean"}

on_colab = True  #@param {type:"boolean"}
on_kaggle = False  #@param {type:"boolean"}
assert (on_colab and on_kaggle) == False
# << Behaviors

# Directory
home_dir = os.path.expanduser("~")
now_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

if on_colab:
    work_dp = r"/content/drive/My Drive/AITrainingRecipe/super_convergence"
elif on_kaggle:
    work_dp = r"/kaggle/working/AITrainingRecipe/super_convergence"
else:
    work_dp = None
    
if not work_dp is None:
    os.makedirs(work_dp, exist_ok=True)
    os.chdir(work_dp)
# !pwd && ls -lh && du -h

dataset_dp = os.path.join(home_dir, "datasets", ds_name)
os.makedirs(dataset_dp, exist_ok=True)

tfhub_cache_dir = os.path.join(home_dir, "tfhub_modules")
os.environ["TFHUB_CACHE_DIR"] = tfhub_cache_dir
os.makedirs(tfhub_cache_dir, exist_ok=True)

lr_find_result_dir = os.path.join("lr_find_result")
os.makedirs(lr_find_result_dir, exist_ok=True)
# << Directory

# Dataset info
# input_shape = [32, 32, 3]
# input_shape = [71, 71, 3]
# input_shape = [75, 75, 3]
# input_shape = [96, 96, 3]
input_shape = [128, 128, 3]
# input_shape = [224, 224, 3]
n_classes = 10
# << Dataset info

# Training
# GPU Tensor Cores (XLA) requires batch_size to be a multiple of 8
batch_size = 128  # 128 256 512 1024
n_epochs = 18
# << Training

# Training Optimization

In [None]:
# Input prefetch.
!lscpu -e 

# Number of CPU threads (nproc --all)
workers = int(subprocess.check_output("nproc --all", shell=True))

prefetch_cfg = dict(
    max_queue_size=10,
    workers=workers,  
)
pp(prefetch_cfg)
# << Input prefetch.

# Reduce 'Kernel Launch' time.
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private"

# Mixed precision (use of both 16-bit and 32-bit).
# If using "Mixed precision", remember to cast output of last layer to "float32" for numeric stability.
# For GPU: "mixed_float16"; for TPU: "mixed_bfloat16"
if do_use_mixed_precision:
    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_policy(policy)

    pp(policy.compute_dtype)
    pp(policy.variable_dtype)
    pp(policy.loss_scale)
!nvidia-smi -L
# << Mixed precision.


def reset_env():
    """Call this before creating new model."""
    tf.keras.backend.clear_session()
    tf.config.optimizer.set_jit(True)  # Enable XLA

CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE
0   0    0      0    0:0:0:0       yes
1   0    0      0    0:0:0:0       yes


04:32:49.67 LOG:
04:32:49.68 .... prefetch_cfg = {'max_queue_size': 10, 'workers': 2}


GPU 0: Tesla P100-PCIE-16GB (UUID: GPU-72982998-4d8c-0b67-3619-ed640a491dfa)


# Dataset

In [None]:
%%time
# https://www.cs.toronto.edu/~kriz/cifar.html

def unpickle(file):
    with open(file, 'rb') as fo:
        d = pickle.load(fo, encoding='bytes')
    return d


def data_to_img(data):
    data = np.array_split(data, 3, axis=0)
    data = [np.array_split(channel, 32, axis=0) for channel in data]
    data = np.stack(data, axis=-1)
    return data


def save_images(images, dir, labels, filenames, target_size=None, do_override=False):
    # Resize image
    image_size = images.shape[1:3]
    if target_size is not None and list(target_size) != image_size:
        target_height, target_width = target_size
        images = tf.image.resize_with_pad(images, target_height, target_width).numpy()
        print("Resized images from {} to {}.".format(image_size, target_size))

    for image, label, filename in tqdm(zip(images, labels, filenames)):
        image_fp = os.path.join(dir, label, filename)

        if not os.path.isfile(image_fp) or do_override:
            pil_im = Image.fromarray(image.astype(np.uint8))
            pil_im.save(image_fp)


def make_dataset(pickle_path, dir):
    raw = unpickle(pickle_path)
    x = np.array([data_to_img(img_data) for img_data in raw[b"data"]])
    y = np.array(raw[b"labels"])
    pp(x.shape, y.shape)

    labels = np.array(meta[b"label_names"], dtype=str)[y]
    filenames = np.array(raw[b"filenames"], dtype=str)

    save_images(x, dir, labels, filenames, target_size=None, do_override=False)


train_dp = os.path.join(dataset_dp, "cifar-10_32x32/train/")
test_dp = os.path.join(dataset_dp, "cifar-10_32x32/test/")

if do_make_dataset:
    # Download and extract raw data
    raw_data_fp = "cifar-10-python.tar.gz"
    !cd $dataset_dp && wget -nc https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz -O $raw_data_fp
    !cd $dataset_dp && tar --skip-old-files -xf $raw_data_fp

    # Find top-level directory(-ies) of an archive.
    raw_data_dp = subprocess.check_output(
        "cd {} && tar -tf {} | sed -e 's@/.*@@' | uniq".format(dataset_dp, raw_data_fp), 
        shell=True,
    )
    raw_data_dp = raw_data_dp.decode().replace("\n", "")
    # Raw data dir path
    raw_data_dp = os.path.join(dataset_dp, raw_data_dp)

    # Make dataset
    meta = unpickle(os.path.join(raw_data_dp, "batches.meta"))

    train_pickle_paths = [
        os.path.join(raw_data_dp, "data_batch_1"),
        os.path.join(raw_data_dp, "data_batch_2"),
        os.path.join(raw_data_dp, "data_batch_3"),
        os.path.join(raw_data_dp, "data_batch_4"),
        os.path.join(raw_data_dp, "data_batch_5"),
    ]
    test_pickle_paths = [os.path.join(raw_data_dp, "test_batch")]

    os.makedirs(train_dp, exist_ok=True)
    os.makedirs(test_dp, exist_ok=True)

    for cls in meta[b"label_names"]:
        cls = cls.decode("utf-8")
        os.makedirs(os.path.join(train_dp, cls), exist_ok=True)
        os.makedirs(os.path.join(test_dp, cls), exist_ok=True)
    
    for train_pickle_path in train_pickle_paths:
        pp(train_pickle_path, train_dp)
        make_dataset(train_pickle_path, train_dp)
    
    for test_pickle_path in test_pickle_paths:
        pp(test_pickle_path, test_dp)
        make_dataset(test_pickle_path, test_dp)

!cd $dataset_dp && pwd && ls -lh && du -h

# CPU times: user 6min 41s, sys: 26.9 s, total: 7min 8s
# Wall time: 6min 37s

--2020-12-08 04:32:58--  https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170498071 (163M) [application/x-gzip]
Saving to: ‘cifar-10-python.tar.gz’


2020-12-08 04:33:05 (27.7 MB/s) - ‘cifar-10-python.tar.gz’ saved [170498071/170498071]



04:33:10.75 LOG:
04:33:10.76 .... <argument 1> = '/root/datasets/cifar-10/cifar-10-batches-py/data_batch_1'
04:33:10.76 .... <argument 2> = '/root/datasets/cifar-10/cifar-10_32x32/train/'
04:33:15.47 LOG:
04:33:15.47 .... <argument 1> = (10000, 32, 32, 3)
04:33:15.47 .... <argument 2> = (10000,)
10000it [00:04, 2142.30it/s]
04:33:20.15 LOG:
04:33:20.15 .... <argument 1> = '/root/datasets/cifar-10/cifar-10-batches-py/data_batch_2'
04:33:20.15 .... <argument 2> = '/root/datasets/cifar-10/cifar-10_32x32/train/'
04:33:25.06 LOG:
04:33:25.06 .... <argument 1> = (10000, 32, 32, 3)
04:33:25.06 .... <argument 2> = (10000,)
10000it [00:04, 2491.38it/s]
04:33:29.09 LOG:
04:33:29.09 .... <argument 1> = '/root/datasets/cifar-10/cifar-10-batches-py/data_batch_3'
04:33:29.09 .... <argument 2> = '/root/datasets/cifar-10/cifar-10_32x32/train/'
04:33:34.37 LOG:
04:33:34.37 .... <argument 1> = (10000, 32, 32, 3)
04:33:34.37 .... <argument 2> = (10000,)
10000it [00:03, 2556.18it/s]
04:33:38.30 LOG:
04:33

/root/datasets/cifar-10
total 163M
drwxr-xr-x 2 2156 1103 4.0K Jun  4  2009 cifar-10-batches-py
-rw-r--r-- 1 root root 163M Jun  4  2009 cifar-10-python.tar.gz
drwxr-xr-x 4 root root 4.0K Dec  8 04:33 cifar-10_32x32
4.0M	./cifar-10_32x32/test/horse
4.0M	./cifar-10_32x32/test/ship
4.0M	./cifar-10_32x32/test/frog
4.0M	./cifar-10_32x32/test/automobile
4.0M	./cifar-10_32x32/test/dog
4.0M	./cifar-10_32x32/test/bird
4.0M	./cifar-10_32x32/test/deer
4.0M	./cifar-10_32x32/test/truck
4.0M	./cifar-10_32x32/test/airplane
4.0M	./cifar-10_32x32/test/cat
40M	./cifar-10_32x32/test
20M	./cifar-10_32x32/train/horse
20M	./cifar-10_32x32/train/ship
20M	./cifar-10_32x32/train/frog
20M	./cifar-10_32x32/train/automobile
20M	./cifar-10_32x32/train/dog
20M	./cifar-10_32x32/train/bird
20M	./cifar-10_32x32/train/deer
20M	./cifar-10_32x32/train/truck
20M	./cifar-10_32x32/train/airplane
20M	./cifar-10_32x32/train/cat
198M	./cifar-10_32x32/train
238M	./cifar-10_32x32
178M	./cifar-10-batches-py
578M	.
CPU times: use

# Model

In [None]:
%%time
# Noted: GPU Tensor Cores (XLA) requires units, filters to be a multiple of 8.

# tensorflow_hub
def cache_tfhub_model(tfhub_cache_dir, hub_model_link, net_name):
    """
    Manually cache tfhub model, 
    use this in case automatically caching fail, especially when on Colab.
    """
    tfhub_module_dir = os.path.join(tfhub_cache_dir, net_name)

    if not os.path.isfile(os.path.join(tfhub_module_dir, "saved_model.pb")):
        os.makedirs(tfhub_module_dir, exist_ok=True)
        subprocess.run(
            f"curl -L {hub_model_link}?tf-hub-format=compressed | tar -zxvC {tfhub_module_dir}",
            shell=True,
            check=True,
        )

    return tfhub_module_dir


# hub_model_link = "https://tfhub.dev/google/imagenet/mobilenet_v2_100_96/feature_vector/4"
# hub_model_link = "https://tfhub.dev/google/imagenet/mobilenet_v2_075_96/feature_vector/4"

hub_model_link = "https://tfhub.dev/google/bit/m-r50x1/1"
# preprocessing_function = lambda x : tf.image.convert_image_dtype(x, tf.float32)
net_name = "bit_m-r50x1"

handle = hub_model_link
# handle = cache_tfhub_model(tfhub_cache_dir, hub_model_link, net_name)
# << tensorflow_hub


# # tf.keras.applications
# # base_model = tk.applications.MobileNetV2
# # preprocessing_function = tk.applications.mobilenet_v2.preprocess_input

# # base_model = tk.applications.EfficientNetB0
# # base_model = tk.applications.EfficientNetB4
# # preprocessing_function = tk.applications.efficientnet.preprocess_input

# # base_model = tk.applications.InceptionV3
# # preprocessing_function = tk.applications.inception_v3.preprocess_input

# base_model = tf.keras.applications.Xception
# preprocessing_function = tk.applications.xception.preprocess_input

# # base_model = tf.keras.applications.InceptionResNetV2
# # preprocessing_function = tk.applications.inception_resnet_v2.preprocess_input

# # base_model = tf.keras.applications.ResNet50V2
# # preprocessing_function = tk.applications.resnet_v2.preprocess_input

# # base_model = tf.keras.applications.NASNetLarge
# # base_model = tf.keras.applications.NASNetMobile
# # preprocessing_function = tk.applications.nasnet.preprocess_input

# net_name = base_model.__name__
# # << tf.keras.applications


def get_model(input_shape, n_classes):
    # tf.keras.applications
    # base = base_model(
    #     include_top=False,
    #     weights="imagenet",
    #     pooling="avg",
    #     name=net_name,
    # )
    # tensorflow_hub
    base = hub.KerasLayer(handle, name=net_name, trainable=True)

    # base.trainable = True
    # for layer in base.layers:
    #     if isinstance(layer, tk.layers.BatchNormalization):
    #         layer.trainable = False

    inputs = tk.layers.Input(input_shape)
    x = base(inputs)
    x = tk.layers.Dense(n_classes, kernel_initializer="zeros", name="logits")(x)
    outputs = tk.layers.Activation("softmax", name="pred", dtype=tf.float32)(x)  # Use "float32" for numeric stability.

    model = tk.models.Model(inputs, outputs, name=f"{ds_name}_{net_name}")
    print("Layers' computations dtype ", x.dtype)
    print("Outputs' dtype ", outputs.dtype)
    return model


reset_env()
model = get_model(input_shape, n_classes)

model_plot_fp = f"{ds_name}_{net_name}.png"
tk.utils.plot_model(model, show_shapes=True, to_file=model_plot_fp)
model.summary()

# Data Generator

In [None]:
common_generator_cfg = dict(
    rescale=1/255.,
    # preprocessing_function=preprocessing_function,
    dtype=tf.float32,  # On CPU, float32 operations are faster.
)

data_aug_cfg = dict(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    # shear_range=0.2,
    zoom_range=0.2,
    fill_mode="nearest",
    horizontal_flip=True,
) if do_augmentation else {}

flow_cfg = dict(
    target_size=input_shape[:2],
    color_mode="rgb",
    class_mode="categorical",
    batch_size=batch_size,
)

train_gen = ImageDataGenerator(**{**common_generator_cfg, **data_aug_cfg})
train_data = train_gen.flow_from_directory(train_dp, shuffle=True, **flow_cfg)

test_gen = ImageDataGenerator(**common_generator_cfg)
test_data = test_gen.flow_from_directory(test_dp, shuffle=False, **flow_cfg)

# HP Scheduler

In [None]:
class DecayType(enum.IntEnum):
    """Data class, each decay type is assigned a number."""
    LINEAR = 0
    COSINE = 1
    EXPONENTIAL = 2
    POLYNOMIAL = 3


class DecayScheduler():
    """Given initial and endvalue, 
    this class generates the value depending on decay type and decay steps (by calling).
    """

    def __init__(self, start_val, end_val, decay_steps, decay_type, extra=1.0):
        self.start_val = start_val
        self.end_val = end_val
        self.decay_steps = decay_steps
        self.decay_type = decay_type
        self.extra = extra
    
    def __call__(self, step):
        if self.decay_type == DecayType.LINEAR:
            pct = step / self.decay_steps
            return self.start_val + pct * (self.end_val - self.start_val)
        elif self.decay_type == DecayType.COSINE:
            cos_out = np.cos(np.pi * step / self.decay_steps) + 1
            return self.end_val + (self.start_val - self.end_val) / 2 * cos_out
        elif self.decay_type == DecayType.EXPONENTIAL:
            ratio = self.end_val / self.start_val
            return self.start_val * ratio **  (step / self.decay_steps)
        elif self.decay_type == DecayType.POLYNOMIAL:
            return self.end_val + (self.start_val - self.end_val) * (1 - step / self.decay_steps) ** self.extra

In [None]:
def get_decay_steps(cycle_len, anneal_pct):
    phase_len = int(cycle_len * (1 - anneal_pct) / 2)
    anneal_len = cycle_len - phase_len * 2
    return phase_len, phase_len, anneal_len

def onecyle_learning_rate(
    init_lr,
    end_lr,
    train_steps,
    decay_type=DecayType.LINEAR,
    anneal_pct=0.075,
):
    """OneCyle learning rates
    Args:
        anneal_pct (float): 
            Percentage to leave for the annealing at the end.
            The annealing phase goes from the minimum lr to 1/100th of it linearly.
    """
    phase_decay_steps = get_decay_steps(train_steps, anneal_pct)

    lr_schedulers = [
        DecayScheduler(init_lr, end_lr, phase_decay_steps[0], decay_type),
        DecayScheduler(end_lr, init_lr, phase_decay_steps[1], decay_type),
        DecayScheduler(init_lr, init_lr / 100., phase_decay_steps[2], decay_type),
    ]

    learning_rates = []
    for lr_scheduler, decay_steps in zip(lr_schedulers, phase_decay_steps):
        learning_rates.append(lr_scheduler(np.arange(decay_steps)))

    learning_rates = np.concatenate(learning_rates, 0)
    return learning_rates


def onecyle_momentum(
    init_mom,
    end_mom,
    train_steps,
    decay_type=DecayType.LINEAR,
    anneal_pct=0.075,
):
    """OneCyle learning rates
    Args:
        anneal_pct (float): 
            Percentage to leave for the annealing at the end.
            The annealing phase use constant maximum momentum.
    """
    phase_decay_steps = get_decay_steps(train_steps, anneal_pct)

    mom_schedulers = [
        DecayScheduler(init_mom, end_mom, phase_decay_steps[0], decay_type),
        DecayScheduler(end_mom, init_mom, phase_decay_steps[1], decay_type),
    ]

    moms = []
    for mom_scheduler, decay_steps in zip(mom_schedulers, phase_decay_steps):
        moms.append(mom_scheduler(np.arange(decay_steps)))
    moms.append(np.array([init_mom] * phase_decay_steps[2]))

    moms = np.concatenate(moms, 0)
    return moms

In [None]:
class OneCycleScheduler(tk.callbacks.Callback):
    """Callback that update lr, momentum at begining of mini batch based on OneCycle policy."""

    def __init__(
        self,
        init_lr,
        end_lr,
        train_steps,
        decay_type=DecayType.LINEAR,
        anneal_pct=0.075,
        init_mom=None,
        end_mom=None,
    ):
        super().__init__()
        self.train_steps = train_steps

        common_kwargs = dict(
            train_steps=train_steps,
            decay_type=decay_type,
            anneal_pct=anneal_pct,
        )

        self.learning_rates = onecyle_learning_rate(
            init_lr=init_lr,
            end_lr=end_lr,
            **common_kwargs
        )

        if not (init_mom is None or end_mom is None):
            self.moms = onecyle_momentum(
                init_mom=init_mom,
                end_mom=end_mom,
                **common_kwargs
            )
        else:
            self.moms = None

    def on_train_begin(self, logs=None):
        self.train_step = 0

    def on_train_batch_begin(self, batch, logs=None):
        if self.train_step < self.train_steps:
            K.set_value(self.model.optimizer.learning_rate, self.learning_rates[self.train_step])

            if not self.moms is None:
                # If using "tf.keras.mixed_precision.experimental", then set HP this way (except for learning_rate)
                # K.set_value(self.model.optimizer._optimizer.beta_1, self.moms[self.train_step])

                K.set_value(self.model.optimizer.beta_1, self.moms[self.train_step])

        self.train_step += 1

    def plot(self):
        a_train_steps = np.arange(self.train_steps)
        traces = [
            go.Scatter(x=a_train_steps, y=self.learning_rates, name="learning_rate"),
        ]
        if not self.moms is None:
            traces.append(go.Scatter(x=a_train_steps, y=self.moms, name="momentum"))

        fig = make_subplots(len(traces), 1)
        for i, trace in enumerate(traces):
            fig.add_trace(trace, row=i+1, col=1)
        fig.show()

# LR Finder

In [None]:
class LRFinder(tk.callbacks.Callback):
    """Learning Rate Finder Callback"""
    
    def __init__(
        self, 
        init_lr,
        end_lr,
        decay_type=DecayType.EXPONENTIAL,
        beta=0.98
    ):
        super().__init__()
        self.init_lr = init_lr
        self.end_lr = end_lr
        self.decay_type = decay_type
        self.beta = beta

    def on_train_begin(self, logs=None):
        # pp(self.params)
        self.train_steps = self.params["epochs"] * self.params["steps"]

        self.scheduler = DecayScheduler(
            start_val=self.init_lr, 
            end_val=self.end_lr, 
            decay_steps=self.train_steps,
            decay_type=self.decay_type, 
        )
        self.learning_rates = self.scheduler(np.arange(self.train_steps))

        self.history = {}
        self.train_step = 0
        self.avg_loss = 0.
        self.best_loss = 0.
        self.best_learning_rate = 0.

    def on_train_batch_begin(self, batch, logs=None):
        K.set_value(self.model.optimizer.learning_rate, self.learning_rates[self.train_step])

    def on_train_batch_end(self, batch, logs=None):
        # Compute the smoothed loss
        self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * logs["loss"]
        smoothed_loss = self.avg_loss / (1 - self.beta ** (self.train_step + 1))

        # Stop if the loss is exploding
        if self.train_step > 1 and smoothed_loss > 4 * self.best_loss:
            self.model.stop_training = True
            print("Stop training because loss is exploding.")

        # Record the best loss, learning_rate
        if self.train_step == 1 or smoothed_loss < self.best_loss:
            self.best_loss = smoothed_loss
            self.best_learning_rate = self.learning_rates[self.train_step]

        if not self.model.stop_training:
            # History
            for key, val in logs.items():
                self.history.setdefault(key, []).append(val)
            self.history.setdefault("smoothed_loss", []).append(smoothed_loss)

            self.train_step += 1

    def plot_learning_rate(self):
        fig = px.line(
            x=np.arange(self.train_step),
            y=self.learning_rates[:self.train_step],
            labels=dict(x="train_step", y="learning_rate"),
        )
        fig.show()

    def _plot_metric(self, metric, metric_name):
        fig = px.line(
            x=self.learning_rates[:self.train_step],
            y=metric, 
            log_x=True,
            labels=dict(x="learning_rate(log)", y=metric_name),
        )
        fig.show()

    def plot_accuracy(self):
        self._plot_metric(self.history["accuracy"], "accuracy")

    def plot_loss(self):
        self._plot_metric(self.history["loss"], "loss")

    def plot_smoothed_loss(self):
        self._plot_metric(self.history["smoothed_loss"], "smoothed_loss")


def plot_metric(learning_rates, metrics, metric_name, optimizer_cfgs):
    """
    Plot the same metric from multiple training run (using different optimizer configuration),
    against learning rates.
    """
    fig = go.Figure()

    for metric, optimizer_cfg in zip(metrics, optimizer_cfgs):
        if len(metric) < len(learning_rates):
            metric = np.append(
                metric, [np.nan] * (len(learning_rates) - len(metric))
            )
        fig.add_trace(
            go.Scatter(x=learning_rates, y=metric, name=f"{optimizer_cfg}")
        )

    fig.update_xaxes(type="log")
    fig.update_layout(xaxis_title="learning_rate(log)", yaxis_title=metric_name)
    fig.show()
    return fig

In [None]:
loss = "categorical_crossentropy"

optimizer_class = tfa.optimizers.AdamW
weight_decays = [1e-2, 1e-3, 1e-4]
optimizer_cfgs = [
    dict(weight_decay=wd) for wd in weight_decays
]

# optimizer_class = tk.optimizers.RMSprop
# optimizer_class = tk.optimizers.Adam
# optimizer_cfg = dict()


In [None]:
%%time
if do_find_lr:
    init_lr = 1e-4
    end_lr = 1

    learning_rates = None
    accuracies = []
    losses = []
    smoothed_losses = []
    best_losses = []
    best_learning_rates = []

    for optimizer_cfg in optimizer_cfgs:
        pp(optimizer_cfg)
        
        reset_env()
        model = get_model(input_shape, n_classes)

        cb_lrfinder = LRFinder(init_lr, end_lr, decay_type=DecayType.EXPONENTIAL, beta=0.98)

        optimizer = optimizer_class(**optimizer_cfg)
        model.compile(loss=loss, optimizer=optimizer, metrics=["accuracy"])
        model.fit(
            train_data,
            epochs=1, 
            batch_size=batch_size, 
            callbacks=[cb_lrfinder], 
            verbose=1,
            **prefetch_cfg
        )

        # cb_lrfinder.plot_learning_rate()
        # cb_lrfinder.plot_accuracy()
        # cb_lrfinder.plot_loss()
        # cb_lrfinder.plot_smoothed_loss()
        # pp(cb_lrfinder.best_loss, cb_lrfinder.best_learning_rate)

        if learning_rates is None:
            learning_rates = cb_lrfinder.learning_rates

        accuracies.append(cb_lrfinder.history["accuracy"])
        losses.append(cb_lrfinder.history["loss"])
        smoothed_losses.append(cb_lrfinder.history["smoothed_loss"])
        best_losses.append(cb_lrfinder.best_loss)
        best_learning_rates.append(cb_lrfinder.best_learning_rate)

In [None]:
lr_find_result_f = os.path.join(lr_find_result_dir, f"{net_name}.pickle")

if do_find_lr:
    lr_find_result = {
        "accuracies" : accuracies,
        "losses" : losses,
        "smoothed_losses" : smoothed_losses,
        "learning_rates" : learning_rates,
        "best_losses" : best_losses,
        "best_learning_rates": best_learning_rates,
    }
    with open(lr_find_result_f, "wb") as pickle_fo:
        pickle.dump(lr_find_result, pickle_fo)
else:
    if os.path.isfile(lr_find_result_f):
        with open(lr_find_result_f, "rb") as pickle_fo:
            lr_find_result = pickle.load(pickle_fo)
    else:
        lr_find_result = {}        

for key, val in lr_find_result.items():
    if key in ["accuracies", "losses", "smoothed_losses"]:
        fig = plot_metric(
            lr_find_result["learning_rates"],
            metrics=val,
            metric_name=key,
            optimizer_cfgs=optimizer_cfgs
        )

pp(optimizer_cfgs, lr_find_result.get("best_losses"), lr_find_result.get("best_learning_rates"))

# HP

In [None]:
n_batches_per_epoch = len(train_data)
n_train_steps = n_epochs * n_batches_per_epoch
pp(n_train_steps)

optimizer_cfg = dict(
#     weight_decay=0.0001,  # bit_m-r50x fine tune 512 batch_size, 30 epochs
    weight_decay=0.0001,  # bit_m-r50x fine tune 128 batch_size, 18 epochs
)

# LR
# end_lr = 0.001  # bit_m-r50x fine tune 512 batch_size, 30 epochs
end_lr = 0.001  # bit_m-r50x fine tune 128 batch_size, 18 epochs
init_lr = end_lr / 10.

# Momemtum
init_mom = 0.95
end_mom = 0.85

cb_onecycle = OneCycleScheduler(
    init_lr, 
    end_lr,
    n_train_steps,
    decay_type=DecayType.LINEAR,
    anneal_pct=0.075,
    init_mom=init_mom,
    end_mom=end_mom,
)
cb_onecycle.plot()

# Callbacks

In [None]:
monitor = "val_accuracy"
mode = "max"
# goal = 0.97
min_delta = 1e-3

validation_freq = 3

cb_early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor=monitor,
    mode=mode,
    min_delta=min_delta,
    patience=9,
    restore_best_weights=True,
)

# Model checkpoint dir
model_root_dir = os.path.join("models/", net_name)
os.makedirs(model_root_dir, exist_ok=True)

checkpoint_fp = os.path.join(model_root_dir, now_str)
model_fps = os.listdir(model_root_dir)
lastest_model_fp = (
    os.path.join(model_root_dir, sorted(model_fps)[-1]) 
    if len(model_fps) > 0 
    else None
)

cb_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_fp,
    monitor=monitor,
    mode=mode,
    save_weights_only=False,
    save_best_only=True,
    save_format="tf",
    include_optimizer=True,
)

# Tensorboard logs dir
tb_logs_root_dir = os.path.join("tb_logs/", net_name)
if do_train:
    tb_logs_dp = os.path.join(tb_logs_root_dir, now_str)
else:
    tb_logs_dps = os.listdir(tb_logs_root_dir)
    tb_logs_dp = (
        os.path.join(tb_logs_root_dir, sorted(tb_logs_dps)[-1]) 
        if len(tb_logs_dps) > 0 
        else None
    )

cb_tensorboard = tf.keras.callbacks.TensorBoard(
    log_dir=tb_logs_dp,
    histogram_freq=1,
    update_freq='epoch',
    profile_batch="2,22",
)


# class Goal(tf.keras.callbacks.Callback):
#     def __init__(self, monitor, mode, goal):
#         super().__init__()
#         self.monitor = monitor
#         self.mode = mode
#         self.goal = goal

#     def on_epoch_end(self, epoch, logs={}):
#         if self.mode == "min":
#             goal_achieved = logs[self.monitor] <= self.goal
#         elif self.mode == "max":
#             goal_achieved = logs[self.monitor] >= self.goal

#         if goal_achieved:
#             print("Goal {}: {} achieved. Stop training.".format(self.monitor, self.goal))
#             self.model.stop_training = True


# cb_goal = Goal(monitor, mode, goal)

# cb_reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
#     monitor=monitor,
#     mode=mode,
#     min_delta=min_delta,
#     patience=5,
#     factor=0.2,
#     min_lr=1e-6,
# )

callbacks = [
    cb_checkpoint,
    cb_early_stopping,
    cb_tensorboard,
    # cb_goal,
    # cb_reduce_lr,
    cb_onecycle,
]

# Train

In [None]:
%%time
reset_env()
tf.get_logger().setLevel('ERROR')

# For test run
# n_epochs = 3

if do_train:

    model_loaded = False
    if do_load_model:
        try:
            model = tk.models.load_model(lastest_model_fp, compile=True)
            model_loaded = True
            print("Model loaded: {}".format(lastest_model_fp))
        except Exception as ex:
            print("ERROR load_model: {}".format(ex))
    
    if not do_load_model or not model_loaded:
        model = get_model(input_shape, n_classes)
        model.compile(
            loss=loss,
            optimizer=optimizer_class(**optimizer_cfg),
            metrics=["accuracy"]
        )
        print("Created new model.")

    history = model.fit(
        train_data,
        epochs=n_epochs,
        validation_data=test_data,
        validation_freq=validation_freq,
        callbacks=callbacks,
        verbose=1,
        **prefetch_cfg,
    )
else:
    model = tk.models.load_model(lastest_model_fp, compile=True)


# Evaluation

In [None]:
model.evaluate(test_data, verbose=1, **prefetch_cfg)

In [None]:
%tensorboard --logdir $tb_logs_dp