<a href="https://colab.research.google.com/github/pbrandl/aNN_Audio/blob/master/WaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# WaveNet Implementation

Modified WaveNet implementation with a memory of the latest receptive field in a sequence.

In [None]:
import torch
from torch import nn
from functools import reduce


class GatedConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedConv1d, self).__init__()
        self.dilation = dilation
        self.conv_f = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        self.conv_g = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        padding = self.dilation - (x.shape[-1] + self.dilation - 1) % self.dilation
        x = nn.functional.pad(x, (self.dilation, 0))
        return torch.mul(self.conv_f(x), self.sig(self.conv_g(x)))


class GatedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, receptive_field, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedResidualBlock, self).__init__()
        self.receptive_field = receptive_field
        self.gatedconv = GatedConv1d(in_channels, out_channels, kernel_size,
                                     stride=stride, padding=padding,
                                     dilation=dilation, groups=groups, bias=bias)
        self.conv_1 = nn.Conv1d(out_channels, out_channels, 1, stride=1, padding=0,
                                dilation=1, groups=1, bias=bias)

    def forward(self, x):
        skip = self.conv_1(self.gatedconv(x))
        residual = torch.add(skip, x)

        #skip_cut = skip.shape[-1] - self.output_width
        #skip = skip.narrow(-1, skip_cut, self.output_width)
        skip = skip[:, :, self.receptive_field:]
        return residual, skip


class WaveNet(nn.Module):
    def __init__(self, num_time_samples, num_channels=1, num_classes=2 ** 16, num_blocks=2, num_layers=14,
                 num_hidden=32, kernel_size=2):
        super(WaveNet, self).__init__()
        self.previous = None
        self.num_time_samples = num_time_samples
        self.num_channels = num_channels
        self.num_classes = num_classes
        self.num_blocks = num_blocks
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.kernel_size = kernel_size
        self.receptive_field = (kernel_size - 1) * num_blocks * (1 + sum([2 ** k for k in range(num_layers)]))

        #self.output_width = num_time_samples - self.receptive_field

        print('receptive_field: {}'.format(self.receptive_field))
        #print('Output width: {}'.format(self.output_width))

        self.device = self.set_device()
        print(self.device)

        hs = []
        batch_norms = []

        # add gated convs
        first = True
        for b in range(num_blocks):
            for i in range(num_layers):
                rate = 2 ** i
                if first:
                    h = GatedResidualBlock(num_channels, num_hidden, kernel_size,
                                           self.receptive_field, dilation=rate)
                    first = False
                else:
                    h = GatedResidualBlock(num_hidden, num_hidden, kernel_size,
                                           self.receptive_field, dilation=rate)
                h.name = 'b{}-l{}'.format(b, i)

                hs.append(h)
                batch_norms.append(nn.BatchNorm1d(num_hidden))

        self.hs = nn.ModuleList(hs)
        self.batch_norms = nn.ModuleList(batch_norms)
        self.relu_1 = nn.ReLU()
        self.conv_1_1 = nn.Conv1d(num_hidden, num_hidden, 1)
        self.relu_2 = nn.ReLU()
        self.conv_1_2 = nn.Conv1d(num_hidden, num_hidden, 1)
        self.h_class = nn.Conv1d(num_hidden, num_classes, 2)

        self.linear_mix = nn.Conv1d(
            in_channels=num_hidden,
            out_channels=1,
            kernel_size=1,
        )

    def forward(self, x):
        if self.previous is None:
            self.previous = torch.zeros(x.shape[0], x.shape[1], self.receptive_field)

        x = torch.cat((self.previous, x), dim=2)
        self.previous = x[:, :, -self.receptive_field:]

        skips = []
        for layer, batch_norm in zip(self.hs, self.batch_norms):
            x, skip = layer(x)
            x = batch_norm(x)
            skips.append(skip)
        x = reduce(torch.add, skips)
        #x = self.relu_1(self.conv_1_1(x))
        #x = self.relu_2(self.conv_1_2(x))
        return self.linear_mix(x)

    @staticmethod
    def set_device(device=None):
        if device is None:
            return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        else:
            return device
