<a href="https://colab.research.google.com/github/pbrandl/aNN_Audio/blob/master/WaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# WaveNet Implementation

Modified WaveNet implementation with a memory of the latest receptive field in a sequence.

In [None]:
import torch
from torch import nn
from functools import reduce


class GatedConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedConv1d, self).__init__()
        self.dilation = dilation
        self.conv_f = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        self.conv_g = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        padding = self.dilation - (x.shape[-1] + self.dilation - 1) % self.dilation
        x = nn.functional.pad(x, (self.dilation, 0))
        return torch.mul(self.conv_f(x), self.sig(self.conv_g(x)))


class GatedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, receptive_field, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedResidualBlock, self).__init__()
        self.receptive_field = receptive_field
        self.gatedconv = GatedConv1d(in_channels, out_channels, kernel_size,
                                     stride=stride, padding=padding,
                                     dilation=dilation, groups=groups, bias=bias)
        self.conv_1 = nn.Conv1d(out_channels, out_channels, 1, stride=1, padding=0,
                                dilation=1, groups=1, bias=bias)

    def forward(self, x):
        skip = self.conv_1(self.gatedconv(x))
        residual = torch.add(skip, x)

        #skip_cut = skip.shape[-1] - self.output_width
        #skip = skip.narrow(-1, skip_cut, self.output_width)
        #skip = skip[:, :, self.receptive_field:]
        return residual, skip


class WaveNet(nn.Module):
    def __init__(self, num_time_samples, num_channels=1, num_classes=2 ** 16, num_blocks=2, num_layers=14,
                 num_hidden=32, kernel_size=2, device='cuda'):
        super(WaveNet, self).__init__()
        self.previous = None
        self.num_time_samples = num_time_samples
        self.num_channels = num_channels
        self.num_classes = num_classes
        self.num_blocks = num_blocks
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.kernel_size = kernel_size
        self.device = device

        self.receptive_field = (kernel_size - 1) * num_blocks * (1 + sum([2 ** k for k in range(num_layers)]))

        print('Receptive Field: {}'.format(self.receptive_field))

        hs = []
        #batch_norms = []

        # add gated convs
        first = True
        for b in range(num_blocks):
            for i in range(num_layers):
                rate = 2 ** i
                if first:
                    h = GatedResidualBlock(num_channels, num_hidden, kernel_size,
                                           self.receptive_field, dilation=rate)
                    first = False
                else:
                    h = GatedResidualBlock(num_hidden, num_hidden, kernel_size,
                                           self.receptive_field, dilation=rate)
                h.name = 'b{}-l{}'.format(b, i)

                hs.append(h)
                #batch_norms.append(nn.BatchNorm1d(num_hidden))

        self.hs = nn.ModuleList(hs)
        #self.batch_norms = nn.ModuleList(batch_norms)
        self.relu_1 = nn.ReLU()
        self.conv_1_1 = nn.Conv1d(num_hidden, num_hidden, 1)
        self.relu_2 = nn.ReLU()
        self.conv_1_2 = nn.Conv1d(num_hidden, num_hidden, 1)
        self.h_class = nn.Conv1d(num_hidden, num_classes, 2)

        self.linear_mix = nn.Conv1d(
            in_channels=num_hidden,
            out_channels=1,
            kernel_size=1,
        )

    def forward(self, x):
        if self.previous is None:
            self.previous = torch.zeros(x.shape[0], 
                                        x.shape[1], 
                                        self.receptive_field,
                                        device = self.device)

        x = torch.cat((self.previous, x), dim=2)
        self.previous = x[:, :, -self.receptive_field:]

        skips = []
        for layer in self.hs:
            x, skip = layer(x)
            #x = batch_norm(x)
            skips.append(skip)
        x = reduce(torch.add, skips)
        #x = self.relu_1(self.conv_1_1(x))
        #x = self.relu_2(self.conv_1_2(x))
        return self.linear_mix(x)[:, :, self.receptive_field:]


# Training Metod

In [None]:
import os
import time
import torch
import torchaudio
from torch import nn

device = cuda = torch.device('cuda')
print(device)


s_samples = 12000
model = WaveNet(s_samples, num_layers=11, num_hidden=4, device=device).to(device)

def train(model, infile_x, infile_y, epochs, n_samples, s_samples, s_batch, lr, device='cpu'):
    x, sr_x = torchaudio.load(infile_x, normalization=True)
    y, sr_y = torchaudio.load(infile_y, normalization=True)

    x = x.to(device)
    y = y.to(device)
    print(type(x))

    assert sr_x == sr_y, "Expected audio data to be eqaul in sample rate."
    assert x.shape == y.shape, "Expected audio data to be eqaul in shape."

    b_length = s_batch * s_samples

    assert n_samples * s_samples * s_batch <= x.shape[1], "Samples must not exceed audio data length."

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss(reduction='sum')
    #loss_fn = nn.L1Loss(reduction='mean')

    start_time = time.time()

    loss_history = []

    for epoch in range(epochs):
        for i in range(0, s_samples * n_samples, b_length):
            optimizer.zero_grad()

            # Shape batches of (Features, Batch, Input)
            x_batch, y_batch = x[0, i:i + b_length], y[0, i:i + b_length]

            x_batch = x_batch.reshape(s_batch, 1, s_samples)
            y_batch = y_batch.reshape(s_batch, 1, s_samples)

            prediction = model(x_batch)

            # loss = error_to_signal(y_batch, prediction).mean()
            print("pred", prediction.shape, "y", y_batch.shape)
            loss = loss_fn(prediction, y_batch)
            loss.backward()
            optimizer.step()

            loss_history.append(loss.item())
            print("Epoch", epoch, "Loss:", loss.item())

    print("Duration:", time.time() - start_time)

    torch.save(model.to('cpu').state_dict(), "/content/drive/My Drive/Colab Notebooks/Models/second_try" + time.strftime("%y%m%d-%H%M"))

    return loss_history

x_file = "/content/drive/My Drive/Colab Notebooks/trimmed_x.wav"
y_file = "/content/drive/My Drive/Colab Notebooks/trimmed_y.wav"
train(model, x_file, y_file, epochs=20,
      n_samples=2000, s_samples=s_samples, s_batch=5, lr=1e-3, device=device)


# Predict Sequence Method

In [None]:
import torch.nn.functional as F

s_samples = 12000
model = WaveNet(s_samples, num_layers=11, num_hidden=4, device=device).to(device)

model.load_state_dict(torch.load("/content/drive/My Drive/Colab Notebooks/Models/second_try201029-2126"))

x_file = "/content/drive/My Drive/Colab Notebooks/Audio/beat_test_raw_l.wav"

x_test, sr_x = torchaudio.load(x_file, normalization=True)
x_test = x_test.to(device)
print("input file shape", x_test.shape)


def predict_sequence(model, x, pred_length):
    pad_size = pred_length - x.shape[1] % pred_length
    x_padded = F.pad(x, (0, pad_size), mode='constant', value=0)

    seq_length = x_padded.shape[1]
    seq = torch.zeros(seq_length)

    for i in range(0, seq_length, pred_length):
        x = x_padded[:, i:i+pred_length].reshape(1, 1, pred_length)
        seq[i:i+pred_length] = model(x)

    return seq.squeeze()[:seq_length]


output = predict_sequence(model, x_test, s_samples)

print("Now writing to wav", output.shape)
torchaudio.save("/content/drive/My Drive/Colab Notebooks/Predictions/test_pred.wav", output, sr_x)


# Requirements


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torchaudio==0.6.0

In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.6-cp36-cp36m-linux_x86_64.whl

In [None]:
VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION