In [None]:
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from scipy.io import loadmat
from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import mnist

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import confusion_matrix
from keras.preprocessing.image import ImageDataGenerator
import tensorflow
import tensorflow as tf
from time import time
from collections import Counter
import keras
from keras.layers import Dense, Conv2D, BatchNormalization, Activation, Dropout
from keras.layers import AveragePooling2D, Input, Flatten
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as K
from keras.models import Model
%matplotlib inline

In [None]:
batch_size = 32 
epochs = 200
num_classes = 10
depth = 20
subtract_pixel_mean = True

model_type = 'ResNet%d' % (depth)

In [None]:
(mnist_train_images, mnist_train_labels), (mnist_test_images, mnist_test_labels) = mnist.load_data()

In [None]:
print(Counter(mnist_train_labels).keys()) # equals to list(set(words))
print(Counter(mnist_train_labels).values()) # counts the elements' frequency

In [None]:
new_images = []
new_labels = []
class_threshold = 50

one=0
two=0
three=0
four=0
five=0
six=0
seven=0
eight=0
nine=0
zero=0

for i in range(len(mnist_train_images)): #0-60000
    if mnist_train_labels[i]==0:
        if zero<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            zero+=1
    elif mnist_train_labels[i]==1:
        if one<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            one+=1
    elif mnist_train_labels[i]==2:
        if two<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            two+=1
    elif mnist_train_labels[i]==3:
        if three<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            three+=1
    elif mnist_train_labels[i]==4:
        if four<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            four+=1
    elif mnist_train_labels[i]==5:
        if five<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            five+=1
    elif mnist_train_labels[i]==6:
        if six<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            six+=1
    elif mnist_train_labels[i]==7:
        if seven<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            seven+=1
    elif mnist_train_labels[i]==8:
        if eight<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            eight+=1
    elif mnist_train_labels[i]==9:
        if nine<class_threshold:
            new_images.append(mnist_train_images[i])
            new_labels.append(mnist_train_labels[i])
            nine+=1

In [None]:
#Renaming back to original name - 
mnist_train_images = np.array(new_images)
mnist_train_labels = np.array(new_labels)

In [None]:
print(Counter(mnist_train_labels).keys()) # equals to list(set(words))
print(Counter(mnist_train_labels).values()) # counts the elements' frequency

In [None]:
from tensorflow.keras import backend as K

if K.image_data_format() == 'channels_first':
    train_images = mnist_train_images.reshape(mnist_train_images.shape[0], 1, 28, 28)
    test_images = mnist_test_images.reshape(mnist_test_images.shape[0], 1, 28, 28)
    input_shape = (1, 28, 28) #if the dataset has the channels first, then we do (1, 28, 28)
else:
    train_images = mnist_train_images.reshape(mnist_train_images.shape[0], 28, 28, 1)
    test_images = mnist_test_images.reshape(mnist_test_images.shape[0], 28, 28, 1)
    input_shape = (28, 28, 1) #if the dataset has features first, then we do (28, 28, 1)
    
train_images = train_images.astype('float32')
test_images = test_images.astype('float32')
train_images /= 255
test_images /= 255

In [None]:
train_labels = tensorflow.keras.utils.to_categorical(mnist_train_labels, 10)
test_labels = tensorflow.keras.utils.to_categorical(mnist_test_labels, 10)

In [None]:
print("Count of training images -", len(train_images))
print("Count of testing images -", len(test_images))
print("Count of training labels -", len(train_labels))
print("Count of testing labels -", len(test_labels))

In [None]:
lb = LabelBinarizer()
train_labels = lb.fit_transform(train_labels)
test_labels = lb.fit_transform(test_labels)

In [None]:
#The following excerpt of codes was extracted and modified from - 
#Arvind Singh, Y. (2019) Yasharvindsingh/Resnet20: A ResNet architecture implemented in Keras, GitHub. 
#Available at: https://github.com/yasharvindsingh/ResNet20
#Strictly, only the architecture.

In [None]:
def lr_schedule(epoch):
  lr = 1e-3
  if epoch > 180:
      lr *= 0.5e-3
  elif epoch > 160:
      lr *= 1e-3
  elif epoch > 120:
      lr *= 1e-2
  elif epoch > 80:
      lr *= 1e-1
  print('Learning rate: ', lr)
  return lr

In [None]:
def resnet_layer(inputs, num_filters=16, kernel_size=3, strides=1, activation='relu', batch_normalization=True, conv_first=True):
  
  conv = Conv2D(num_filters, kernel_size=kernel_size, strides=strides, padding='same')

  x = inputs
  if conv_first:
      x = conv(x)
      if batch_normalization:
          x = BatchNormalization()(x)
      if activation is not None:
          x = Activation(activation)(x)
  else:
      if batch_normalization:
          x = BatchNormalization()(x)
      if activation is not None:
          x = Activation(activation)(x)
      x = conv(x)
  return x

In [None]:
def resnet_v1(input_shape, depth, num_classes=10):
    
    if (depth - 2) % 6 != 0:
        raise ValueError('depth should be 6n+2 (eg 20, 32, 44 in [a])')
    # Start model definition.
    num_filters = 16
    num_res_blocks = int((depth - 2) / 6)

    inputs = Input(shape=input_shape)
    x = resnet_layer(inputs=inputs)
    # Instantiate the stack of residual units
    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack > 0 and res_block == 0:  # first layer but not first stack
                strides = 2  # downsample
            y = resnet_layer(inputs=x,num_filters=num_filters,strides=strides)
            y = resnet_layer(inputs=y,num_filters=num_filters,activation=None)
            if stack > 0 and res_block == 0:  # first layer but not first stack
                # linear projection residual shortcut connection to match
                # changed dims
                x = resnet_layer(inputs=x,num_filters=num_filters,kernel_size=1,strides=strides,activation=None,batch_normalization=False)
            x = keras.layers.add([x, y])
            x = Activation('relu')(x)
            x = Dropout(rate=0.25)(x)
        num_filters *= 2

    # Add classifier on top.
    # v1 does not use BN after last shortcut connection-ReLU
    x = AveragePooling2D(pool_size=7)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes, activation='softmax')(y)

    # Instantiate model.
    model = Model(inputs=inputs, outputs=outputs)
    return model

In [None]:
model = resnet_v1(input_shape=(28, 28, 1), depth=depth)

In [None]:
opt = optimizer=Adam(lr=lr_schedule(0))

model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])


earlyStopping = tensorflow.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',
    min_delta=0,
    patience=20,
    verbose=1,
    mode='auto',
    baseline=None,
    restore_best_weights=False
)

lr_scheduler = LearningRateScheduler(lr_schedule)

cb = [earlyStopping, lr_scheduler]


In [None]:
model.summary()

In [None]:
history = model.fit(train_images, train_labels,
                    batch_size=16,
                    epochs=100,
                    validation_data=(test_images, test_labels), shuffle=True, callbacks=cb)