<a href="https://colab.research.google.com/github/pbrandl/aNN_Audio/blob/master/WaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Global Variables and Requirements

---

The following codes requires your Google Drive and Colab to access some of the listed imports. Please also make sure to locally install torchaudio in your Google Drive.


In [None]:
!pip install --target="$pyenv" --upgrade torchaudio

In [None]:
# Import from Global Python Enivironemnt
import os
import sys
import glob
import time
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from functools import reduce
import matplotlib.pyplot as plt
from google.colab import drive
from IPython.display import Audio
from IPython.core.display import display


# Set Working Directories
drive.mount('/content/drive')
project_path = '/content/drive/My Drive/aNN_Colab'
print("Working in {}.".format(project_path))
models_path = os.path.join(project_path, 'Models')
preds_path = os.path.join(project_path, 'Predictions')
logger_path = os.path.join(project_path, 'log')

# Import from Local Python Environment
pyenv = os.path.join(project_path, 'pyenv')
sys.path.append(pyenv)

import torchaudio
torchaudio.set_audio_backend("sox_io")

# Select the Processing Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Working on {}.".format(device))


# Logger

The logger class aims to write the model's parameters to a file during training to compensate for run-time disruptions.

In [None]:
class Logger():
    def __init__(self, path):
        self.logger_path = path

    def log(self, id, epoch, model, loss):
        try:
            os.makedirs(os.path.join(self.logger_path, str(id)))
        except FileExistsError as e:
            pass

        with open(os.path.join(self.logger_path, str(id), str(epoch)), 'wb') as log_file:
            torch.save(model, log_file)

        with open(os.path.join(self.logger_path, str(id), 'val_loss'), 'a+') as log_file:
            log_file.write("{} {}\n".format(epoch, loss))

    def clean_log():
        pass
    
    def read_log(id):
        with open(os.path.join(logger_path, str(id), 'val_loss'), 'r') as log_file:
            return log_file.readlines()



# WaveNet Implementation

Modified WaveNet implementation with a memory of the latest receptive field in a sequence.

In [None]:
class AdaptiveActivation(nn.Module):
    """
        This is an adaptive activation function according to Em Karniadakis 2020.
        Title: "On the convergence of physics informed neural networks for linear 
        second-order elliptic and parabolic type PDEs"
    """
    def __init__(self, activation_fun):
        super(AdaptiveActivation, self).__init__()
        self.n = torch.Tensor([10]).to(device)
        self.a = nn.Parameter(torch.rand(1))
        self.activ_f = activation_fun()
    
    def forward(self, x):
        return self.activ_f(self.n * self.a * x)
        

class GatedConv1d(nn.Module):
    """
        Gated dilation layer used by WaveNet class.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedConv1d, self).__init__()
        self.dilation = dilation
        self.conv_f = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        self.conv_g = nn.Conv1d(in_channels, out_channels, kernel_size,
                                stride=stride, padding=padding, dilation=dilation,
                                groups=groups, bias=bias)
        
        self.tanh1 = AdaptiveActivation(nn.Tanh)
        self.tanh2 = AdaptiveActivation(nn.Sigmoid)

    def forward(self, x):
        x = nn.functional.pad(x, (self.dilation, 0))
        return torch.mul(self.tanh2(self.conv_f(x)), self.tanh1(self.conv_g(x)))


class GatedResidualBlock(nn.Module):
    """
        Gated block used by WaveNet class.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                 dilation=1, groups=1, bias=True):
        super(GatedResidualBlock, self).__init__()
        self.gatedconv = GatedConv1d(in_channels, out_channels, kernel_size,
                                     stride=stride, padding=padding,
                                     dilation=dilation, groups=groups, bias=bias)
        self.conv_1 = nn.Conv1d(out_channels, out_channels, 1, stride=1, padding=0,
                                dilation=1, groups=1, bias=bias)

    def forward(self, x):
        skip = self.conv_1(self.gatedconv(x))
        residual = torch.add(skip, x)
        return residual, skip


class WaveNet(nn.Module):
    def __init__(self, num_time_samples, num_channels=1, num_blocks=2, max_dilation=14,
                 num_hidden=32, kernel_size=2, device='cuda'):
        super(WaveNet, self).__init__()
        
        self.input_length = 0
        self.num_time_samples = num_time_samples
        self.num_channels = num_channels
        self.num_blocks = num_blocks
        self.max_dilation = max_dilation
        self.num_hidden = num_hidden
        self.kernel_size = kernel_size
        self.device = device

        self.length_rf = (kernel_size - 1) * num_blocks * (1+ sum([2 ** k for k in range(max_dilation)]))
        self.previous_rf = None # Initial Receptive Field
        self.x_shape = None # Remember the input shape

        stacked_dilation = []

        first = True
        for b in range(num_blocks):
            for i in range(max_dilation):
                rate = 2 ** i
                if first:
                    hidden = GatedResidualBlock(num_channels, num_hidden, kernel_size, dilation=rate)
                    first = False
                else:
                    hidden = GatedResidualBlock(num_hidden, num_hidden, kernel_size, dilation=rate)
                    
                hidden.name = 'b{}-l{}'.format(b, i)
                stacked_dilation.append(hidden)
                #stacked_dilation.append(nn.Tanh())
                #batch_norms.append(nn.BatchNorm1d(num_hidden))

        self.stacked_dilation = nn.ModuleList(stacked_dilation)
        
        self.atanh = AdaptiveActivation(nn.Tanh)

        self.linear_mix = nn.Conv1d(
            in_channels=num_hidden,
            out_channels=1,
            kernel_size=1,
        )

        self.to(device)

    @property
    def n_param(self):
        # Returns the number of parameters within the net.
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def reset_previous_rf(self):
        # Resets the receptive field.
        self.previous_rf = None

    def forward(self, x):
        self.x_shape = x.shape

        if self.previous_rf is None:
            self.previous_rf = torch.zeros((x.shape[0], x.shape[1], self.length_rf)).to(device)

        # Concat the last receptive field from x_(i-1) to x_i
        x_tended = torch.cat((self.previous_rf, x), dim=2)
        self.previous_rf = x[:, :, -self.length_rf:]
        
        skips = []
        for layer in self.stacked_dilation:
            x_tended, skip = layer(x_tended)
            skips.append(skip)
        
        x_tended = reduce(torch.add, skips)

        return self.linear_mix(x_tended)[:, :, self.length_rf:] + x

    def predict_sequence(self, x_seq):
        """
            Predicts a whole sequence of audio material. x_seq has to be two-dimensional, i.e.,
            (channels, lengths).
        """
        assert x_seq.dim() == 2, "Expected two-dimensional input shape (channels, lengths)."
        
        # Initialize 
        self.reset_previous_rf()
        x_length = self.x_shape[-1]
        x_seq_length = x_seq.shape[-1]
        channels = x_seq.shape[0]
        x_seq = x_seq.reshape(1, channels, x_seq_length)

        # Pad the input, s.t. it fits to the model's input expections
        pad_size = x_length - x_seq_length % x_length
        x_seq_padded = F.pad(x_seq, (pad_size, 0), mode='constant', value=0)
        x_seq_padded_length = x_seq_padded.shape[-1]
        y_seq_padded = torch.zeros_like(x_seq_padded)

        for i in range(0, x_seq_length, x_length):
            x_slice_c0 = x_seq_padded[:, 0, i:i+x_length].unsqueeze(0)
            y_seq_padded[:, 0, i:i+x_length] = model(x_slice_c0)

            if channels == 2:
                x_slice_c1 = x_seq_padded[:, 1, i:i+x_length].unsqueeze(0)
                y_seq_padded[:, 1, i:i+x_length] = model(x_slice_c1)
            #print(y_seq_padded.shape)

        y_seq = y_seq_padded[:, :, pad_size:]

        assert x_seq.shape == y_seq.shape, "Expected input and output to be equal in shape."
        return y_seq.reshape(channels, x_seq_length)



Test that the WaveNet has to pass:

In [None]:
# Test WaveNet with Random Tensor
n_dilations = 11
model = WaveNet(5000, max_dilation=n_dilations, num_hidden=4, num_blocks=1, kernel_size=2, device=device)
assert len(model.stacked_dilation) == n_dilations, "Num of layers do not match number of dilations."
x = torch.rand((5, 1, 6000)).to(device)
y = model(x)
print(model.n_param)
assert y.shape == x.shape, "In- and output do not match in size. Check your previous receptive field."

# Mono
model.predict_sequence(torch.rand(1, 8000).to(device))
# Stereo
model.predict_sequence(torch.rand(2, 8000).to(device))

# Data Set

Due to the huge dataset a Loader is implemted that enables preloading dataset. Once loaded all the Loader object stores the data and the rest of the notebook code can be modified without reloading. The class AudioDataset splits the data into train, test and validation parts and creates batches.

In [None]:
class Loader():
    def __init__(self, file_x, file_y, channels=1, device=device):
        self.file_x = file_x
        self.file_y = file_y
        self.__x, self.__sr_x = torchaudio.load(file_x, normalize=True)
        self.__y, self.__sr_y = torchaudio.load(file_y, normalize=True)
        assert self.__sr_x == self.__sr_y, "Expected audio data to be eqaul in sample rate."
        assert self.__x.shape == self.__y.shape, "Expected audio data to be equal in shape."

    @property
    def data(self):
        return self.__x, self.__y

    @property
    def input_files(self):
        return self.file_x, self.file_y

    @property
    def sample_rate(self):
        return self.__sr_x

# Load the data from Google Drive
# del preloader # if already loaded to avoid RAM crash
file_x = os.path.join(project_path, "Preprocessed_Data", "210421_Dataset_mono_trim_x.wav")
file_y = os.path.join(project_path, "Preprocessed_Data", "210421_Dataset_mono_trim_x.wav")
preloader = Loader(file_x, file_y)

In [None]:
from torch.utils.data import Dataset

class AudioDataset(Dataset):
    def __init__(self, file_x='', file_y='', channels=1, preloader=None, device='cuda'):
        # Load Train Data
        if file_x == '' and file_y == '':
            self.__x, self.__y = preloader.data
        elif preloader is not None: 
            preloader = Loader(file_x, file_y, channels, device)
            self.__x, self.__y = preloader.data
        else:
            raise Exception("Either preloader or input files need to be specified.") 

        self.file_x, self.file_y = preloader.input_files

        self.x_train = self.__x
        self.y_train = self.__y
        self.x_valid = None
        self.y_valid = None
        self.x_test = None
        self.y_test = None

        self.x_batches = self.__x
        self.y_batches = self.__y
        self.sample_rate = preloader.sample_rate

    def __len__(self):
        return len(x.shape[1])

    def __getitem__(self, channel, idx):
        return self.x[channel, idx]

    def print_file_names(self):
        print("Input File:  {}".format(self.file_x))
        print("Target File: {}".format(self.file_y))

    def batchify(self, batch_dim, length_sample, n_samples='max'):
        """Shape data to batches of (Sample, Batch, Channels, SampleLength)."""
        if n_samples == 'max':
          n_samples = self.__x.shape[1] // (length_sample*batch_dim)

        sum_length = n_samples * length_sample * batch_dim
        assert sum_length <= self.__x.shape[1], "Summed duration length must be less than train data."

        self.x_batches = self.__x[:, :sum_length].reshape(n_samples, batch_dim, 1, length_sample)
        self.y_batches = self.__y[:, :sum_length].reshape(n_samples, batch_dim, 1, length_sample)
        print("Reshaped to {}.".format(self.x_batches.shape))
        return self.x_batches, self.y_batches

    def split_data(self, xs, *args):
        assert sum([arg for arg in args]) <= 1, "Splits must sum to 1."
        n_samples = xs[0].shape[0]
        
        for x in xs:
            start = 0
            splits = []
            for arg in args:
                end = np.rint(n_samples * arg).astype(int) + start
                yield x[start:end]
                start = end
    
    def isnan(self):
        if torch.isnan(self.x_train).any():
            print("X contains nan.")
        if torch.isnan(self.y_train).any():
            print("Y contains nan.")
        

f_x, f_y = preloader.input_files
dataset = AudioDataset(preloader=preloader, device=device)


In [None]:
# Listen to DataSet
x, y = dataset.batchify(5, 400000)
x1, x2, x3, y1, y2, y3 = dataset.split_data((x,y), 0.5, 0.4, 0.1)
print(x1.shape, y1.shape)
print(x2.shape, y2.shape)
print(x3.shape, y3.shape)
x1.shape[0] + x2.shape[0] + x3.shape[0]

sample_rate = dataset.sample_rate
display(Audio(x1[9, 0, 0, :].cpu().numpy(), rate=sample_rate))
Audio(x1[9, 0, 0, :].cpu().numpy(), rate=sample_rate)
# Test Data Set for nans
dataset.isnan()

# Training Metod

This section includes the train, validate and testing methods.

In [None]:
# RunTime Listening Files
listening_test_file = os.path.join(project_path, "Audio", "beat_test_raw_l.wav")
listening_test_drums = torchaudio.load(listening_test_file, normalize=True)[0].to(device)
print("Listening tensor shape", listening_test_drums.shape)
listening_test_file = os.path.join(project_path, "Audio", "test_full_mix.wav")
listening_test_fmix = torchaudio.load(listening_test_file, normalize=True)[0].to(device)
print("Listening tensor shape", listening_test_fmix.shape)
display(Audio(listening_test_fmix.cpu().numpy(), rate=44100))
display(Audio(listening_test_drums.cpu().numpy(), rate=44100))

# Hyperparameter Settings
length_sample = 4096*2
batch_dim = 4
num_channels = 1

dataset = AudioDataset(preloader=preloader, device=device)
x, y = dataset.batchify(batch_dim=batch_dim, length_sample=length_sample, n_samples=5000)
x_train, x_valid, y_train, y_valid = dataset.split_data((x, y), 0.8, 0.2)

model = WaveNet(length_sample, max_dilation=11, num_hidden=1, num_blocks=1, device=device)
print("Train Data Shape: {}.".format(x.shape))
print("Receptive Field Length: {}.".format(model.length_rf))

logger = Logger(logger_path)

assert x_train.shape ==  y_train.shape and x_valid.shape == y_valid.shape, "Expected equal shape."

def train(model, x, y, loss_fn, optimizer):
    model.train()
    for x_batch, y_batch in zip(x, y):
        optimizer.zero_grad()
        prediction = model(x_batch)
        loss = loss_fn.forward(prediction, y_batch)
        loss.backward()
        optimizer.step()
        #print(loss.item())
    return loss.item()

@torch.no_grad()
def validate(model, x, y, loss_fn):
    loss_history = []
    model.eval()
    for x_sample, y_sample in zip(x, y):
        prediction = model(x_sample)
        loss_history.append(loss_fn.forward(prediction, y_sample))
    return sum(loss_history) / len(loss_history)

@torch.no_grad()
def pred_listening_test(model):
    print("now predicting listening_test")
    y_listening_test = model.predict_sequence(listening_test_drums)
    display(Audio(y_listening_test.cpu().numpy(), rate=44100))
    y_listening_test = model.predict_sequence(listening_test_fmix)
    display(Audio(y_listening_test.cpu().numpy(), rate=44100))

def fit(model, x_train, y_train, x_valid, y_valid, epochs, config, logger=None):
    assert x_train.shape == y_train.shape, "Expected data in equal shape."
    assert x_valid.shape == y_valid.shape, "Expected data in equal shape."
    lr = config['lr']

    optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'])
    loss_fn = nn.MSELoss(reduction='mean')
    #loss_fn = nn.L1Loss(reduction='mean')

    for epoch in range(int(epochs)):
        model.reset_previous_rf()
        loss_train = train(model, x_train, y_train, loss_fn, optimizer)
        loss_valid = validate(model, x_valid, y_valid, loss_fn)
        print("Epoch:", epoch, "\nAvg. valid. loss:", loss_valid.item())
        if logger is not None:
            logger.log('test', epoch, model, loss_valid)
    
        pred_listening_test(model)

    return loss_valid, loss_train

print(fit(model, x_train, y_train, x_valid, y_valid, epochs=50, config={'lr': 1e-3}, logger=None))


In [None]:
torch.save(model.state_dict(), os.path.join(models_path, "model_RN_silk_512"))

# Hyperparameter Optimization



## BOHB

In [None]:
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
import hpbandster.core.nameserver as hpns
import hpbandster.core.result as hpres
from hpbandster.optimizers import BOHB as BOHB
from hpbandster.core.worker import Worker
import logging
logging.basicConfig(level=logging.DEBUG)


class MyWorker(Worker):
    def __init__(self, id, dataset, sleep_interval=0, *args, **kwargs):
        super().__init__(run_id=id, *args, **kwargs)
        self.sleep_interval=sleep_interval
        self.dataset = dataset

    def compute(self, config, budget, config_id, working_directory='.'):
        input_dim = config['input_dim']
        batch_dim = config['batch_dim']
        n_layers = config['n_layers']
        n_hidden = config['n_hidden']
        kernel_size = config['kernel_size']

        # Initialise Dataset
        x, y = self.dataset.batchify(batch_dim, input_dim, n_samples=100)
        x_train, x_valid, y_train, y_valid = dataset.split_data((x, y), 0.8, 0.2)

        assert x_train.shape == y_train.shape and x_valid.shape == y_valid.shape, "Expected equal shapes."
        print("AAAAAAAAAAA", x_train.shape)
        # Initialise Model
        model = WaveNet(input_dim, 1, 1, num_layers=n_layers, num_hidden=n_hidden, kernel_size=kernel_size, device=device)
        model.previous_rf = torch.zeros((batch_dim, 1, model.length_rf)).to(device)
        
        loss, epoch = fit(model, x_train, y_train, x_valid, y_valid, budget, config)

        logger.log(id=self.run_id, epoch=epoch, model=model, loss=loss)

        return {'loss': loss, 'info': 1}
        

    @staticmethod
    def get_configspace():
        cs = CS.ConfigurationSpace()

        # Define Parameter Search Space
        lr = CSH.UniformFloatHyperparameter('lr', lower=1e-6, upper=1e-1, default_value=1e-2, log=True)
        batch_dim = CSH.UniformIntegerHyperparameter('batch_dim', lower=1, upper=15, default_value=2)
        input_dim = CSH.UniformIntegerHyperparameter('input_dim', lower=1000, upper=20000, default_value=18000)
        n_layers = CSH.UniformIntegerHyperparameter('n_layers', lower=1, upper=16, default_value=9)
        n_hidden = CSH.UniformIntegerHyperparameter('n_hidden', lower=1, upper=10, default_value=1)
        kernel_size = CSH.UniformIntegerHyperparameter('kernel_size', lower=1, upper=2, default_value=1)
        cs.add_hyperparameters([lr, batch_dim, input_dim, n_layers, n_hidden, kernel_size])

        return cs


dataset = AudioDataset(preloader=preloader, device=device)
NS = hpns.NameServer(run_id='0', host='127.0.0.1', port=None)
NS.start()
w = MyWorker(dataset=dataset, sleep_interval = 0, nameserver='127.0.0.1', id='0')
w.run(background=True)
bohb = BOHB(configspace = w.get_configspace(),
            run_id = '0', nameserver='127.0.0.1',
            min_budget=1, max_budget=2)

res = bohb.run(n_iterations=1)



In [None]:
bohb.shutdown(shutdown_workers=True)
NS.shutdown()

In [None]:
model = WaveNet(2)
logger = Logger(logger_path)
logger.log(id=4, epoch=1, model=model, loss=8)


# Extra Stuff


In [None]:
!pip install --target="$pyenv" zounds

In [None]:
!pip install --target="$pyenv" ray[tune]
!pip install --target="$pyenv" hpbandster ConfigSpace
!pip install --target="$pyenv" hpbandster

In [None]:
!pip install --target="$pyenv" cdpam

In [None]:
!chmod -R 755 "/content/drive/My Drive/aNN_Colab/pyenv/ray/core/src/ray/thirdparty/redis/src/redis-server"
!chmod -R 755 "/content/drive/My Drive/aNN_Colab/pyenv/ray/core/src/ray/gcs/gcs_server"
!chmod -R 755 "/content/drive/My Drive/aNN_Colab/pyenv/ray/core/src/plasma/plasma_store_server"
!chmod -R 755 "/content/drive/My Drive/aNN_Colab/pyenv/ray/core/src/ray/raylet/raylet"

In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.6-cp36-cp36m-linux_x86_64.whl

In [None]:
VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION