In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras 
from keras.models import Model, Sequential
from keras.layers import Dense, Conv2D, Flatten, Dropout, BatchNormalization, Add, AveragePooling2D, Input, ReLU
from keras.utils import to_categorical
from tensorflow import Tensor
from keras.callbacks import ModelCheckpoint

In [2]:
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [3]:
train_x.shape

(50000, 32, 32, 3)

In [4]:
train_x = np.reshape(train_x, [train_x.shape[0], train_x.shape[1], train_x.shape[2], train_x.shape[3]])
test_x = np.reshape(test_x, [test_x.shape[0], test_x.shape[1], test_x.shape[2], test_x.shape[3]])

In [5]:
train_x = train_x.astype('float') / 255
test_x = test_x.astype('float') / 255
train_y = to_categorical(train_y)
test_y = to_categorical(test_y)

In [6]:
def relu_bn(inputs:Tensor) -> Tensor:
  re = ReLU()(inputs)
  bn = BatchNormalization()(re)
  return bn

In [7]:
def residual_block(x:Tensor, filters:int, downsampling:bool, kernel_size:int=3) -> Tensor:
  y = Conv2D(filters=filters,
             kernel_size=kernel_size,
             strides=(1 if not downsampling else 2),
             padding = 'same'
             )(x)
  y = relu_bn(y)
  y = Conv2D(filters=filters,
             kernel_size=kernel_size,
             strides=1,
             padding='same'
             )(y)
  if downsampling:
    x = Conv2D(kernel_size=1,
               strides=2,
               filters=filters,
               padding='same')(x)
  out = tf.keras.layers.Add()([x, y])
  out = relu_bn(out)
  return out

In [8]:
def create_res_net():
  inputs = Input(shape=(32,32,3))
  num_filters=64
  num_blocks_list = [2,5,5,2]
  t = BatchNormalization()(inputs)
  t = Conv2D(kernel_size=3,
             filters=num_filters,
             padding='same',
             strides=1)(t)
  t = relu_bn(t)

  for i in range(len(num_blocks_list)):
    num_blocks = num_blocks_list[i]
    for j in range(num_blocks):
      t = residual_block(t,
                         downsampling=(j==0 and i!=0),
                         kernel_size=3,
                         filters=num_filters)
    num_filters *= 2
  
  t = AveragePooling2D(4)(t)
  t = Flatten()(t)
  outputs = Dense(10, activation='softmax')(t)
  model = Model(inputs, outputs)

  model.compile(optimizer=tf.keras.optimizers.Adam(0.01),
                loss = 'categorical_crossentropy',
                metrics = ['accuracy'])
  
  return model

In [9]:
model = create_res_net()
model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 32, 32, 3)    12          input_1[0][0]                    
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 64)   1792        batch_normalization[0][0]        
__________________________________________________________________________________________________
re_lu (ReLU)                    (None, 32, 32, 64)   0           conv2d[0][0]                     
_______________________________________________________________________________________

In [10]:
m = ModelCheckpoint('myfile.h5', monitor='val_loss', save_best_only=True)
model.fit(train_x, train_y, epochs=20, batch_size=500, validation_data=(test_x, test_y)) 

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7f0d2e5a5fd0>