In [None]:
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
    def __init__(
        self,
        num_classes: int,
        num_filters: list[int],
        filter_sizes: list[int],
        activation_fn: nn.Module,
        num_neurons_dense: int,
        dropout_prob: float = 0.0,
        use_batch_norm: bool = False
    ):
        super(SimpleCNN, self).__init__()
        layers = []
        in_channels = 3

        # Build convolutional feature extractor
        for out_ch, k in zip(num_filters, filter_sizes):
            layers.append(nn.Conv2d(in_channels, out_ch, kernel_size=k, padding=k//2))
            if use_batch_norm:
                layers.append(nn.BatchNorm2d(out_ch))
            layers.append(activation_fn)
            if dropout_prob > 0.0:
                layers.append(nn.Dropout2d(dropout_prob))
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            in_channels = out_ch

        self.features = nn.Sequential(*layers)
        # Global pooling to flatten spatial dims
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        # Classification head
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_channels, num_neurons_dense),
            activation_fn,
            nn.Dropout(dropout_prob) if dropout_prob > 0.0 else nn.Identity(),
            nn.Linear(num_neurons_dense, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        x = self.classifier(x)
        return x


SimpleCNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=2048, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=10, bias=True)
)
Output shape: torch.Size([4, 10])
