# Load libraries

In [0]:
!git clone https://github.com/koshian2/OctConv-TFKeras
!mv OctConv-TFKeras/*.py ./

fatal: destination path 'OctConv-TFKeras' already exists and is not an empty directory.
mv: cannot stat 'OctConv-TFKeras/*.py': No such file or directory


# Train OctConv Wide ResNet
* alpha = 0 -> normal wide res-net
* alpha > 0 -> OctConv wide res-net

~ 2 hours for training.

In [0]:
from tensorflow.keras import layers
from oct_conv2d import OctConv2D
from tensorflow.keras.models import Model

def _create_normal_residual_block(inputs, ch, N):
    # adujust channels
    x = layers.Conv2D(ch, 3, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("elu")(x)
    # Conv with skip connections
    for i in range(N-1):
        skip = x
        x = layers.Conv2D(ch, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("elu")(x)
        x = layers.Conv2D(ch, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("elu")(x)
        x = layers.Add()([x, skip])
    return x

def _create_octconv_residual_block(inputs, ch, N, alpha):
    # adjust channels
    high, low = OctConv2D(filters=ch, alpha=alpha)(inputs)
    high = layers.BatchNormalization()(high)
    high = layers.Activation("elu")(high)
    low = layers.BatchNormalization()(low)
    low = layers.Activation("elu")(low)
    # OctConv with skip connections
    for i in range(N-1):
        skip_high, skip_low = [high, low]

        high, low = OctConv2D(filters=ch, alpha=alpha)([high, low])
        high = layers.BatchNormalization()(high)
        high = layers.Activation("elu")(high)
        low = layers.BatchNormalization()(low)
        low = layers.Activation("elu")(low)

        high, low = OctConv2D(filters=ch, alpha=alpha)([high, low])
        high = layers.BatchNormalization()(high)
        high = layers.Activation("elu")(high)
        low = layers.BatchNormalization()(low)
        low = layers.Activation("elu")(low)

        high = layers.Add()([high, skip_high])
        low = layers.Add()([low, skip_low])
    return [high, low]

def create_normal_wide_resnet(N=4, k=10):
    """
    Create vanilla conv Wide ResNet (N=4, k=10)
    """
    # input
    input = layers.Input((32,32,3))

    # 1st block
    x = _create_normal_residual_block(input, 16*k, N)
    # 2nd block
    x = layers.AveragePooling2D(2)(x)
    x = _create_normal_residual_block(x, 32*k, N)
    # 3rd block
    x = layers.AveragePooling2D(2)(x)
    x = _create_normal_residual_block(x, 64*k, N)
    # FC
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(10, activation="softmax")(x)

    model = Model(input, x)
    return model

def create_octconv_wide_resnet(alpha, N=4, k=10):
    """
    Create OctConv Wide ResNet(N=4, k=10)
    """
    # Input
    input = layers.Input((32,32,3))
    # downsampling for lower
    low = layers.AveragePooling2D(2)(input)

    # 1st block
    high, low = _create_octconv_residual_block([input, low], 16*k, N, alpha)
    # 2nd block
    high = layers.AveragePooling2D(2)(high)
    low = layers.AveragePooling2D(2)(low)
    high, low = _create_octconv_residual_block([high, low], 32*k, N, alpha)
    # 3rd block
    high = layers.AveragePooling2D(2)(high)
    low = layers.AveragePooling2D(2)(low)
    high, low = _create_octconv_residual_block([high, low], 64*k, N, alpha)
    # concat
    high = layers.AveragePooling2D(2)(high)
    x = layers.Concatenate()([high, low])
    x = layers.Conv2D(64*k, 1)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("elu")(x)
    # FC
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(10, activation="softmax")(x)

    model = Model(input, x)
    return model

In [0]:
from tensorflow.keras.regularizers import l2

class OctConv2D(layers.Layer):
    def __init__(self, filters, alpha, kernel_size=(3,3), strides=(1,1), 
                    padding="same", kernel_initializer='he_normal',
                    kernel_regularizer=l2(1e-4), kernel_constraint=None,
                    **kwargs):
        """
        OctConv2D : Octave Convolution for image( rank 4 tensors)
        filters: # output channels for low + high
        alpha: Low channel ratio (alpha=0 -> High only, alpha=1 -> Low only)
        kernel_size : 3x3 by default, padding : same by default
        """
        assert alpha >= 0 and alpha <= 1
        assert filters > 0 and isinstance(filters, int)
        super().__init__(**kwargs)

        self.alpha = alpha
        self.filters = filters
        # optional values
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.kernel_initializer = kernel_initializer
        self.kernel_regularizer = kernel_regularizer
        self.kernel_constraint = kernel_constraint
        # -> Low Channels 
        self.low_channels = int(self.filters * self.alpha)
        # -> High Channles
        self.high_channels = self.filters - self.low_channels
        
    def build(self, input_shape):
        assert len(input_shape) == 2
        assert len(input_shape[0]) == 4 and len(input_shape[1]) == 4
        # Assertion for high inputs
        assert input_shape[0][1] // 2 >= self.kernel_size[0]
        assert input_shape[0][2] // 2 >= self.kernel_size[1]
        # Assertion for low inputs
        assert input_shape[0][1] // input_shape[1][1] == 2
        assert input_shape[0][2] // input_shape[1][2] == 2
        # channels last for TensorFlow
        assert K.image_data_format() == "channels_last"
        # input channels
        high_in = int(input_shape[0][3])
        low_in = int(input_shape[1][3])

        # High -> High
        self.high_to_high_kernel = self.add_weight(name="high_to_high_kernel", 
                                    shape=(*self.kernel_size, high_in, self.high_channels),
                                    initializer=self.kernel_initializer,
                                    regularizer=self.kernel_regularizer,
                                    constraint=self.kernel_constraint)
        # High -> Low
        self.high_to_low_kernel  = self.add_weight(name="high_to_low_kernel", 
                                    shape=(*self.kernel_size, high_in, self.low_channels),
                                    initializer=self.kernel_initializer,
                                    regularizer=self.kernel_regularizer,
                                    constraint=self.kernel_constraint)
        # Low -> High
        self.low_to_high_kernel  = self.add_weight(name="low_to_high_kernel", 
                                    shape=(*self.kernel_size, low_in, self.high_channels),
                                    initializer=self.kernel_initializer,
                                    regularizer=self.kernel_regularizer,
                                    constraint=self.kernel_constraint)
        # Low -> Low
        self.low_to_low_kernel   = self.add_weight(name="low_to_low_kernel", 
                                    shape=(*self.kernel_size, low_in, self.low_channels),
                                    initializer=self.kernel_initializer,
                                    regularizer=self.kernel_regularizer,
                                    constraint=self.kernel_constraint)
        super().build(input_shape)

    def call(self, inputs):
        # Input = [X^H, X^L]
        assert len(inputs) == 2
        high_input, low_input = inputs
        # High -> High conv
        high_to_high = K.conv2d(high_input, self.high_to_high_kernel,
                                strides=self.strides, padding=self.padding,
                                data_format="channels_last")
        # High -> Low conv
        high_to_low  = K.pool2d(high_input, (2,2), strides=(2,2), pool_mode="avg")
        high_to_low  = K.conv2d(high_to_low, self.high_to_low_kernel,
                                strides=self.strides, padding=self.padding,
                                data_format="channels_last")
        # Low -> High conv
        low_to_high  = K.conv2d(low_input, self.low_to_high_kernel,
                                strides=self.strides, padding=self.padding,
                                data_format="channels_last")
        low_to_high = K.repeat_elements(low_to_high, 2, axis=1) # Nearest Neighbor Upsampling
        low_to_high = K.repeat_elements(low_to_high, 2, axis=2)
        # Low -> Low conv
        low_to_low   = K.conv2d(low_input, self.low_to_low_kernel,
                                strides=self.strides, padding=self.padding,
                                data_format="channels_last")
        # Cross Add
        high_add = high_to_high + low_to_high
        low_add = high_to_low + low_to_low
        return [high_add, low_add]

    def compute_output_shape(self, input_shapes):
        high_in_shape, low_in_shape = input_shapes
        high_out_shape = (*high_in_shape[:3], self.high_channels)
        low_out_shape = (*low_in_shape[:3], self.low_channels)
        return [high_out_shape, low_out_shape]

    def get_config(self):
        base_config = super().get_config()
        out_config = {
            **base_config,
            "filters": self.filters,
            "alpha": self.alpha,
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "strides": self.strides,
            "padding": self.padding,
            "kernel_initializer": self.kernel_initializer,
            "kernel_regularizer": self.kernel_regularizer,
            "kernel_constraint": self.kernel_constraint,            
        }
        return out_config

In [0]:
import numpy as np

class MixupGenerator():
    def __init__(self, X_train, y_train, batch_size=32, alpha, shuffle=True, datagen=None):
        self.X_train = X_train
        self.y_train = y_train
        self.batch_size = batch_size
        self.alpha = alpha
        self.shuffle = shuffle
        self.sample_num = len(X_train)
        self.datagen = datagen

    def __call__(self):
        while True:
            indexes = self.__get_exploration_order()
            itr_num = int(len(indexes) // (self.batch_size * 2))

            for i in range(itr_num):
                batch_ids = indexes[i * self.batch_size * 2:(i + 1) * self.batch_size * 2]
                X, y = self.__data_generation(batch_ids)

                yield X, y

    def __get_exploration_order(self):
        indexes = np.arange(self.sample_num)

        if self.shuffle:
            np.random.shuffle(indexes)

        return indexes

    def __data_generation(self, batch_ids):
        _, h, w, c = self.X_train.shape
        l = np.random.beta(self.alpha, self.alpha, self.batch_size)
        X_l = l.reshape(self.batch_size, 1, 1, 1)
        y_l = l.reshape(self.batch_size, 1)

        X1 = self.X_train[batch_ids[:self.batch_size]]
        X2 = self.X_train[batch_ids[self.batch_size:]]
        X = X1 * X_l + X2 * (1 - X_l)

        if self.datagen:
            for i in range(self.batch_size):
                X[i] = self.datagen.random_transform(X[i])
                X[i] = self.datagen.standardize(X[i])

        if isinstance(self.y_train, list):
            y = []

            for y_train_ in self.y_train:
                y1 = y_train_[batch_ids[:self.batch_size]]
                y2 = y_train_[batch_ids[self.batch_size:]]
                y.append(y1 * y_l + y2 * (1 - y_l))
        else:
            y1 = self.y_train[batch_ids[:self.batch_size]]
            y2 = self.y_train[batch_ids[self.batch_size:]]
            y = y1 * y_l + y2 * (1 - y_l)

        return X, y

In [0]:
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import LearningRateScheduler, History
from tensorflow.contrib.tpu.python.tpu import keras_support
from models import *

from keras.datasets import cifar10
from keras.utils import to_categorical
import pickle, os, time

def lr_scheduler(epoch):
    x = 0.1
    if epoch >= 100: x /= 5.0
    if epoch >= 150: x /= 5.0
    if epoch >= 200: x /= 5.0
    return x

def train(alpha):
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    train_gen = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True, 
                                    width_shift_range=4.0/32.0, height_shift_range=4.0/32.0,
                                  rotation_range=40, shear_range=0.05, zoom_range=0.05)

    test_gen = ImageDataGenerator(rescale=1.0/255)
    y_train = to_categorical(y_train)
    y_test = to_categorical(y_test)

    tf.logging.set_verbosity(tf.logging.FATAL)

    if alpha <= 0:
        model = create_normal_wide_resnet()
    else:
        model = create_octconv_wide_resnet(alpha)
    model.compile(Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False), "categorical_crossentropy", ["acc"])
    model.summary()

    # convert to tpu model
    tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
    strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
    model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)

    batch_size = 128
    scheduler = LearningRateScheduler(lr_scheduler)
    hist = History()

    start_time = time.time()
    training_generator = MixupGenerator(X_train, y_train, batch_size=batch_size, alpha=1.0, datagen=train_gen)()
    model.fit_generator(training_generator,
                        steps_per_epoch=X_train.shape[0]//batch_size,
                        validation_data=test_gen.flow(X_test, y_test, batch_size, shuffle=False),
                        validation_steps=X_test.shape[0]//batch_size,
                        callbacks=[scheduler, hist], max_queue_size=5, epochs=200)
    elapsed = time.time() - start_time
    print(elapsed)

    history = hist.history
    history["elapsed"] = elapsed

    with open(f"octconv_alpha_{alpha}.pkl", "wb") as fp:
        pickle.dump(history, fp)

if __name__ == "__main__":
    train(0.25)

# Check Test Accuracy

In [0]:
with open("octconv_alpha_0.25.pkl", "rb") as fp:
    data = pickle.load(fp)
    print(f"Max test accuracy = {max(data['val_acc']):.04}")

In [0]:
import matplotlib.pyplot as plt

with open("octconv_alpha_0.25.pkl", "rb") as fp:
    data = pickle.load(fp)
    model = model.load_weights(data)

plt.figure(figsize=(14, 9))

plt.subplot(1, 2, 1)
plt.plot(range(len(model.history.history['loss'])), model.history.history['loss'], 
         label='Train Loss')
plt.plot(range(len(model.history.history['loss'])), model.history.history['val_loss'], 
         label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Cross entropy')
plt.text(0.6 * epochs, 0.75 * max(model.history.history['loss']),
         'Test loss = %.4f' %(metrics[0]))
plt.legend(loc='best')


plt.subplot(1, 2, 2)
plt.plot(range(len(model.history.history['loss'])), model.history.history['acc'], 
         label='Train Accuracy')
plt.plot(range(len(model.history.history['loss'])), model.history.history['val_acc'], 
         label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.text(0.6 * epochs, 0.8 * max(model.history.history['acc']),
         'Test Accuracy = %.4f' %(metrics[1]))
plt.legend(loc='best')


plt.savefig('training_metrics_%i.pdf' %(model_num))
plt.show()