In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self, input_shape, num_classes, num_filters, filter_size, activation_conv, activation_dense, num_neurons_dense):
        super(CNN, self).__init__()
        self.conv_layers = self._create_conv_layers(input_shape[0], num_filters, filter_size, activation_conv)
        self.fc_layers = nn.Sequential(
            nn.Linear(256 * 7 * 7, num_neurons_dense),
            activation_dense,
            nn.Linear(num_neurons_dense, num_classes)
        )

    def _create_conv_layers(self, input_channels, num_filters, filter_size, activation_conv):
        layers = []
        in_channels = input_channels
        for _ in range(5):  # Reduced to 5 convolutional layers
            layers += [
                nn.Conv2d(in_channels, num_filters, filter_size, padding=1),
                activation_conv,
                nn.MaxPool2d(kernel_size=2, stride=2)
            ]
            in_channels = num_filters
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

# Example parameters
input_shape = (3, 224, 224)  # Example shape compatible with iNaturalist dataset
num_classes = 10  # Number of classes in iNaturalist dataset
num_filters = 32  # Number of filters in convolutional layers
filter_size = 3  # Size of filters

# Define activation functions for convolutional and dense layers
activation_conv = nn.ReLU(inplace=True)  # Activation function for convolutional layers
activation_dense = nn.ReLU(inplace=True)  # Activation function for dense layer

num_neurons_dense = 1024  # Number of neurons in dense layer

# Create the model
model = CNN(input_shape, num_classes, num_filters, filter_size, activation_conv, activation_dense, num_neurons_dense)

# Display model summary
print(model)




CNN(
  (conv_layers): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU(inplace=True)
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layers): Sequenti