In [1]:
from python_lib.ModelGenerator import *

## Parameters I wish to set

Here, I'm only concern about the DIMENSIONS of the inputs/outputs of the layers.

In [2]:
input_shape = (2, 64)

channels = [8, 8, 8, 8, 16]
kernel_sizes = [5, 3, 3, 3, 1]
dilations = [1, 2, 3, 4, 1]
attention_channels = 128
res2net_scale = 8
se_channels = 128
lin_neurons = 6

## Initialisation of the various layers

Note: It's the ECAPA TDNN model as reflected in ./modules

In [3]:
blocks = []
blocks.append(
    TDNNBlock(
        channel_in=2,
        channel_out=channels[0],
        kernel=kernel_sizes[0],
        dilation=dilations[0],
    )
)

for i in range(1, len(channels) - 1):
    blocks.append(
        SERes2NetBlock(
            res2net_scale=res2net_scale,
            se_channel=se_channels,
            channel_in=channels[i - 1],
            channel_out=channels[i],
            kernel=kernel_sizes[i],
            dilation=dilations[i],
        )
    )

mfa = TDNNBlock(
    channels[-2] * (len(channels) - 2),
    channels[-1],
    kernel_sizes[-1],
    dilations[-1],
)

# Attentive Statistical Pooling
asp = ASP(
    channels[-1],
    attention_channels=attention_channels,
)
asp_bn = BatchNorm1d(channel=channels[-1] * 2)

dense = Dense(lin_neurons)

## Forward Feed

In [4]:
xl = []
x = input_shape
for block in blocks:
    x = block.forward(x)
    xl.append(x)

x = Cat(xl[1:])

x = mfa.forward(x)
x = asp.forward(x)
x = asp_bn.forward(x)

# flatten
x = x[0] * x[1]

y = dense.forward(x)

## Print the various layers

In [5]:
for block in blocks:
    print(block.initialisedstring())

print(mfa.initialisedstring())
print(asp.initialisedstring())
print(asp_bn.initialisedstring())
print(dense.initialisedstring())

TDNNBlock<5, 1, 2, 8, 1, 64, 64, 4, float> TDNNBlock_0; 
float x0[8][64];
SERes2NetBlock<3, 8, 8, 2, 64, 64, 4, 8, 128, float> SERes2NetBlock_1; 
float x1[8][64];
SERes2NetBlock<3, 8, 8, 3, 64, 64, 6, 8, 128, float> SERes2NetBlock_2; 
float x2[8][64];
SERes2NetBlock<3, 8, 8, 4, 64, 64, 8, 8, 128, float> SERes2NetBlock_3; 
float x3[8][64];
TDNNBlock<1, 1, 24, 16, 1, 64, 64, 0, float> TDNNBlock_4; 
float x4[16][64];
ASP<16, 128, 64, 1, float> ASP_5; 
float x5[32][1];
BatchNorm1d<32, 1, float> BatchNorm1d_6; 
float x6[32][1];
Dense<32, 6, float> Dense_0; 
float y0[6];
