In [None]:
# Import necessary libraries
import tensorflow as tf
from tensorflow import keras # High-level API for building and training models

# Hyperparameters
num_epochs = 5 # Number of times the model will go through the entire dataset
num_classes = 10 # Number of output classes (digits 0-9)
batch_size = 100 # Number of samples per gradient update
learning_rate = 0.001 # Step size for optimization

# Load MNIST dataset using tensorflow.keras.datasets
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# Preprocess data (normalize pixel values between 0 and 1)
train_images = train_images.astype('float32') / 255.0
test_images = test_images.astype('float32') / 255.0

# Reshape images to add a channel dimension (required for CNNs)
train_images = train_images.reshape(-1, 28, 28, 1)
test_images = test_images.reshape(-1, 28, 28, 1)

# Define a simple CNN with two convolutional layers
class ConvNet(keras.Model):
    def __init__(self, num_classes=10): # Initialize the model
        super(ConvNet, self).__init__()  
        # First convolutional layer: 1 input channel, 16 output channels, 5x5 kernel, stride 1, padding 2
        self.layer1 = keras.Sequential([
            keras.layers.Conv2D(16, kernel_size=5, strides=1, padding='same', input_shape=(28, 28, 1)),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.MaxPooling2D(pool_size=(2, 2))
        ])

        # Second convolutional layer: 16 input channels, 32 output channels, 5x5 kernel, stride 1, padding 2
        self.layer2 = keras.Sequential([
            keras.layers.Conv2D(32, kernel_size=5, strides=1, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.MaxPooling2D(pool_size=(2, 2))
        ])

        # Fully connected (dense) layer: input 7*7*32 neurons, output num_classes (10 digits)
        self.fc = keras.layers.Dense(num_classes, activation='softmax') # Use softmax for probability distribution

    # Define the forward pass (how data flows through the network)
    def call(self, x):
        out = self.layer1(x) # Pass input through the first conv layer
        out = self.layer2(out) # Pass output through the second conv layer
        out = keras.layers.Flatten()(out) # Flatten the output for the fully connected layer
        out = self.fc(out) # Pass through the fully connected layer
        return out

# Instantiate the model
model = ConvNet(num_classes)

# Define loss function and optimizer (same as PyTorch example)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) # Use logits as input for softmax activation
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

