In [None]:
import tensorflow as tf
tf.test.is_gpu_available()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!wget -q --no-check-certificate https://storage.googleapis.com/laurencemoroney-blog.appspot.com/horse-or-human.zip
!unzip -q horse-or-human.zip -d train_data
!wget -q --no-check-certificate https://storage.googleapis.com/laurencemoroney-blog.appspot.com/validation-horse-or-human.zip
!unzip -q validation-horse-or-human.zip -d validation_data
!wget -q --no-check-certificate https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5

In [None]:
local_weights_file = './inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'

In [None]:
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras.applications.inception_v3 import InceptionV3

pre_trained_model = InceptionV3(input_shape = (300, 300, 3),
                                include_top = False,
                                weights = None)
pre_trained_model.load_weights(local_weights_file)

In [None]:
for layer in pre_trained_model.layers:
    layer.trainable = False

In [None]:
last_layer = pre_trained_model.get_layer('mixed7')
last_output = last_layer.output

In [None]:
from tensorflow.keras.optimizers import RMSprop

new_layers = layers.Flatten()(last_output)
new_layers = layers.Dense(1024, activation='relu')(new_layers)
new_layers = layers.Dropout(0.2)(new_layers)
new_layers = layers.Dense(1, activation='sigmoid')(new_layers)

model = Model(pre_trained_model.input, new_layers)
model.compile(optimizer = RMSprop(lr=0.001),
              loss = 'binary_crossentropy',
              metrics = ['acc'])

In [None]:
train_folder = './train_data'
valid_folder = './validation_data'

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_data_gen = ImageDataGenerator(rescale = 1./255,
                                    rotation_range = 40,
                                    width_shift_range = 0.2,
                                    height_shift_range = 0.2,
                                    shear_range = 0.2,
                                    zoom_range = 0.2,
                                    horizontal_flip = True)
# Validation data should not be augmented
valid_data_gen = ImageDataGenerator(rescale = 1./255)

train_generator = train_data_gen.flow_from_directory(train_folder,
                                                     batch_size = 20,
                                                     class_mode = 'binary',
                                                     target_size = (300, 300))

valid_generator = valid_data_gen.flow_from_directory(valid_folder,
                                                     batch_size = 20,
                                                     class_mode = 'binary',
                                                     target_size = (300, 300))

In [None]:
history = model.fit_generator(
        generator = train_generator,
        validation_data = valid_generator,
        steps_per_epoch = 32,
        epochs = 20,
        validation_steps = 8,
        verbose = 1)

In [None]:
import matplotlib.pyplot as plt

#-----------------------------------------------------------
# Retrieve a list of list results on training and test data
# sets for each training epoch
#-----------------------------------------------------------
acc      = history.history[     'acc' ]
val_acc  = history.history[ 'val_acc' ]
loss     = history.history[    'loss' ]
val_loss = history.history['val_loss' ]

epochs   = range(len(acc)) # Get number of epochs

#------------------------------------------------
# Plot training and validation accuracy per epoch
#------------------------------------------------
plt.plot  ( epochs,     acc )
plt.plot  ( epochs, val_acc )
plt.title ('Training and validation accuracy')
plt.figure()

#------------------------------------------------
# Plot training and validation loss per epoch
#------------------------------------------------
plt.plot  ( epochs,     loss )
plt.plot  ( epochs, val_loss )
plt.title ('Training and validation loss'   )