In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

from keras.datasets import cifar10
from keras.applications import vgg16
from keras.utils import to_categorical
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D

Using TensorFlow backend.


In [0]:
(x_train, y_train),(x_test, y_test) = cifar10.load_data()

In [0]:
# Upscale x_train and y_train to fit minimum input_shape of VGG16 model that is (48,48)
x_train_resized = np.zeros(shape = (x_train.shape[0], 48, 48, 3))
x_test_resized = np.zeros(shape = (x_test.shape[0], 48, 48, 3))

for i, img in enumerate(x_train):
  resized_img = cv2.resize(img, dsize=(48, 48), interpolation=cv2.INTER_CUBIC)
  x_train_resized[i] = resized_img
  
for i, img in enumerate(x_test):
  resized_img = cv2.resize(img, dsize=(48, 48), interpolation=cv2.INTER_CUBIC)
  x_test_resized[i] = resized_img


In [0]:
x_train = x_train_resized
x_test  = x_test_resized

x_train = x_train.astype('float32')
x_test  = x_test.astype('float32')

x_train = vgg16.preprocess_input(x_train)
x_test  = vgg16.preprocess_input(x_test)

y_train = to_categorical(y_train)
y_test  = to_categorical(y_test)

In [0]:
base_model = vgg16.VGG16(include_top = False, input_shape = (48, 48, 3), weights = 'imagenet')

x = base_model.output
x = GlobalAveragePooling2D()(x)

x = Dense(1024, activation = 'relu')(x)

prediction = Dense(10, activation = 'softmax')(x)

model = Model(inputs=base_model.input, outputs=prediction)

In [0]:
for layer in base_model.layers:
  layer.trainable = False

In [0]:
model.compile(optimizer = 'adam', loss='categorical_crossentropy', metrics = ['accuracy'])


In [8]:
model.fit(x_train, y_train, batch_size = 128, epochs = 10, shuffle = True, validation_split = 0.2)

Train on 40000 samples, validate on 10000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f4fb9394f60>

In [9]:
model.evaluate(x_test, y_test)



[1.6849238495826722, 0.7087]