In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader

In [2]:
class CausalConv(nn.Module):
    
    def __init__(self,through_channels, inter_channels, dilation=1, kernel_size = 5):
        """Causal Convolutional layer. This implementation is more in the flavor of pixel cnn where masks are
        used to ensure that the model can't peer forwards in time, where as other implementations often make us of 
        filters with a kernel size of 2. Those are probably more efficient but wanted to do it with masks :/

        Args:
            through_channels (int): size of input and output channels
            inter_channels (int): size of representation within the cell
            dilation (int, optional): dilation level. Defaults to 1.
            kernel_size (int, optional): kernel size. Defaults to 5.
        """
        super(CausalConv,self).__init__()
    
        assert(kernel_size%2 == 1)
        assert(kernel_size >= 5)

        filter_mask = torch.tensor([1 for i in range(kernel_size//2)] + 
                                        [0 for i in range (-(-kernel_size//2))])
        self.register_buffer('filter_mask', filter_mask)

        self.conv_sig = nn.Conv1d(through_channels,
                                  inter_channels,
                                  kernel_size,
                                  dilation=dilation,
                                  padding='same'
                                  )
        self.conv_tanh = nn.Conv1d(through_channels,
                                  inter_channels,
                                  kernel_size,
                                  dilation=dilation,
                                  padding='same'
                                  )
        self.one_by_one = nn.Conv1d(inter_channels,
                                    through_channels,
                                    kernel_size=1
                                    )

    
    def forward(self,inputs):
        with torch.no_grad():
            
            self.conv_sig.weight = nn.parameter.Parameter(self.conv_sig.weight * self.get_buffer('filter_mask'))
            self.conv_tanh.weight = nn.parameter.Parameter(self.conv_tanh.weight * self.get_buffer('filter_mask'))

        sig_a = self.conv_sig(inputs)
        sig_a = nn.Sigmoid()(sig_a)

        tanh_a = self.conv_tanh(inputs)
        tanh_a = nn.Tanh()(tanh_a)

        x = sig_a * tanh_a
        skip = self.one_by_one(x)
        res = skip + inputs

        return res, skip



In [8]:
class WaveNet(nn.Module):
    def __init__(self, num_layers,in_channels, channels):
        """WaveNet Model

        Args:
            num_layers (int): num of layers
            in_channels (int): input channels, should just be
            channels (int): number of channels to upsample to
        """
        super(WaveNet,self).__init__()
        self.init_embed = nn.Conv1d(in_channels,channels,1)
        
        self.causal_layers = nn.ModuleList([])
        for i in range(num_layers):
            self.causal_layers.append(CausalConv(channels,channels*2,2**i))
            
        self.one_by_one_1 = nn.Conv1d(channels, 512, 1)
        self.one_by_one_2 = nn.Conv1d(512, 256, 1)   
    
    def forward(self,inputs):
        x = self.init_embed(inputs)
        skips = [] #[(batch,channels,length)
        for layer in self.causal_layers:
            x, skip = layer(x)
            skips.append(skip)
            
        stack = torch.stack(skips,dim=0)
        add = stack.sum(dim=0)
        
        add = nn.ReLU()(add)
        add = self.one_by_one_1(add)
        
        add = nn.ReLU()(add)
        logits = self.one_by_one_2(add)
        
        odds = torch.softmax(logits, dim = 1)
        
        return logits, odds
        


In [None]:
class trainWrapper():
    def __init__(self,num_layers,in_channels,channels,lr):
        self.waveNet = WaveNet(num_layers,in_channels,channels)
        self.optim = torch.optim.Adam(self.waveNet.parameters,lr = lr)
        
    def __train_on_batch(self, x):
        self.optim.zero_grad()
        logits, odds = self.waveNet(x)
        
        loss = nn.CrossEntropyLoss()(logits, x)
        loss.backward()
        self.optim.step()
        
        return loss
        
    def train(self,steps: int, batch_size: int, dataset):
        loader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=True))
        for step in range(steps):
            x = next(loader)
            loss = self.__train_on_batch(x)
            print(loss)
            
        return self.waveNet
        
        
        