In [None]:
import torch
import torch.nn as nn

In [None]:
class TestModel(nn.Module):
    global folder_df

    def __init__(self, input_shape: torch.Size, dropout_rate: float = 0):
        super().__init__()

        total_output_classes = len(folder_df["folder"].unique())

        # input shape should be some list/tuple of length 4
        if len(input_shape) != 4: return Exception("Input shape is not AxBxCxD.")
        
        A = input_shape[0]
        B = input_shape[1]
        C = input_shape[2]
        D = input_shape[3]

        self.relu = nn.ReLU() # relu does not have trainable parameters, thus, can be reused

        self.conv1 = nn.Conv2d(in_channels=B, out_channels=10, kernel_size=(3, 3))
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.drop1 = nn.Dropout(p=dropout_rate)
        self.conv2 = nn.Conv2d(in_channels=10, out_channels=10, kernel_size=(3, 3))
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.drop2 = nn.Dropout(p=dropout_rate)
        self.flat = nn.Flatten()

        conv_layers = 3
        pool_layers = 3
        final_width = C
        final_height = D

        # ASSUMES KERNEL SIZE IS 3 AND 2 FOR CONV AND POOL LAYERS
        while conv_layers > 0 and pool_layers > 0:
            if conv_layers > 0:
                final_width = final_width - 2
                final_height = final_height - 2
                conv_layers -= 1
            if pool_layers > 0:
                final_width = final_width // 2
                final_height = final_height // 2
                pool_layers -= 1

        flatten_nodes = 10 * final_width * final_height
        
        self.norm = nn.BatchNorm1d(num_features=flatten_nodes)
        self.linear1 = nn.Linear(in_features=flatten_nodes, out_features=1024)
        self.linear2 = nn.Linear(in_features=1024, out_features=512)
        self.linear3 = nn.Linear(in_features=512, out_features=128) 
        self.output = nn.Linear(in_features=128, out_features=total_output_classes)

    def forward(self, x):
        # define calculations here
        x = self.conv1(x)
        x = self.relu(x)

        x = self.pool1(x)
        x = self.drop1(x)

        x = self.conv2(x)
        x = self.relu(x)

        x = self.pool2(x)
        x = self.drop2(x)

        x = self.flat(x)
        x = self.norm(x)

        x = self.linear1(x)
        x = self.relu(x)

        x = self.linear2(x)
        x = self.relu(x)

        x = self.linear3(x)
        x = self.relu(x)

        x = self.output(x)
        x = self.relu(x)

        return x