In [1]:
import tensorflow as tf
from tensorflow.keras import datasets, models, layers
from sklearn.model_selection import train_test_split

In [7]:
# A function to return a CNN model based on input hyperparameters
def my_cnn(m, k, n, input_shape, n_output, dropout_conv = 0.0, dropout_dense = 0.0, batch_normalize = False):
    model = models.Sequential()
    
    if batch_normalize:
        model.add(layers.BatchNormalization(input_size=input_size))
    for ii in range(0, 2): # Number of convolutional layers is 2
        if ii == 0:
            model.add(layers.Conv2D(m, (k, k), activation='relu', input_shape=input_shape))
        else:
            model.add(layers.Conv2D(m, (k, k), activation='relu'))
        model.add(layers.MaxPooling2D((2, 2)))
        if dropout_conv != 0.0:
            model.add(layers.Dropout(dropout_conv))
    
    model.add(layers.Flatten())
    model.add(layers.Dense(n, activation='relu'))
    if dropout_dense != 0.0:
        model.add(layers.Dropout(dropout_dense))
    model.add(layers.Dense(n_output))
    
    return model

In [4]:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0

X_train, X_val, y_train, y_val = train_test_split(train_images, train_labels,test_size=0.1, random_state=69)

print(X_train.shape, y_train.shape)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
(45000, 32, 32, 3) (45000, 1)


In [8]:
my_model = my_cnn(16, 3, 64, (32, 32, 3), 10)
my_model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_4 (Conv2D)            (None, 30, 30, 16)        448       
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 15, 15, 16)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 13, 13, 16)        2320      
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 6, 6, 16)          0         
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                36928     
_________________________________________________________________
dense_1 (Dense)              (None, 10)               