In [None]:
from torch import nn, Tensor, concatenate, norm, sigmoid
from torch.nn import ModuleList, TripletMarginLoss, BCEWithLogitsLoss
from torch.nn.functional import relu, normalize


class InceptionModule(nn.Module):
    NUM_FILTER_SETS = 3

    def __init__(self, in_dim: int, hidden_dim: int, bottleneck_dim: int, base_kernel_size: int,
                 residual: bool):
        min_base_kernel_size = 2 ** (self.NUM_FILTER_SETS - 1)
        assert base_kernel_size >= min_base_kernel_size, f'base kernel size must be {min_base_kernel_size} or greater'
        super().__init__()

        # The outputs from the filter sets will be concatenated feature-wise along with the parallel low pass filter.
        out_dim = hidden_dim * (self.NUM_FILTER_SETS + 1)
        filter_in_dim = in_dim
        if bottleneck_dim > 0 and in_dim > 1:
            self.bottleneck = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1, bias=False, padding='same')
            filter_in_dim = bottleneck_dim

        kernel_sizes = [base_kernel_size // (2 ** i) for i in range(self.NUM_FILTER_SETS)]

        self.filter_sets = ModuleList([
            nn.Conv1d(filter_in_dim, hidden_dim, kernel_size=ks, padding='same', bias=False) for ks in kernel_sizes])

        self.bn = nn.BatchNorm1d(out_dim)

        self.parallel_low_pass_filter = nn.Sequential(*[
            nn.MaxPool1d(kernel_size=3, stride=1, padding=1),
            nn.Conv1d(in_dim, hidden_dim, kernel_size=1, padding='same', bias=False)
        ])

        # Not sure if residual should be used at every module, but from the paper it doesn't seem to have a significant
        # effect anyway.
        if residual:
            self.residual = nn.Sequential(*[
                nn.Conv1d(in_dim, out_dim, kernel_size=1, bias=False, padding='same'),
                # Not sure if it's important that the residual has its own batch norm. Will keep it just in case.
                nn.BatchNorm1d(out_dim),
            ])

    def forward(self, x: Tensor) -> Tensor:
        org_x = x
        if self.bottleneck is not None:
            x = self.bottleneck(x)

        filter_outputs = []
        for filter_set in self.filter_sets:
            filter_outputs.append(filter_set(x))

        filter_outputs.append(self.parallel_low_pass_filter(org_x))

        x = concatenate(filter_outputs, dim=1)
        x = self.bn(x)

        if self.residual is not None:
            x = x + self.residual(org_x)

        return relu(x)

