In [2]:
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self,
                 input_channels=3,     # Number of input channels, e.g., 3 for RGB images
                 num_classes=10,       # Number of classes in the dataset
                 num_filters=32,       # m: Number of filters in each convolutional layer
                 kernel_size=3,        # k: Kernel size (k×k)
                 dense_neurons=128):   # n: Number of neurons in the dense (fully connected) layer
        super(SimpleCNN, self).__init__()

        # Convolution Block 1: Convolution -> Activation -> MaxPool
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=num_filters, kernel_size=kernel_size, padding=1)
        self.act1 = nn.ReLU()  # Activation function (can be replaced with GELU, SiLU, Mish, etc.)
        self.pool1 = nn.MaxPool2d(kernel_size=2)  # Halves the spatial dimensions

        # Convolution Block 2: You may extend the network with more blocks as needed
        self.conv2 = nn.Conv2d(in_channels=num_filters, out_channels=num_filters, kernel_size=kernel_size, padding=1)
        self.act2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)


        # After two pooling operations (with kernel size 2), the spatial size reduces to 8x8.
        flattened_size = num_filters * 8 * 8

        # Fully Connected (Dense) Layers:
        # First dense layer with configurable number of neurons
        self.fc = nn.Linear(in_features=flattened_size, out_features=dense_neurons)

        # Output layer: number of neurons equal to number of classes
        self.out = nn.Linear(in_features=dense_neurons, out_features=num_classes)

    def forward(self, x):
        # Apply first conv block
        x = self.conv1(x)   # Convolution
        x = self.act1(x)    # Activation
        x = self.pool1(x)   # Max pooling

        # Apply second conv block
        x = self.conv2(x)
        x = self.act2(x)
        x = self.pool2(x)

        # Flatten the output feature maps into a vector
        x = x.view(x.size(0), -1)

        # Pass through the fully connected layer then the output layer
        x = self.fc(x)
        x = self.out(x)
        return x

# Example usage:
if __name__ == "__main__":
    # Instantiate the network with example parameters
    model = SimpleCNN(input_channels=3, num_classes=10, num_filters=32, kernel_size=3, dense_neurons=128)
    print(model)

    # Create a sample input tensor (e.g., a batch of 4 RGB images of size 32x32)
    sample_input = torch.randn(4, 3, 32, 32)
    sample_output = model(sample_input)
    print("Output shape:", sample_output.shape)


SimpleCNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (act2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=2048, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=10, bias=True)
)
Output shape: torch.Size([4, 10])
