In [22]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

###
class Conv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding=0):
        super().__init__()
        self.c = out_channels
        self.filter_out = nn.Conv1d(in_channels=in_channels, out_channels=out_channels*2, 
                                    kernel_size=kernel_size, stride=stride, dilation=dilation,
                                    padding=padding, bias=False)      
        self.filter_out.weight.data.normal_(0, 0.02)
        
    def forward(self, x):
        # x: (B, C, T)
        y = self.filter_out(x)
        y = F.tanh(y[:, :self.c]) * torch.sigmoid(y[:, self.c:])
        
        return y
    
###
class NonLinear1d(nn.Module):
    def __init__(self, channels, hidden_channels, out_channels, n_layers):
        super().__init__()
        self.in_layer = nn.Conv1d(channels, hidden_channels, kernel_size=1)
        self.main_convs = nn.ModuleList([Conv1d(hidden_channels, hidden_channels, kernel_size=3,\
                                                dilation=2**(l+1), padding=2**(l+1)) for l in range(n_layers)])
        self.skip_convs = nn.ModuleList([Conv1d(hidden_channels, hidden_channels, kernel_size=1)\
                                         for l in range(n_layers)])
        self.out_layer = nn.Sequential(nn.ReLU(),
                                       nn.Conv1d(hidden_channels, out_channels, kernel_size=1),
                                       nn.ReLU(),
                                       nn.Conv1d(out_channels, out_channels, kernel_size=1))
                                       
        
    def forward(self, x):
        # x: (B, C, T)
        
        x = self.in_layer(x)
        
        skips = []
        for main_conv, skip_conv in zip(self.main_convs, self.skip_convs):
            y = main_conv(x)
            skip = skip_conv(y)
            skips.append(skip)
            x = x + skip
        
        y = self.out_layer(sum(skips))
        return y
        

In [23]:
x = torch.randn(2, 128, 100)
# y = Conv1d(128, 128, kernel_size=3, dilation=1, padding=0)(x)
# print(y.shape)

y = NonLinear1d(128, 256, 128, 4)(x)
print(y.shape)

torch.Size([2, 128, 100])
