# Transfer Learning using Keras Functional API in TensorFlow 2.0

Create the model architecture using Transfer Learning method in TensorFlow 2.0 Keras API.

Note: We have extracted features from pretrained model mobile-net and then added custom classification layers.

In [3]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.applications import MobileNet


def create_model(width, height, depth, classes):

	# initialize the input shape
	inputShape = (height, width, depth)

	#extract features from pretrained model mobile-net
	base_model = MobileNet(weights='imagenet',
						   include_top=False, input_shape=inputShape)
	x = base_model.output

	# define classification layers
	x = Flatten()(x)
	x = Dense(classes)(x)
	preds = Activation("softmax")(x)

	# create the model
	model = Model(inputs=base_model.input, outputs=preds)

	# return the constructed model
	return model

In [13]:
import tensorflow as tf
import datetime

# Clear any logs from previous runs
!rm -rf ./logs/ 

Intilalize Model Parameters and the dataset here. 

Note: For easy to run dummy example, I have used EPOCHS=1. Use EPOCHS=100 for quality results. 

In [5]:
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report



# initialize the learning rate, batch size and epochs to train the model
LR = 1e-2
BATCH_SIZE = 128

#CHANGE IT TO 100.
EPOCHS = 1

# initialize the number of classes
num_classes = 10

# load the CIFAR-10 dataset
print("[LOGGING] loading CIFAR-10 dataset...")
training_tuple, testing_tuple = cifar10.load_data()

((trainX, trainY), (testX, testY)) = training_tuple, testing_tuple

# normalize the data into the range [0, 1]
trainX = trainX.astype("float32") / 255.0
testX = testX.astype("float32") / 255.0

# convert the labels from integers to vectors
label_bin = LabelBinarizer()
trainY = label_bin.fit_transform(trainY)
testY = label_bin.transform(testY)

[LOGGING] loading CIFAR-10 dataset...


Call the ImageDataGenerator for data augmentation on CIFAR-10 image dataset

In [7]:
# declare the image generator for data augmentation
data_aug = ImageDataGenerator(rotation_range=15,
                               width_shift_range=0.1,
                               height_shift_range=0.1,
                               shear_range=0.01,
                               zoom_range=[0.9, 1.25],
                               horizontal_flip=True,
                               vertical_flip=False,
                               fill_mode='reflect',
                               data_format='channels_last',
                               brightness_range=[0.5, 1.5])

In [19]:
# Call model created using Keras Functional API in TensorFlow 2.0
print("[LOGGING] using transfer learning & functional model api...")
model = create_model(32, 32, 3, num_classes)

# Display the model's architecture
model.summary()

[LOGGING] using transfer learning & functional model api...




_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         (None, 32, 32, 3)         0         
_________________________________________________________________
conv1_pad (ZeroPadding2D)    (None, 33, 33, 3)         0         
_________________________________________________________________
conv1 (Conv2D)               (None, 16, 16, 32)        864       
_________________________________________________________________
conv1_bn (BatchNormalization (None, 16, 16, 32)        128       
_________________________________________________________________
conv1_relu (ReLU)            (None, 16, 16, 32)        0         
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D)  (None, 16, 16, 32)        288       
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 16, 16, 32)        128       
__________

In [20]:
# initialize the optimizer compile the model and
opt = SGD(lr=LR, momentum=0.9, decay=LR / EPOCHS)
print("[LOGGING] training network...")
model.compile(loss="categorical_crossentropy", optimizer=opt,
	metrics=["accuracy"])

log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)


[LOGGING] training network...


Using TensorBoard with Keras Model

In [16]:
# train the network
model_history = model.fit_generator(
	data_aug.flow(trainX, trainY, batch_size=BATCH_SIZE),
	validation_data=(testX, testY),
	steps_per_epoch=trainX.shape[0] // BATCH_SIZE,
	epochs=EPOCHS,
    callbacks=[tensorboard_callback],
	verbose=1)



In [21]:
# Save the entire model to a HDF5 file.
model.save('CIFAR_model.h5') 

See all the metrics and results on Tensorboard loaded in notebook using TensorFlow 2.0

In [17]:
%tensorboard --logdir logs/fit

Evaluation of the deep neural model , we have just trained.

Note: These results are dummy using EPOCHS=1. Use EPOCHS=100 for quality results. 

In [10]:
# evaluate the network
print("[LOGGING] evaluating network...")
predictions = model.predict(testX, batch_size=BATCH_SIZE)
print(classification_report(testY.argmax(axis=1),
	predictions.argmax(axis=1)))

[LOGGING] evaluating network...
              precision    recall  f1-score   support

           0       0.00      0.00      0.00      1000
           1       0.00      0.00      0.00      1000
           2       0.00      0.00      0.00      1000
           3       0.00      0.00      0.00      1000
           4       0.10      1.00      0.18      1000
           5       0.00      0.00      0.00      1000
           6       0.00      0.00      0.00      1000
           7       0.00      0.00      0.00      1000
           8       0.00      0.00      0.00      1000
           9       0.00      0.00      0.00      1000

   micro avg       0.10      0.10      0.10     10000
   macro avg       0.01      0.10      0.02     10000
weighted avg       0.01      0.10      0.02     10000



  'precision', 'predicted', average, warn_for)
