Import necessary libraries

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np

Class for kernel plotting

In [None]:
class PlotKernelCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        conv_layers = [layer for layer in self.model.layers if isinstance(layer, tf.keras.layers.Conv2D)]
        for i, layer in enumerate(conv_layers):
            weights = layer.get_weights()[0]
            num_kernels = weights.shape[3]

            # Determine how many rows and columns to use for subplots
            num_cols = min(8, num_kernels)  # Maximum of 8 kernels per row
            num_rows = (num_kernels + num_cols - 1) // num_cols

            plt.figure(figsize=(15, num_rows * 3))  # Adjust figsize for better visibility

            for j in range(num_kernels):
                plt.subplot(num_rows, num_cols, j + 1)
                plt.imshow(weights[:, :, 0, j], cmap='viridis')
                plt.axis('off')
                plt.title(f'Kernel {j + 1}')

            plt.suptitle(f'Epoch {epoch + 1}, Layer {i + 1}')
            plt.show()

Load and preprocess dataset

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

Create fully connected network layers

In [None]:
input_layer_fc = Input(shape=(28, 28, 1))
flatten_fc = Flatten()(input_layer_fc)
dense1_fc = Dense(128, activation='relu')(flatten_fc)
dense2_fc = Dense(64, activation='relu')(dense1_fc)
output_layer_fc = Dense(10, activation='softmax')(dense2_fc)
fc_model = tf.keras.Model(inputs=input_layer_fc, outputs=output_layer_fc)
fc_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

Create CNN layers

In [None]:
input_layer_cnn = Input(shape=(28, 28, 1))
conv1 = Conv2D(32, (3, 3), activation='relu')(input_layer_cnn)
pool1 = MaxPooling2D((2, 2))(conv1)
conv2 = Conv2D(64, (3, 3), activation='relu')(pool1)
pool2 = MaxPooling2D((2, 2))(conv2)
flatten_cnn = Flatten()(pool2)
dense1_cnn = Dense(128, activation='relu')(flatten_cnn)
output_layer_cnn = Dense(10, activation='softmax')(dense1_cnn)
cnn_model = tf.keras.Model(inputs=input_layer_cnn, outputs=output_layer_cnn)
cnn_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

Parameters for training and plotting

In [None]:
epoch_num = 3
plot_every_kernel = False # Plots every kernel after each epoch if True 

Train model

In [None]:

if plot_every_kernel:
    callbacks = [PlotKernelCallback()]
else:
    callbacks = []

fc_train_accuracy_history = []
fc_val_accuracy_history = []

cnn_train_accuracy_history = []
cnn_val_accuracy_history = []

for _ in range(epoch_num):
    fc_history = fc_model.fit(x_train, y_train, epochs=1, batch_size=32, validation_data=(x_test, y_test))
    cnn_history = cnn_model.fit(x_train, y_train, epochs=1, batch_size=32, validation_data=(x_test, y_test), callbacks=callbacks)

    fc_train_accuracy_history.append(fc_history.history['accuracy'][0])
    cnn_train_accuracy_history.append(cnn_history.history['accuracy'][0])

    fc_val_accuracy_history.append(fc_history.history['val_accuracy'][0])
    cnn_val_accuracy_history.append(cnn_history.history['val_accuracy'][0])

Show results

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, epoch_num + 1), fc_train_accuracy_history, label='FC Train Accuracy')
plt.plot(range(1, epoch_num + 1), cnn_train_accuracy_history, label='CNN Train Accuracy')
plt.ylim([0.92, 1])
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training Accuracy with FC and CNN')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, epoch_num + 1), fc_val_accuracy_history, label='FC Validation Accuracy')
plt.plot(range(1, epoch_num + 1), cnn_val_accuracy_history, label='CNN Validation Accuracy')
plt.ylim([0.92, 1])
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy with FC and CNN')
plt.legend()

plt.tight_layout()
plt.show()

Create parameters for plots

In [None]:
sample_index = 100 # Change to desired index
layer_name = 'conv2d_1'  # Change to the desired layer name

Apply kernels on sample image

In [None]:
def visualize_conv_layer_output(model, image, layer_name, num_cols):

    sub_model = tf.keras.models.Model(inputs=model.inputs, outputs=model.get_layer(layer_name).output)

    activations = sub_model.predict(np.expand_dims(image, axis=0))

    num_filters = activations.shape[-1]
    num_rows = (num_filters + num_cols - 1) // num_cols

    plt.figure(figsize=(num_cols * 2, num_rows * 2))
    for i in range(num_filters):
        plt.subplot(num_rows, num_cols, i + 1)
        plt.imshow(activations[0, :, :, i], cmap='viridis')
        plt.axis('off')
        plt.title(f'Filter {i + 1}')

    plt.show()


num_cols = 8  # Change to the desired number of columns

visualize_conv_layer_output(cnn_model, x_train[sample_index], layer_name, num_cols)

In [None]:
def visualize_kernel_and_result(model, image, layer_name, filter_index = 0):

    sub_model = tf.keras.models.Model(inputs=model.inputs, outputs=model.get_layer(layer_name).output)

    activations = sub_model.predict(np.expand_dims(image, axis=0))
    num_filters = activations.shape[-1]

    plt.figure(figsize=(8, 4))

    for filter_index in range(num_filters):

        kernel_weights = model.get_layer(layer_name).get_weights()[0]
        kernel = kernel_weights[:, :, 0, filter_index]
  
        plt.subplot(1, 2, 1)
        plt.imshow(kernel, cmap='viridis')
        plt.axis('off')
        plt.title(f'Kernel {filter_index + 1}')

        plt.subplot(1, 2, 2)
        plt.imshow(activations[0, :, :, filter_index], cmap='gray')
        plt.axis('off')
        plt.title('Filtered Image')

        plt.show()

visualize_kernel_and_result(cnn_model, x_train[sample_index], layer_name)