In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Conv2D, BatchNormalization, Dense, MaxPooling2D, \
    Flatten, Dropout, GlobalAveragePooling2D, Layer, Input, add, ReLU, Dropout
from tensorflow.keras.callbacks import ModelCheckpoint
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.image import random_flip_left_right, random_flip_up_down, random_brightness

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

### Data pipeline

In [None]:
ds, ds_info = tfds.load('fashion_mnist', with_info=True)

In [None]:
ds_info

In [None]:
# Within each of these, the iterator produces a single dict
# Within the dict is 'image' and 'label' keys
ds_train = ds['train']
ds_test = ds['test']

In [None]:
"""
Code for making Tensors have slice-able assignment from StackOverflow user: Tensorflow Support
https://stackoverflow.com/questions/54086836/how-to-fix-sliced-assignment-is-only-supported-for-variables-for-tensors

This custom class and methods are needed for the application of random erasing augmentation
"""

def replace_slice(input_, replacement, begin, size=None):
    inp_shape = tf.shape(input_)
    if size is None:
        size = tf.shape(replacement)
    else:
        replacement = tf.broadcast_to(replacement, size)
    padding = tf.stack([begin, inp_shape - (begin + size)], axis=1)
    replacement_pad = tf.pad(replacement, padding)
    mask = tf.pad(tf.ones_like(replacement, dtype=tf.bool), padding)
    return tf.where(mask, replacement_pad, input_)

def replace_slice_in(tensor):
    return _SliceReplacer(tensor)

class _SliceReplacer:
    def __init__(self, tensor):
        self._tensor = tensor
    def __getitem__(self, slices):
        return _SliceReplacer._Inner(self._tensor, slices)
    def with_value(self, replacement):  # Just for convenience in case you skip the indexing
        return _SliceReplacer._Inner(self._tensor, (...,)).with_value(replacement)
    class _Inner:
        def __init__(self, tensor, slices):
            self._tensor = tensor
            self._slices = slices
        def with_value(self, replacement):
            begin, size = _make_slices_begin_size(self._tensor, self._slices)
            return replace_slice(self._tensor, replacement, begin, size)

# This computes begin and size values for a set of slices
def _make_slices_begin_size(input_, slices):
    if not isinstance(slices, (tuple, list)):
        slices = (slices,)
    inp_rank = tf.rank(input_)
    inp_shape = tf.shape(input_)
    # Did we see a ellipsis already?
    before_ellipsis = True
    # Sliced dimensions
    dim_idx = []
    # Slice start points
    begins = []
    # Slice sizes
    sizes = []
    for i, s in enumerate(slices):
        if s is Ellipsis:
            if not before_ellipsis:
                raise ValueError('Cannot use more than one ellipsis in slice spec.')
            before_ellipsis = False
            continue
        if isinstance(s, slice):
            start = s.start
            stop = s.stop
            if s.step is not None:
                raise ValueError('Step value not supported.')
        else:  # Assumed to be a single integer value
            start = s
            stop = s + 1
        # Dimension this slice refers to
        i_dim = i if before_ellipsis else inp_rank - (len(slices) - i)
        dim_size = inp_shape[i_dim]
        # Default slice values
        start = start if start is not None else 0
        stop = stop if stop is not None else dim_size
        # Fix negative indices
        start = tf.cond(tf.convert_to_tensor(start >= 0), lambda: start, lambda: start + dim_size)
        stop = tf.cond(tf.convert_to_tensor(stop >= 0), lambda: stop, lambda: stop + dim_size)
        dim_idx.append([i_dim])
        begins.append(start)
        sizes.append(stop - start)
    # For empty slice specs like [...]
    if not dim_idx:
        return tf.zeros_like(inp_shape), inp_shape
    # Make full begin and size array (including omitted dimensions)
    begin_full = tf.scatter_nd(dim_idx, begins, [inp_rank])
    size_mask = tf.scatter_nd(dim_idx, tf.ones_like(sizes, dtype=tf.bool), [inp_rank])
    size_full = tf.where(size_mask,
                          tf.scatter_nd(dim_idx, sizes, [inp_rank]),
                          inp_shape)
    return begin_full, size_full

def random_erasing(img, probability = 0.08, sl = 0.02, sh = 0.4, r1 = 0.3):
    '''
    img is a 3-D variable (ex: tf.Variable(image, validate_shape=False) ) and  HWC order
    '''
    # HWC order
    height = 28
    width = 28
    channel = 1
    area = tf.cast(784, tf.float32)

    erase_area_low_bound = tf.cast(tf.round(tf.sqrt(sl * area * r1)), tf.int32)
    erase_area_up_bound = tf.cast(tf.round(tf.sqrt((sh * area) / r1)), tf.int32)
    h_upper_bound = tf.minimum(erase_area_up_bound, height)
    w_upper_bound = tf.minimum(erase_area_up_bound, width)

    h = tf.random.uniform([], erase_area_low_bound, h_upper_bound, tf.int32)
    w = tf.random.uniform([], erase_area_low_bound, w_upper_bound, tf.int32)

    x1 = tf.random.uniform([], 0, height+1 - h, tf.int32)
    y1 = tf.random.uniform([], 0, width+1 - w, tf.int32)

    erase_area = tf.cast(tf.random.uniform([h, w, channel], 0, 1, tf.int32), tf.float32)
    erasing_img = replace_slice_in(img)[x1:x1+h, y1:y1+w, :].with_value(erase_area)
#     erasing_img = img[x1:x1+h, y1:y1+w, :].assign(erase_area)

    return tf.cond(tf.random.uniform([], 0, 1) > probability, lambda: img, lambda: erasing_img)

def augment_train(data):
    image = data['image']
    label = data['label']
    image = tf.cast(image, tf.float32)
    image /= 255.0
    
    image = random_flip_left_right(image)
    image = random_flip_up_down(image)
    image = random_brightness(image, 0.3)
    image = random_erasing(image)
    
    return image, label

def augment_test(data):
    image = data['image']
    label = data['label']
    image = tf.cast(image, tf.float32)
    image /= 255.0
    
    image = random_flip_left_right(image)
    image = random_flip_up_down(image)
    image = random_brightness(image, 0.3)
    
    return image, label

In [None]:
batch_size = 1024

ds_train = (
    ds_train
    .map(augment_train, num_parallel_calls=AUTO)
    .shuffle(10000)
    .batch(batch_size)
    .prefetch(AUTO)
)

ds_test = (
    ds_test
    .map(augment_train, num_parallel_calls=AUTO)
    .batch(batch_size)
    .prefetch(AUTO)
)

### ResNet Architecture

In [None]:
class ResidualBlock(Layer):
    """
    Introduced by He et al. (2015): https://arxiv.org/pdf/1512.03385.pdf
    Just as in the original implementation, we apply batch normalization BEFORE the layer's activation (hence activation=None).
    In addition, this residual block performs the identity mapping in the skip connection, just as in the smaller (18, 34)
        ResNets from the original paper. This is explained in 1st paragraph of "Residual Network." subsection on page 5.
    
    Attributes:
        filters (int): the number of filters in the convolutional layers in this residual block
        dim_increase (bool): indicates whether this block has an increase in filter dimensions from the residual block before it.
            in the original paper, the authors propose two options: (a) identiy mapping with zero-padding for extra space or 
            (b) shortcut connection with Network-in-Network (1x1 convolution) connection. This is explained in "Residual Network."
            subsection on page 4. We choose to use (b). Both options use a stride of 2 in the first convolution.
    """
    
    
    def __init__(self, filters, dim_increase):
        super(ResidualBlock, self).__init__()
        
        self.filters = filters
        self.dim_increase = dim_increase

            
        if filters != 64 and dim_increase:
            self.conv1 = Conv2D(kernel_size=(3, 3), filters=filters, activation=None,
                               padding='same', strides=2) 
            self.conv2 = Conv2D(kernel_size=(3, 3), filters=filters,
                               activation=None, padding='same')
            
        else:
            self.conv1 = Conv2D(kernel_size=(3, 3), filters=filters,
                                activation=None, padding='same')
            self.conv2 = Conv2D(kernel_size=(3, 3), filters=filters,
                                activation=None, padding='same')
        
        self.bn1 = BatchNormalization()
        self.bn2 = BatchNormalization()
        self.bn_resid = BatchNormalization()
        
        if dim_increase:
            self.residual_connection = Conv2D(kernel_size=(1, 1), filters=filters, strides=2, activation=None)
        else:
            self.residual_connection = lambda x: x
            
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'filters': self.filters,
            'dim_increase': self.dim_increase
        })
        
        return config
            
    def call(self, x):
        if self.dim_increase:
            resid = self.residual_connection(x)
            resid = self.bn_resid(resid)
            resid = tf.nn.relu(resid)
        else:
            resid = self.residual_connection(x)
            
        x = self.conv1(x)
        x = self.bn1(x)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = tf.nn.relu(x)
        output = tf.nn.relu(add([resid, x]))
        return output
        

In [None]:
INPUT_SHAPE = (28, 28, 1)

resnet18 = tf.keras.Sequential([
    Input(shape=INPUT_SHAPE),
    Conv2D(kernel_size=(7, 7), filters=64, strides=2, activation=None, padding='same'),
    BatchNormalization(),
    ReLU(),
    MaxPooling2D(pool_size=(3, 3), strides=2),
    ResidualBlock(filters=64, dim_increase=False),
    ResidualBlock(filters=64, dim_increase=False),
    ResidualBlock(filters=128, dim_increase=True),
    ResidualBlock(filters=128, dim_increase=False),
    ResidualBlock(filters=256, dim_increase=True),
    ResidualBlock(filters=256, dim_increase=False),
    Dropout(0.10),
    ResidualBlock(filters=512, dim_increase=True),
    Dropout(0.10),
    ResidualBlock(filters=512, dim_increase=False),
    Dropout(0.10),
    GlobalAveragePooling2D(),
    Dense(10, activation='softmax'),
    Dropout(0.10)
])

resnet18.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
                 optimizer=tf.keras.optimizers.Adam(),
                 metrics=tf.keras.metrics.SparseCategoricalAccuracy())

In [None]:
resnet18.summary()

In [None]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=150),
    
]

train_hist = resnet18.fit(
    ds_train, 
    epochs=500,
    validation_data=ds_test,
    callbacks=callbacks
)

In [None]:
plt.plot(train_hist.history['sparse_categorical_accuracy'])
plt.plot(train_hist.history['val_sparse_categorical_accuracy'])
plt.legend(['Train', 'Test'])

In [None]:
resnet18.save('resnet18.h5')

## Looking into accuracy

##### Loding the model (if needed)

In [None]:
INPUT_SHAPE = (28, 28, 1)

resnet18 = tf.keras.Sequential([
    Input(shape=INPUT_SHAPE),
    Conv2D(kernel_size=(7, 7), filters=64, strides=2, activation=None, padding='same'),
    BatchNormalization(),
    ReLU(),
    MaxPooling2D(pool_size=(3, 3), strides=2),
    ResidualBlock(filters=64, dim_increase=False),
    ResidualBlock(filters=64, dim_increase=False),
    ResidualBlock(filters=128, dim_increase=True),
    ResidualBlock(filters=128, dim_increase=False),
    ResidualBlock(filters=256, dim_increase=True),
    ResidualBlock(filters=256, dim_increase=False),
    Dropout(0.10),
    ResidualBlock(filters=512, dim_increase=True),
    Dropout(0.10),
    ResidualBlock(filters=512, dim_increase=False),
    Dropout(0.10),
    GlobalAveragePooling2D(),
    Dense(10, activation='softmax'),
    Dropout(0.10)
])


resnet18.load_weights('./resnet18.h5')

In [None]:
# Reloading the models without the probabilistic augmentations...
ds_train = ds['train']
ds_test = ds['test']

batch_size = 1024


def augment_inference(data):
    image = data['image']
    label = data['label']
    image = tf.cast(image, tf.float32)
    image /= 255.0
    
    return image, label

ds_train = (
    ds_train
    .map(augment_inference, num_parallel_calls=AUTO)
    .batch(batch_size)
    .prefetch(AUTO)
)

ds_test = (
    ds_test
    .map(augment_inference, num_parallel_calls=AUTO)
    .batch(batch_size)
    .prefetch(AUTO)
)

In [None]:
preds = []
labels = []
for xdata, label in ds_train:
    p = resnet18(xdata)
    preds.append(p.numpy())
    labels.append(label.numpy())

In [None]:
unrolled_preds = []
for p in preds:
    for j in p:
        unrolled_preds.append(np.argmax(j))
        
unrolled_labels = []
for l in labels:
    for j in l:
        unrolled_labels.append(j)

In [None]:
acc = np.sum(np.array(unrolled_preds) == np.array(unrolled_labels)) / len(unrolled_preds)
print(f'Training set accuracy: {acc}')

In [None]:
preds = []
labels = []
for xdata, label in ds_test:
    p = resnet18(xdata)
    preds.append(p.numpy())
    labels.append(label.numpy())

In [None]:
unrolled_preds = []
for p in preds:
    for j in p:
        unrolled_preds.append(np.argmax(j))
        
unrolled_labels = []
for l in labels:
    for j in l:
        unrolled_labels.append(j)

In [None]:
acc = np.sum(np.array(unrolled_preds) == np.array(unrolled_labels)) / len(unrolled_preds)
print(f'Test set accuracy: {acc}')