In [None]:
import numpy as np
import pandas as pd
from tensorflow.keras.layers import Dense, Dropout, Input
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Flatten, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10
#from keras.datasets import cifar100 #Replace use
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers import RMSprop

# load CIFAR dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
#(x_train, y_train), (x_test, y_test) = cifar100.load_data() #Replace use

num_labels = len(np.unique(y_train))
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

image_size = x_train.shape[1]
x_train = np.reshape(x_train,[-1, image_size, image_size, 3])
x_test = np.reshape(x_test,[-1, image_size, image_size, 3])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

def modl(noconvlyrs, dr, optm):
    # left branch of Y network
    left_inputs = Input(shape=(image_size, image_size, 3))
    x = left_inputs
    filters = 32
    for i in range(noconvlyrs):
        x = Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu', dilation_rate=1)(x)
        x = Dropout(dr)(x)
        x = MaxPooling2D((2,2))(x)
        filters *= 2

    # right branch of Y network
    right_inputs = Input(shape=(image_size, image_size, 3))
    y = right_inputs
    filters = 32
    for i in range(noconvlyrs):
        y = Conv2D(filters=filters, kernel_size=3, padding='same', activation='relu', dilation_rate=2)(y)
        y = Dropout(dr)(y)
        y = MaxPooling2D((2,2))(y)
        filters *= 2

    y = concatenate([x, y])
    
    y = Flatten()(y)
    y = Dropout(dr)(y)
    outputs = Dense(num_labels, activation='softmax')(y)

    model = Model([left_inputs, right_inputs], outputs)

    model.summary()

    if optm == 'SGD':
        opt = SGD(learning_rate=0.01, momentum=0.9)
    elif optm == 'ADAM':
        opt = Adam(learning_rate=0.01)
    elif optm == 'RMSProp':
        opt = RMSprop(learning_rate=0.01, momentum=0.9)

    model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

datagen = ImageDataGenerator(width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True)
it_train = datagen.flow([x_train, x_train], y_train, batch_size=64)
steps = int(x_train.shape[0] / 64)

model = modl(3, 0.2, 'SGD')
#model = modl(4, 0.2, 'SGD') #Replace model parameters as per the required architecture
history = model.fit(it_train, steps_per_epoch=steps, epochs=10, validation_data=([x_test, x_test], y_test))

score = model.evaluate([x_test, x_test], y_test, batch_size=64, verbose=0)
print("\nTest accuracy: %.1f%%" % (100.0 * score[1]))

In [None]:
from matplotlib import pyplot

def observe_plot(history):
	# plot loss
	pyplot.subplot(211)
	pyplot.title('Categorical Cross Entropy Loss')
	pyplot.plot(history.history['loss'], color='blue', label='train')
	pyplot.plot(history.history['val_loss'], color='orange', label='test')
	# plot accuracy
	pyplot.subplot(212)
	pyplot.title('Classification Accuracy')
	pyplot.plot(history.history['accuracy'], color='blue', label='train')
	pyplot.plot(history.history['val_accuracy'], color='orange', label='test')
 
observe_plot(history)

## Feature Map Computation

In [None]:
import matplotlib.pyplot as plt

print('Input image for feature map extraction is as follows:')
pyplot.imshow(x_test[10])
#pyplot.imshow(x_test[36]) #for CIFAR-100


In [None]:
from keras.models import Model
from matplotlib import pyplot
from numpy import expand_dims

lyrs = [2, 3, 8, 9, 14, 15]
#lyrs = [2, 3, 8, 9, 14, 15, 20, 21] #for 4-pair conv layers
outputs = [model.layers[i].output for i in lyrs]
model = Model(inputs=model.inputs, outputs=outputs)

ip_img = expand_dims(x_test[10], axis=0)
feature_maps = model.predict([ip_img, ip_img])
for fmap in feature_maps:
	axs = 1
	for _ in range(3):
		for _ in range(3):
			ax = pyplot.subplot(3, 3, axs)
			ax.set_xticks([])
			ax.set_yticks([])
			pyplot.imshow(fmap[0, :, :, axs-1], cmap='gray')
			axs += 1

	pyplot.show()