<a href="https://colab.research.google.com/github/pbrandl/aNN_Audio/blob/master/WaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Global Variables and Requirements

In [None]:
import os
import sys
import glob
import time
import torch
from torch import nn
import torch.nn.functional as F

from google.colab import drive
drive.mount('/content/drive')

# Set Working Directory
project_path = '/content/drive/My Drive/aNN_Colab'
print("Working in {}.".format(project_path))
models_path = os.path.join(project_path, 'Models')
preds_path = os.path.join(project_path, 'Predictions')
pyenv = os.path.join(project_path, 'pyenv')
sys.path.append(pyenv) 

# Select the Processing Device
device = torch.device('cuda')
print("Working on {}.".format(device))



# WaveNet Implementation

Modified WaveNet implementation with a memory of the latest receptive field in a sequence.

In [None]:
class GatedConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedConv1d, self).__init__()
        self.dilation = dilation
        self.conv_f = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        self.conv_g = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x = nn.functional.pad(x, (self.dilation, 0))
        return torch.mul(self.conv_f(x), self.sig(self.conv_g(x)))


class GatedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedResidualBlock, self).__init__()
        self.gatedconv = GatedConv1d(in_channels, out_channels, kernel_size,
                                     stride=stride, padding=padding,
                                     dilation=dilation, groups=groups, bias=bias)
        self.conv_1 = nn.Conv1d(out_channels, out_channels, 1, stride=1, padding=0,
                                dilation=1, groups=1, bias=bias)

    def forward(self, x):
        skip = self.conv_1(self.gatedconv(x))
        residual = torch.add(skip, x)
        return residual, skip


class WaveNet(nn.Module):
    def __init__(self, num_time_samples, num_channels=1, num_classes=2 ** 16, num_blocks=2, num_layers=14,
                 num_hidden=32, kernel_size=2, device='cuda'):
        super(WaveNet, self).__init__()

        self.num_time_samples = num_time_samples
        self.num_channels = num_channels
        self.num_classes = num_classes
        self.num_blocks = num_blocks
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.kernel_size = kernel_size
        self.device = device

        self.length_rf = (kernel_size - 1) * num_blocks * (1 + sum([2 ** k for k in range(num_layers)]))
        self.previous_rf = torch.tensor([]).to(device) # Initial Receptive Field

        hs = []
        #batch_norms = []

        # add gated convs
        first = True
        for b in range(num_blocks):
            for i in range(num_layers):
                rate = 2 ** i
                if first:
                    h = GatedResidualBlock(num_channels, num_hidden, kernel_size, dilation=rate)
                    first = False
                else:
                    h = GatedResidualBlock(num_hidden, num_hidden, kernel_size, dilation=rate)
                h.name = 'b{}-l{}'.format(b, i)

                hs.append(h)
                #batch_norms.append(nn.BatchNorm1d(num_hidden))

        self.hs = nn.ModuleList(hs)
        #self.batch_norms = nn.ModuleList(batch_norms)
        self.relu_1 = nn.ReLU()
        self.conv_1_1 = nn.Conv1d(num_hidden, num_hidden, 1)
        self.relu_2 = nn.ReLU()
        self.conv_1_2 = nn.Conv1d(num_hidden, num_hidden, 1)
        self.h_class = nn.Conv1d(num_hidden, num_classes, 2)

        self.linear_mix = nn.Conv1d(
            in_channels=num_hidden,
            out_channels=1,
            kernel_size=1,
        )


    def forward(self, x):
        x = torch.cat((self.previous_rf, x), dim=2)
        self.previous_rf = x[:, :, -self.length_rf:]

        skips = []
        for layer in self.hs:
            x, skip = layer(x)
            skips.append(skip)
        x = reduce(torch.add, skips)

        return self.linear_mix(x)[:, :, self.length_rf:]


# Training Metod

In [None]:
def load_data_set(file_x, file_y, channels=1, device='cuda'):
    x, sr_x = torchaudio.load(x_file, normalization=True)
    y, sr_y = torchaudio.load(y_file, normalization=True)
    assert sr_x == sr_y, "Expected audio data to be eqaul in sample rate."
    assert x.shape == y.shape, "Expected audio data to be equal in shape."
    return x[0:channels, :].to(device), y[0:channels, :].to(device)

def batchify(x, y, batch_size, length_sample):
  """Shape data to batches of (Sample, Batch, Channels, SampleLength)."""

  assert batch_size * length_sample <= x.shape[1], "Summed duration length must be less than train data."

  n_samples = x.shape[1] // (length_sample*batch_size)
  sum_length = n_samples * batch_size * length_sample
  x = x[:, :sum_length].reshape(n_samples, batch_size, 1, length_sample)
  y = y[:, :sum_length].reshape(n_samples, batch_size, 1, length_sample)
  print("Generated a max. number of samples {} by batch size {}\
  and sample length {}".format(n_samples, batch_size, length_sample))
  return x, y

# Files
file_x = os.path.join(project_path, "trimmed_x.wav")
y_file = os.path.join(project_path, "trimmed_y.wav")
x, y = load_data_set(file_x, file_y, channels=1, device=device)

In [None]:
# Details
length_sample = 9000
batch_size = 20
num_channels = 1
x, y = batchify(x, y, batch_size, length_sample)
print("Train Data Shape: {}.".format(x.shape))

# Model
model = WaveNet(length_sample, num_layers=11, num_hidden=4, num_blocks=1, device=device).to(device)
model.previous_rf = torch.zeros(batch_size, 1, model.length_rf).to(device)
print("Receptive Field Length: {}".format(model.length_rf))


def train(model, x, y, epochs, n_samples, length_sample, batch_size, lr):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    #loss_fn = nn.MSELoss(reduction='sum')
    loss_fn = nn.L1Loss(reduction='sum')
    loss_history = []
    for epoch in range(epochs):
        for x_batch, y_batch in zip(x, y):
            optimizer.zero_grad()
            prediction = model(x_batch)
            loss = loss_fn(prediction, y_batch)
            loss.backward()
            optimizer.step()
            loss_history.append(loss.item())
        print("Epoch", epoch, "Loss:", loss.item())

    return loss_history

start_time = time.time()
loss_history = train(model, x.to(device), y.to(device), epochs=20,
      n_samples=n_samples, length_sample=length_sample, batch_size=batch_size, lr=1e-3)
print("Duration:", time.time() - start_time)

torch.save(model.state_dict(), os.path.join(models_path, time.strftime("%y%m%d-%H%M")))

# Predict Sequence Method

In [None]:
def get_latest_model(path):
  model_files = glob.glob(os.path.join(path, '*'))
  return sorted(model_files, key=os.path.getmtime, reverse=True)[0]

def get_model(filename):
  model_files = glob.glob(os.path.join(path, filename))
  return sorted(model_files, key=os.path.getmtime, reverse=True)[0]

In [None]:
# Load Model by File
model_file = get_latest_model(models_path)
model = WaveNet(length_sample, num_layers=11, num_hidden=4, num_blocks=1, device=device).to(device)
model.previous_rf = torch.zeros(1, 1, model.length_rf).to(device)

model.load_state_dict(torch.load(model_file))

x_file = os.path.join(project_path, "Audio", "beat_test_raw_l.wav")

x_test, sr_x = torchaudio.load(x_file, normalization=True)
x_test = x_test.to(device)
print("input file shape", x_test.shape)


def predict_sequence(model, x, pred_length):
    pad_size = pred_length - x.shape[1] % pred_length
    x_padded = F.pad(x, (0, pad_size), mode='constant', value=0)

    seq_length = x_padded.shape[1]
    seq = torch.zeros(seq_length)
    print(seq.shape)

    for i in range(0, seq_length, pred_length):
        x_slice = x_padded[:, i:i+pred_length].reshape(1, 1, pred_length)
        print(x_slice.shape)
        seq[i:i+pred_length] = model(x_slice)

    return seq.squeeze()[:x.shape[1]]


output = predict_sequence(model, x_test, 9000)

print("Now writing to wav", output.shape)
torchaudio.save(os.path.join(project_path, "Predictions", "test_pred.wav"), output, sr_x)


# Extra Stuff


In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.6-cp36-cp36m-linux_x86_64.whl

In [None]:
VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION