In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        # First conv block: in_channels=1 (grayscale), out_channels=32 filters, 3x3 kernel
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)   # output: [32,28,28]
        self.pool1 = nn.MaxPool2d(2, 2)                          # output: [32,14,14]

        # Second conv block: input 32 → 64 filters
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # output: [64,14,14]
        self.pool2 = nn.MaxPool2d(2, 2)                          # output: [64,7,7]

        # Fully connected layers
        self.fc1 = nn.Linear(64*7*7, 128)  # Flatten [64,7,7] → 3136
        self.fc2 = nn.Linear(128, 10)      # 10 logits

    def forward(self, x):
        # Convolutional feature extraction
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))

        # Flatten
        x = x.view(x.size(0), -1)   # [B, 64*7*7]

        # Fully connected
        x = F.relu(self.fc1(x))
        x = self.fc2(x)             # logits, no softmax here
        return x
