In [0]:
from keras.datasets import cifar10, cifar100
from keras.models import Model
from keras.optimizers import RMSprop, Adam
from keras import layers
from keras.utils import to_categorical

In [0]:
def residual_block(x, channels, strides=(1,1), first_block=False):
  shortcut_link = x
  
  x = layers.Conv2D(channels[0], (1,1), strides=strides, padding='valid')(x)
  x = layers.BatchNormalization(axis=3)(x)
  x = layers.Activation('relu')(x)
  
  x = layers.Conv2D(channels[1], (3,3), padding='same')(x)
  x = layers.BatchNormalization(axis=3)(x)
  x = layers.Activation('relu')(x)
  
  x = layers.Conv2D(channels[2], (1,1), padding='valid')(x)
  x = layers.BatchNormalization(axis=3)(x)
  
  if first_block:
    shortcut_link = layers.Conv2D(channels[2], (1,1), strides=strides, padding='same')(shortcut_link)
    shortcut_link = layers.BatchNormalization(axis=3)(shortcut_link)
  
  x = layers.Add()([x, shortcut_link])
  x = layers.Activation('relu')(x)
  
  return x

In [0]:
def ResNet(input_shape, num_classes):  
  x_input = layers.Input(input_shape)
  
  x = layers.ZeroPadding2D((224-input_shape[0],224-input_shape[0]))(x_input) # output: 224,224,3
  
  x = layers.Conv2D(64, (7,7), strides=(2,2), padding='same')(x) # output: 224,224,64
  x = layers.BatchNormalization(axis=3)(x)
  x = layers.Activation('relu')(x)
  
  x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) # output: 224,224,64
  
  x = residual_block(x, channels=[64, 64, 256], first_block=True) # 54,54,256
  x = residual_block(x, channels=[64, 64, 256]) # 54,54,256
  x = residual_block(x, channels=[64, 64, 256]) # 54,54,256
  
  x = residual_block(x, channels=[128, 128, 512], strides=(2,2), first_block=True) # 25,25,512
  x = residual_block(x, channels=[128, 128, 512])
  x = residual_block(x, channels=[128, 128, 512])
  x = residual_block(x, channels=[128, 128, 512])
  
  x = residual_block(x, channels=[256, 256, 1024], strides=(2,2), first_block=True)
  x = residual_block(x, channels=[256, 256, 1024])
  x = residual_block(x, channels=[256, 256, 1024])
  x = residual_block(x, channels=[256, 256, 1024])
  x = residual_block(x, channels=[256, 256, 1024])
  x = residual_block(x, channels=[256, 256, 1024])
  
  x = residual_block(x, channels=[512, 512, 2048], strides=(2,2), first_block=True)
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  
  x = residual_block(x, channels=[512, 512, 2048], strides=(2,2), first_block=True)
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  x = residual_block(x, channels=[512, 512, 2048])
  
#   x = residual_block_first(x, channels=[2048, 2048, 8192], strides=(2,2))
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
#   x = residual_block(x, channels=[2048, 2048, 8192])
  
  x = residual_block(x, channels=[1024, 1024, 4096], strides=(2,2), first_block=True)
  x = residual_block(x, channels=[1024, 1024, 4096])
  x = residual_block(x, channels=[1024, 1024, 4096])
  
  x = layers.AveragePooling2D(pool_size=(2,2))(x)
  
  x = layers.Flatten()(x)
  x = layers.Dense(512, activation='relu')(x)
  x = layers.Dense(512, activation='relu')(x)
  x = layers.Dense(num_classes, activation='softmax')(x)
  
  model = Model(inputs=[x_input], outputs=[x])
  
  return model

In [0]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# (x_train, y_train), (x_test, y_test) = cifar100.load_data()

x_train, y_train, x_test, y_test = x_train / 255, y_train, x_test / 255, y_test

print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [0]:
model = ResNet(input_shape=(32,32,3), num_classes=10)

model.compile(optimizer=Adam(lr=0.000001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=100, batch_size=8, validation_split=0.1)