In [1]:
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, BatchNormalization, Activation, add, Flatten, AveragePooling2D, concatenate, Dense
from tensorflow.keras.models import Model

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

width = 32
height = 32

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [3]:
num_classes = 10
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)

In [4]:
validation_images, validation_labels = x_train[:500], y_train[:500]
train_images, validation = x_train[500:], y_train[500:]

In [5]:
train_images.shape

(49500, 32, 32, 3)

In [21]:
input_shape = x_train[0].shape
inputs = Input(shape=input_shape)

x = Conv2D(64, kernel_size=(7,7), strides=2, padding='same', activation='relu')(inputs)
x = BatchNormalization()(x)
x = Conv2D(192, kernel_size=(3,3), padding='same', activation='relu')(x)
x = BatchNormalization()(x)

x = inception(x, [64, 128, 32, 32])
print(x)
#x = inception(x, [128, 192, 96, 64])

None


In [22]:
def inception(x, filters):
    pre_layer = x
    f1, f2, f3, f4 = filters

    conv1 = Conv2D(f1, kernel_size=(1,1), padding='same', activation='relu')(pre_layer)

    conv2 = Conv2D(f4, kernel_size=(1,1), padding='same', activation='relu')(pre_layer)
    conv2 = Conv2D(f2, kernel_size=(3,3), padding='same', activation='relu')(conv2)

    conv3 = Conv2D(f4, kernel_size=(1,1), padding='same', activation='relu')(pre_layer)
    conv3 = Conv2D(f3, kernel_size=(5,5), padding='same', activation='relu')(conv3)

    max_pool = MaxPooling2D(pool_size=(3,3), strides=1, padding='same')(pre_layer)
    max_pool = Conv2D(f4, kernel_size=(1,1), padding='same')(max_pool)

    concat = concatenate([conv1, conv2, conv3, max_pool], axis=-1)
    return concat

In [27]:
input_shape = x_train[0].shape
inputs = Input(shape=input_shape)

x = Conv2D(64, kernel_size=(7,7), strides=2, padding='same', activation='relu')(inputs)
x = BatchNormalization()(x)
x = Conv2D(192, kernel_size=(3,3), padding='same', activation='relu')(x)
x = BatchNormalization()(x)

x = inception(x, [64, 128, 32, 32])
x = inception(x, [128, 192, 96, 64])
x = MaxPooling2D(pool_size=(3,3), strides=2, padding='same')(x)

x = inception(x, [192, 208, 48, 64])
aux1 = AveragePooling2D(pool_size=(5,5), strides=3, padding='valid')(x)
aux1 = Conv2D(128, kernel_size=(1,1), padding='same', activation='relu')(aux1)
Flatten()(aux1)
aux1 = Dense(512, activation='relu')(aux1)
aux1 = Dense(10, activation='softmax')(aux1)

x = inception(x,[160,224,64,64])
x = inception(x,[128,256,64,64])
x = inception(x, [112, 288, 64, 64])

aux2 = AveragePooling2D(pool_size=(5,5), strides=3, padding='valid')(x)
aux2 = Conv2D(128, kernel_size=(1,1), padding='same', activation='relu')(aux2)
Flatten()(aux2)
aux2 = Dense(512, activation='relu')(aux1)
aux2 = Dense(10, activation='softmax')(aux1)

x = inception(x, [256, 320, 128, 128])
x = inception(x, [256, 320, 128, 128])
x = inception(x, [384, 384, 128, 128])

x = AveragePooling2D(pool_size=(4,4), padding='valid')(x)
x = Dropout(0.4)(x)
x = Flatten()(x)

outputs = Dense(10, activation='softmax')(x)

model = Model(inputs = inputs, outputs = [aux1, aux2, outputs])
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_17 (InputLayer)          [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d_206 (Conv2D)            (None, 16, 16, 64)   9472        ['input_17[0][0]']               
                                                                                                  
 batch_normalization_28 (BatchN  (None, 16, 16, 64)  256         ['conv2d_206[0][0]']             
 ormalization)                                                                                    
                                                                                                  
 conv2d_207 (Conv2D)            (None, 16, 16, 192)  110784      ['batch_normalization_28[0][0

 )                                                                                                
                                                                                                  
 conv2d_227 (Conv2D)            (None, 8, 8, 160)    82080       ['concatenate_30[0][0]']         
                                                                                                  
 conv2d_229 (Conv2D)            (None, 8, 8, 224)    129248      ['conv2d_228[0][0]']             
                                                                                                  
 conv2d_231 (Conv2D)            (None, 8, 8, 64)     102464      ['conv2d_230[0][0]']             
                                                                                                  
 conv2d_232 (Conv2D)            (None, 8, 8, 64)     32832       ['max_pooling2d_36[0][0]']       
                                                                                                  
 concatena

                                                                                                  
 conv2d_254 (Conv2D)            (None, 8, 8, 320)    368960      ['conv2d_253[0][0]']             
                                                                                                  
 conv2d_256 (Conv2D)            (None, 8, 8, 128)    409728      ['conv2d_255[0][0]']             
                                                                                                  
 conv2d_257 (Conv2D)            (None, 8, 8, 128)    106624      ['max_pooling2d_40[0][0]']       
                                                                                                  
 concatenate_35 (Concatenate)   (None, 8, 8, 832)    0           ['conv2d_252[0][0]',             
                                                                  'conv2d_254[0][0]',             
                                                                  'conv2d_256[0][0]',             
          