<a href="https://colab.research.google.com/github/pbrandl/aNN_Audio/blob/master/WaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training Metod

In [3]:
import time
import torch
import torchaudio
from torch import nn


def pre_emphasis_filter(x, coeff=0.95):
    return torch.cat((x[:, :, 0:1], x[:, :, 1:] - coeff * x[:, :, :-1]), dim=2)


def error_to_signal(y, y_pred):
    """
    Error to signal ratio with pre-emphasis filter:
    https://www.mdpi.com/2076-3417/10/3/766/htm
    """
    y, y_pred = pre_emphasis_filter(y), pre_emphasis_filter(y_pred)
    return (y - y_pred).pow(2).sum(dim=2) / (y.pow(2).sum(dim=2) + 1e-10)


def train(infile_x, infile_y, n_samples=2000, s_samples=18000, s_batch=5, lr=1e-3):
    x, sr_x = torchaudio.load(infile_x, normalization=True)
    y, sr_y = torchaudio.load(infile_y, normalization=True)

    assert sr_x == sr_y, "Expected audio data to be eqaul in sample rate."
    assert x.shape == y.shape, "Expected audio data to be eqaul in shape."

    b_length = s_batch * s_samples

    assert n_samples * s_samples * s_batch <= x.shape[1], "Samples must not exceed audio data length."

    model = WaveNet(s_samples, num_layers=11, num_hidden=4)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss(reduction='sum')
    #loss_fn = nn.L1Loss(reduction='mean')

    start_time = time.time()

    loss_history = []
    epochs = 4

    for epoch in range(epochs):
        for i in range(0, s_samples * n_samples, b_length):
            optimizer.zero_grad()

            # Shape batches of (Features, Batch, Input)
            x_batch, y_batch = x[0, i:i + b_length], y[0, i:i + b_length]

            x_batch = x_batch.reshape(s_batch, 1, s_samples)
            y_batch = y_batch.reshape(s_batch, 1, s_samples)

            prediction = model(x_batch)

            # loss = error_to_signal(y_batch, prediction).mean()
            print("pred", prediction.shape, "y", y_batch.shape)
            loss = loss_fn(prediction, y_batch)
            loss.backward()
            optimizer.step()

            loss_history.append(loss.item())
            print("Epoch", epoch, "Loss:", loss.item())

    print("Duration:", time.time() - start_time)

    torch.save(model.state_dict(), "Models/second_try" + time.strftime("%y%m%d-%H%M"))

    return loss_history


train('Preprocessed/trimmed_x.wav', 'Preprocessed/trimmed_y.wav', n_samples=2000, s_samples=12000, s_batch=5)


OSError: ignored

# Requirements


In [1]:
!pip install torchaudio


Collecting torchaudio
[?25l  Downloading https://files.pythonhosted.org/packages/3f/23/6b54106b3de029d3f10cf8debc302491c17630357449c900d6209665b302/torchaudio-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (7.6MB)
[K     |████████████████████████████████| 7.6MB 4.6MB/s 
[?25hCollecting torch==1.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/80/2a/58f8078744e0408619c63148f7a2e8e48cf007e4146b74d4bb67c56d161b/torch-1.7.0-cp36-cp36m-manylinux1_x86_64.whl (776.7MB)
[K     |████████████████████████████████| 776.7MB 23kB/s 
[31mERROR: torchvision 0.7.0+cu101 has requirement torch==1.6.0, but you'll have torch 1.7.0 which is incompatible.[0m
Installing collected packages: torch, torchaudio
  Found existing installation: torch 1.6.0+cu101
    Uninstalling torch-1.6.0+cu101:
      Successfully uninstalled torch-1.6.0+cu101
Successfully installed torch-1.7.0 torchaudio-0.7.0



# WaveNet Implementation

Modified WaveNet implementation with a memory of the latest receptive field in a sequence.

In [2]:
import torch
from torch import nn
from functools import reduce


class GatedConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedConv1d, self).__init__()
        self.dilation = dilation
        self.conv_f = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        self.conv_g = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        padding = self.dilation - (x.shape[-1] + self.dilation - 1) % self.dilation
        x = nn.functional.pad(x, (self.dilation, 0))
        return torch.mul(self.conv_f(x), self.sig(self.conv_g(x)))


class GatedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, receptive_field, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedResidualBlock, self).__init__()
        self.receptive_field = receptive_field
        self.gatedconv = GatedConv1d(in_channels, out_channels, kernel_size,
                                     stride=stride, padding=padding,
                                     dilation=dilation, groups=groups, bias=bias)
        self.conv_1 = nn.Conv1d(out_channels, out_channels, 1, stride=1, padding=0,
                                dilation=1, groups=1, bias=bias)

    def forward(self, x):
        skip = self.conv_1(self.gatedconv(x))
        residual = torch.add(skip, x)

        #skip_cut = skip.shape[-1] - self.output_width
        #skip = skip.narrow(-1, skip_cut, self.output_width)
        skip = skip[:, :, self.receptive_field:]
        return residual, skip


class WaveNet(nn.Module):
    def __init__(self, num_time_samples, num_channels=1, num_classes=2 ** 16, num_blocks=2, num_layers=14,
                 num_hidden=32, kernel_size=2):
        super(WaveNet, self).__init__()
        self.previous = None
        self.num_time_samples = num_time_samples
        self.num_channels = num_channels
        self.num_classes = num_classes
        self.num_blocks = num_blocks
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.kernel_size = kernel_size
        self.receptive_field = (kernel_size - 1) * num_blocks * (1 + sum([2 ** k for k in range(num_layers)]))

        #self.output_width = num_time_samples - self.receptive_field

        print('receptive_field: {}'.format(self.receptive_field))
        #print('Output width: {}'.format(self.output_width))

        self.device = self.set_device()
        print(self.device)

        hs = []
        batch_norms = []

        # add gated convs
        first = True
        for b in range(num_blocks):
            for i in range(num_layers):
                rate = 2 ** i
                if first:
                    h = GatedResidualBlock(num_channels, num_hidden, kernel_size,
                                           self.receptive_field, dilation=rate)
                    first = False
                else:
                    h = GatedResidualBlock(num_hidden, num_hidden, kernel_size,
                                           self.receptive_field, dilation=rate)
                h.name = 'b{}-l{}'.format(b, i)

                hs.append(h)
                batch_norms.append(nn.BatchNorm1d(num_hidden))

        self.hs = nn.ModuleList(hs)
        self.batch_norms = nn.ModuleList(batch_norms)
        self.relu_1 = nn.ReLU()
        self.conv_1_1 = nn.Conv1d(num_hidden, num_hidden, 1)
        self.relu_2 = nn.ReLU()
        self.conv_1_2 = nn.Conv1d(num_hidden, num_hidden, 1)
        self.h_class = nn.Conv1d(num_hidden, num_classes, 2)

        self.linear_mix = nn.Conv1d(
            in_channels=num_hidden,
            out_channels=1,
            kernel_size=1,
        )

    def forward(self, x):
        if self.previous is None:
            self.previous = torch.zeros(x.shape[0], x.shape[1], self.receptive_field)

        x = torch.cat((self.previous, x), dim=2)
        self.previous = x[:, :, -self.receptive_field:]

        skips = []
        for layer, batch_norm in zip(self.hs, self.batch_norms):
            x, skip = layer(x)
            x = batch_norm(x)
            skips.append(skip)
        x = reduce(torch.add, skips)
        #x = self.relu_1(self.conv_1_1(x))
        #x = self.relu_2(self.conv_1_2(x))
        return self.linear_mix(x)

    @staticmethod
    def set_device(device=None):
        if device is None:
            return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        else:
            return device
