In [1]:
import torch
from torch import nn

In [2]:
test_conv = nn.Conv1d(1,2,3)
print(test_conv.weight)
with torch.no_grad():
    test_conv.weight *= torch.zeros((1,1,1))
# print(test_conv.weight)
# input = torch.ones(1,1,10)
# test_conv(input)

Parameter containing:
tensor([[[-0.2988,  0.0702,  0.0528]],

        [[ 0.1779, -0.4344, -0.2638]]], requires_grad=True)


In [59]:
class CausalConv(nn.Module):
    '''CausalConv resembling pixelCNN a bit more than some implementations'''
    def __init__(self,through_channels, inter_channels, dilation=1, kernel_size = 5):
        super(CausalConv,self).__init__()
        
        assert(kernel_size%2 == 1)
        assert(kernel_size >= 5)

        filter_mask = torch.tensor([1 for i in range(kernel_size//2)] + 
                                        [0 for i in range (-(-kernel_size//2))])
        self.register_buffer('filter_mask', filter_mask)

        self.conv_sig = nn.Conv1d(through_channels,
                                  inter_channels,
                                  kernel_size,
                                  dilation=dilation,
                                  padding='same'
                                  )
        self.conv_tanh = nn.Conv1d(through_channels,
                                  inter_channels,
                                  kernel_size,
                                  dilation=dilation,
                                  padding='same'
                                  )
        self.one_by_one = nn.Conv1d(inter_channels,
                                    through_channels,
                                    kernel_size=1
                                    )

                               

    
    def forward(self,inputs):
        with torch.no_grad():
            
            self.conv_sig.weight = nn.parameter.Parameter(self.conv_sig.weight * self.get_buffer('filter_mask'))
            self.conv_tanh.weight = nn.parameter.Parameter(self.conv_tanh.weight * self.get_buffer('filter_mask'))

        sig_a = self.conv_sig(inputs)
        sig_a = nn.Sigmoid()(sig_a)

        tanh_a = self.conv_tanh(inputs)
        tanh_a = nn.Tanh()(tanh_a)

        x = sig_a * tanh_a
        x = self.one_by_one(x)
        res = x + inputs

        return res, x



In [65]:
layer = CausalConv(3,5)
input = torch.rand((1,3,20))
# print(input)
layer(input)[1]

torch.Size([1, 3, 20])


tensor([[[ 0.0968,  0.0684,  0.0750,  0.1298,  0.0374,  0.0773,  0.1119,
           0.1175,  0.0786,  0.1402,  0.1527,  0.1313,  0.1370,  0.1197,
           0.1183,  0.1229,  0.1656,  0.1269,  0.1543,  0.1618],
         [-0.0642, -0.0615, -0.0510, -0.0075, -0.0370, -0.0304, -0.0321,
          -0.0373, -0.0418, -0.0307, -0.0200, -0.0246, -0.0232, -0.0245,
          -0.0412, -0.0543, -0.0361, -0.0441, -0.0374, -0.0257],
         [-0.2599, -0.2539, -0.1947, -0.1878, -0.1781, -0.1888, -0.1548,
          -0.1918, -0.1858, -0.1306, -0.1698, -0.1474, -0.1481, -0.1660,
          -0.1992, -0.1817, -0.1976, -0.1906, -0.1811, -0.1604]]],
       grad_fn=<SqueezeBackward1>)

In [None]:
class WaveNet(nn.Module):
    def __init__(self):
        
