In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense

# Load and preprocess MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)

# Create the CNN model
input_layer = Input(shape=(28, 28, 1))
conv1 = Conv2D(32, (3, 3), activation='relu')(input_layer)
pool1 = MaxPooling2D((2, 2))(conv1)
conv2 = Conv2D(64, (3, 3), activation='relu')(pool1)
pool2 = MaxPooling2D((2, 2))(conv2)
conv3 = Conv2D(64, (3, 3), activation='relu')(pool2)
flatten = Flatten()(conv3)
dense1 = Dense(64, activation='relu')(flatten)
output_layer = Dense(10, activation='softmax')(dense1)

model = Model(inputs=input_layer, outputs=output_layer)

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.2)

# Choose a sample input from the test set
sample_index = 0
sample_input = x_test[sample_index]

# Define a new model to visualize intermediate layers
visualization_model = Model(inputs=model.input, outputs=[conv1, conv2, conv3])

# Get the activations for the sample input
activations = visualization_model.predict(np.expand_dims(sample_input, axis=0))

# Visualize the feature maps learned by the convolutional layers
layer_names = ['Convolutional Layer 1', 'Convolutional Layer 2', 'Convolutional Layer 3']
for i, layer_activation in enumerate(activations[::-1]):
    num_features = layer_activation.shape[-1]
    size = layer_activation.shape[1]
    cols = num_features // 8
    rows = 8

    fig, axes = plt.subplots(rows, cols, figsize=(cols, rows))
    for j in range(rows):
        for k in range(cols):
            ax = axes[j, k]
            ax.matshow(layer_activation[0, :, :, j * cols + k], cmap='viridis')
            ax.axis('off')
            ax.set_title(f'Feature {j * cols + k + 1}')

    fig.suptitle(layer_names[i], fontsize=16)
    plt.show()


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Epoch 1/5
Epoch 2/5