# Question 1

In [None]:
class CNN(nn.Module):
    def __init__(self, n_filters, size_filters, size_fc, activation='relu', filter_organization='double', batch_normalization='yes', dropout=0.0):
        super(CNN, self).__init__()

        if activation == 'relu':
            self.activation = F.relu
        elif activation == 'gelu':
            self.activation = F.gelu
        elif activation == 'silu':
            self.activation = F.silu
        elif activation == 'mish':
            self.activation = self.mish
        else:
            raise ValueError("Invalid activation function. Choose from 'relu', 'gelu', 'silu', 'mish'.")

        if batch_normalization == 'yes':
            self.use_batch_norm = True
        else:
            self.use_batch_norm = False

        self.dropout_rate = dropout  # Set dropout rate

        if filter_organization == 'same':
            filter_sizes = [n_filters] * 5  # Same number of filters in all layers
        elif filter_organization == 'double':
            filter_sizes = [n_filters * (2**i) for i in range(5)]  # Double filters in each subsequent layer
        elif filter_organization == 'halve':
            filter_sizes = [n_filters // (2**i) for i in range(5)]  # Halve filters in each subsequent layer
        else:
            raise ValueError("Invalid filter organization. Choose from 'same', 'double', 'halve'.")

        self.conv1 = nn.Conv2d(3, filter_sizes[0], kernel_size=size_filters)
        if self.use_batch_norm:
            self.bn1 = nn.BatchNorm2d(filter_sizes[0])
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(filter_sizes[0], filter_sizes[1], kernel_size=size_filters)
        if self.use_batch_norm:
            self.bn2 = nn.BatchNorm2d(filter_sizes[1])
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(filter_sizes[1], filter_sizes[2], kernel_size=size_filters)
        if self.use_batch_norm:
            self.bn3 = nn.BatchNorm2d(filter_sizes[2])
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.Conv2d(filter_sizes[2], filter_sizes[3], kernel_size=size_filters)
        if self.use_batch_norm:
            self.bn4 = nn.BatchNorm2d(filter_sizes[3])
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.Conv2d(filter_sizes[3], filter_sizes[4], kernel_size=size_filters)
        if self.use_batch_norm:
            self.bn5 = nn.BatchNorm2d(filter_sizes[4])
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.dropout = nn.Dropout(p=self.dropout_rate)  # Dropout layer with specified rate

        self.fc_input_size = self._get_fc_input_size(size_filter_last=filter_sizes[4], kernel_conv=size_filters)

        self.fc1 = nn.Linear(self.fc_input_size, size_fc)
        self.fc2 = nn.Linear(size_fc, 10)

    def _get_fc_input_size(self, size_filter_last, stride_conv=1, stride_pool=2, kernel_conv=5, kernel_pool=2, input_size=244):
        for _ in range(5):  # Number of conv-pool layers
            input_size = ((input_size - kernel_conv) // stride_conv) + 1  # Adjusted for kernel size and stride of conv layers
            input_size = ((input_size - kernel_pool) // stride_pool) + 1  # Adjusted for kernel size, stride, and pooling of pool layers
        return size_filter_last * input_size * input_size

    def forward(self, x):
        x = self.pool1(self.activation(self.bn1(self.conv1(x)))) if self.use_batch_norm else self.pool1(self.activation(self.conv1(x)))
        x = self.pool2(self.activation(self.bn2(self.conv2(x)))) if self.use_batch_norm else self.pool2(self.activation(self.conv2(x)))
        x = self.pool3(self.activation(self.bn3(self.conv3(x)))) if self.use_batch_norm else self.pool3(self.activation(self.conv3(x)))
        x = self.pool4(self.activation(self.bn4(self.conv4(x)))) if self.use_batch_norm else self.pool4(self.activation(self.conv4(x)))
        x = self.pool5(self.activation(self.bn5(self.conv5(x)))) if self.use_batch_norm else self.pool5(self.activation(self.conv5(x)))

        x = self.dropout(x) if self.dropout_rate > 0 else x

        x = x.view(-1, self.fc_input_size)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def mish(self, x):
        return x * torch.tanh(F.softplus(x))