<a href="https://colab.research.google.com/github/rezaghasemi/GenAI-audio-module/blob/main/assignement%201.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## WaveNet Model Explanation

This notebook implements a WaveNet model, which is a deep generative model for raw audio waveforms. Key components and concepts in this implementation include:

- **Dilated Causal Convolutions:** The core of WaveNet. These convolutions have a "hole" between the weights, allowing the receptive field to grow exponentially with depth without increasing the number of parameters linearly. This is crucial for capturing long-range dependencies in sequences like audio. The `dilate` function and `DilatedQueue` class likely support this mechanism.
- **Gated Activation Units:** Similar to gates in LSTMs, these units control the flow of information through the network, allowing it to selectively remember or forget information. The `filter_convs` and `gate_convs` in the `WaveNetModel` class implement this.
- **Residual and Skip Connections:** These connections help to train deeper networks by providing alternative paths for gradients to flow. Residual connections add the output of a dilated convolution block to its input, while skip connections contribute to the final output layer.
- **Softmax Output:** The model outputs a probability distribution over possible next values in the audio waveform. This allows for generating diverse and realistic audio.
- **Mu-law Compounding:** A technique used to represent audio signals with a non-linear quantization, which is particularly effective for low-amplitude signals. The `mu_law` function implements this.
- **Fast Generation:** The `generate_fast` method likely utilizes the `DilatedQueue` to speed up the generation process by avoiding redundant computations.

The `WaveNetModel` class combines these components into a neural network architecture capable of learning the underlying structure of audio data and generating new samples.

In this assignment, you will learn how to:

1. **Manipulate data:** Prepare and process audio data for the WaveNet model.
2. **Train the model:** Understand the training process and train the WaveNet model on your data.
3. **Generate Audio:** Use the trained model to generate new audio samples.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.autograd import Variable, Function
from scipy.io import wavfile
import time
import os
import librosa as lr
from torch.utils.data import Dataset

The `mu_law_compand` and `quantize` functions applies µ-law and quantization to an audio signal.

> Add blockquote



- **µ-law companding:** This is a non-linear process that compresses the dynamic range of an audio signal. It gives more resolution to lower amplitude values and less resolution to higher amplitude values. This is particularly useful for audio as the human ear is more sensitive to changes in quiet sounds than loud sounds.
- **Quantization:** After companding, the function quantizes the signal to a specified number of levels (defaulting to 256). This converts the continuous audio signal into a discrete representation, which is necessary for the WaveNet model's output layer (which predicts the probability of the next discrete audio value).

Essentially, this function prepares the raw audio data for the WaveNet model by applying a transformation that is perceptually motivated and converts the data into a format suitable for the model's output layer.

In [None]:

def mu_law_compand(x, mu=256):
    """µ-law companding: [-1,1] -> [-1,1]."""
    mu = mu - 1
    safe_x = np.clip(x, -1.0, 1.0)
    fx = np.sign(safe_x) * np.log1p(mu * np.abs(safe_x)) / np.log1p(mu)
    return fx

def quantize(fx, mu=256):
    """Quantize companded signal: [-1,1] -> [0, mu-1]."""
    mu = mu - 1
    return ((fx + 1) / 2 * mu + 0.5).astype(np.int32)

def dequantize(q, mu=256):
    """Inverse quantization: [0, mu-1] -> [-1,1]."""
    mu = mu - 1
    return 2 * q.astype(np.float32) / mu - 1

def mu_law_expand(fx, mu=256):
    """µ-law expansion: [-1,1] -> [-1,1] (approx inverse of compand)."""
    mu = mu - 1
    return np.sign(fx) * (np.exp(np.abs(fx) * np.log(mu + 1)) - 1) / mu

## WavenetDataset Class

The `WavenetDataset` class is a custom PyTorch `Dataset` designed to handle audio data for the WaveNet model. It prepares the data in a format suitable for training.

- **Initialization (`__init__`)**:
    - Takes the folder path containing WAV files, `item_length` (receptive field size), `target_length` (number of samples to predict), sampling rate (`sr`), number of quantization classes (`classes`), and a `normalize` flag.
    - Loads all WAV files from the specified folder.
    - Applies µ-law encoding (using the `mu_law_encoding` function) to each audio file.
    - Concatenates all processed audio data into a single NumPy array (`self.data`).

- **Length (`__len__`)**:
    - Returns the total number of possible training examples in the dataset. This is calculated by subtracting the `item_length` and `target_length` from the total length of the concatenated audio data.

- **Get Item (`__getitem__`)**:
    - Takes an index `idx` and returns a single training example.
    - Extracts a segment of length `item_length` as the input (`x`) and the subsequent `target_length` segment as the target (`y`).
    - Converts the input `x` into a one-hot encoded tensor, which is the format expected by the WaveNet model's input layer.
    - Converts the target `y` into a PyTorch LongTensor.
    - Returns the one-hot encoded input and the target tensor.

In [None]:
class WavenetDataset(Dataset):
    def __init__(self,
                 folder,
                 item_length=1600,
                 target_length=1,
                 sr=16000,
                 classes=256,
                 normalize=False
                 ):
        """
        Args:
            folder (str): Path to folder with audio files (.wav).
            item_length (int): Input receptive field size.
            target_length (int): Number of future samples to predict.
            sr (int): Sampling rate for audio.
            classes (int): µ-law quantization classes.
            normalize (bool): Normalize audio amplitude.
        """
        self.item_length = item_length
        self.target_length = target_length
        self.classes = classes
        self.sr = sr
        self.normalize = normalize

        # load and preprocess each file into a list of arrays
        self.data = []
        for fname in os.listdir(folder):
            if fname.endswith(".wav"):
                y, _ = lr.load(os.path.join(folder, fname), sr=self.sr, mono=True)
                if self.normalize:
                    y = lr.util.normalize(y)
                q = mu_law_compand(y, classes)
                q = quantize(q, classes)
                self.data.append(q)

        # for indexing: store (file_index, start_position) pairs
        self.index_map = []
        for f_idx, arr in enumerate(self.data):
            length = len(arr)
            max_start = length - (self.item_length + self.target_length)
            for start in range(max_start):
                self.index_map.append((f_idx, start))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        f_idx, start = self.index_map[idx]
        arr = self.data[f_idx]

        x = arr[start : start + self.item_length]
        y = arr[start + self.item_length : start + self.item_length + self.target_length]

        # one-hot input (classes x item_length)
        x_onehot = torch.zeros(self.classes, self.item_length)
        x_tensor = torch.from_numpy(x).long()
        x_onehot.scatter_(0, x_tensor.unsqueeze(0), 1.0)

        y = torch.from_numpy(y).long()
        return x_onehot, y


In [None]:

def dilate(x, dilation, init_dilation=1, pad_start=True):
    """
    :param x: Tensor of size (N, C, L), where N is the input dilation, C is the number of channels, and L is the input length
    :param dilation: Target dilation. Will be the size of the first dimension of the output tensor.
    :param pad_start: If the input length is not compatible with the specified dilation, zero padding is used. This parameter determines wether the zeros are added at the start or at the end.
    :return: The dilated tensor of size (dilation, C, L*N / dilation). The output might be zero padded at the start
    """

    [n, c, l] = x.size()
    dilation_factor = dilation / init_dilation
    if dilation_factor == 1:
        return x

    # zero padding for reshaping
    new_l = int(np.ceil(l / dilation_factor) * dilation_factor)
    if new_l != l:
        l = new_l
        x = constant_pad_1d(x, new_l, dimension=2, pad_start=pad_start)

    l_old = int(round(l / dilation_factor))
    n_old = int(round(n * dilation_factor))
    l = math.ceil(l * init_dilation / dilation)
    n = math.ceil(n * dilation / init_dilation)

    # reshape according to dilation
    x = x.permute(1, 2, 0).contiguous()  # (n, c, l) -> (c, l, n)
    x = x.view(c, l, n)
    x = x.permute(2, 0, 1).contiguous()  # (c, l, n) -> (n, c, l)

    return x


class DilatedQueue:
    def __init__(self, max_length, data=None, dilation=1, num_deq=1, num_channels=1, dtype=torch.FloatTensor):
        self.in_pos = 0
        self.out_pos = 0
        self.num_deq = num_deq
        self.num_channels = num_channels
        self.dilation = dilation
        self.max_length = max_length
        self.data = data
        self.dtype = dtype
        if data == None:
            self.data = Variable(dtype(num_channels, max_length).zero_())

    def enqueue(self, input):
        self.data[:, self.in_pos] = input.view(-1)
        self.in_pos = (self.in_pos + 1) % self.max_length

    def dequeue(self, num_deq=1, dilation=1):
        #       |
        #  |6|7|8|1|2|3|4|5|
        #         |
        start = self.out_pos - ((num_deq - 1) * dilation)
        if start < 0:
            t1 = self.data[:, start::dilation]
            t2 = self.data[:, self.out_pos % dilation:self.out_pos + 1:dilation]
            t = torch.cat((t1, t2), 1)
        else:
            t = self.data[:, start:self.out_pos + 1:dilation]

        self.out_pos = (self.out_pos + 1) % self.max_length
        return t

    def reset(self):
        self.dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
        self.data = Variable(self.dtype(self.num_channels, self.max_length).zero_())
        self.in_pos = 0
        self.out_pos = 0


class ConstantPad1d(Function):
    def __init__(self, target_size, dimension=0, value=0, pad_start=False):
        super(ConstantPad1d, self).__init__()
        self.target_size = target_size
        self.dimension = dimension
        self.value = value
        self.pad_start = pad_start

    def forward(self, input):
        self.num_pad = self.target_size - input.size(self.dimension)
        assert self.num_pad >= 0, 'target size has to be greater than input size'

        self.input_size = input.size()

        size = list(input.size())
        size[self.dimension] = self.target_size
        output = input.new(*tuple(size)).fill_(self.value)
        c_output = output

        # crop output
        if self.pad_start:
            c_output = c_output.narrow(self.dimension, self.num_pad, c_output.size(self.dimension) - self.num_pad)
        else:
            c_output = c_output.narrow(self.dimension, 0, c_output.size(self.dimension) - self.num_pad)

        c_output.copy_(input)
        return output

    def backward(self, grad_output):
        grad_input = grad_output.new(*self.input_size).zero_()
        cg_output = grad_output

        # crop grad_output
        if self.pad_start:
            cg_output = cg_output.narrow(self.dimension, self.num_pad, cg_output.size(self.dimension) - self.num_pad)
        else:
            cg_output = cg_output.narrow(self.dimension, 0, cg_output.size(self.dimension) - self.num_pad)

        grad_input.copy_(cg_output)
        return grad_input


def constant_pad_1d(input,
                    target_size,
                    dimension=0,
                    value=0,
                    pad_start=False):
    return ConstantPad1d(target_size, dimension, value, pad_start)(input)


In [None]:
class WaveNetModel(nn.Module):
    """
    A Complete Wavenet Model

    Args:
        layers (Int):               Number of layers in each block
        blocks (Int):               Number of wavenet blocks of this model
        dilation_channels (Int):    Number of channels for the dilated convolution
        residual_channels (Int):    Number of channels for the residual connection
        skip_channels (Int):        Number of channels for the skip connections
        classes (Int):              Number of possible values each sample can have
        output_length (Int):        Number of samples that are generated for each input
        kernel_size (Int):          Size of the dilation kernel
        dtype:                      Parameter type of this model

    Shape:
        - Input: :math:`(N, C_{in}, L_{in})`
        - Output: :math:`()`
        L should be the length of the receptive field
    """
    def __init__(self,
                 layers=10,
                 blocks=4,
                 dilation_channels=32,
                 residual_channels=32,
                 skip_channels=256,
                 end_channels=256,
                 classes=256,
                 output_length=32,
                 kernel_size=2,
                 dtype=torch.FloatTensor,
                 bias=False):

        super(WaveNetModel, self).__init__()

        self.layers = layers
        self.blocks = blocks
        self.dilation_channels = dilation_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.classes = classes
        self.kernel_size = kernel_size
        self.dtype = dtype

        # build model
        receptive_field = 1
        init_dilation = 1

        self.dilations = []
        self.dilated_queues = []
        # self.main_convs = nn.ModuleList()
        self.filter_convs = nn.ModuleList()
        self.gate_convs = nn.ModuleList()
        self.residual_convs = nn.ModuleList()
        self.skip_convs = nn.ModuleList()

        # 1x1 convolution to create channels
        self.start_conv = nn.Conv1d(in_channels=self.classes,
                                    out_channels=residual_channels,
                                    kernel_size=1,
                                    bias=bias)

        for b in range(blocks):
            additional_scope = kernel_size - 1
            new_dilation = 1
            for i in range(layers):
                # dilations of this layer
                self.dilations.append((new_dilation, init_dilation))

                # dilated queues for fast generation
                self.dilated_queues.append(DilatedQueue(max_length=(kernel_size - 1) * new_dilation + 1,
                                                        num_channels=residual_channels,
                                                        dilation=new_dilation,
                                                        dtype=dtype))

                # dilated convolutions
                self.filter_convs.append(nn.Conv1d(in_channels=residual_channels,
                                                   out_channels=dilation_channels,
                                                   kernel_size=kernel_size,
                                                   bias=bias))

                self.gate_convs.append(nn.Conv1d(in_channels=residual_channels,
                                                 out_channels=dilation_channels,
                                                 kernel_size=kernel_size,
                                                 bias=bias))

                # 1x1 convolution for residual connection
                self.residual_convs.append(nn.Conv1d(in_channels=dilation_channels,
                                                     out_channels=residual_channels,
                                                     kernel_size=1,
                                                     bias=bias))

                # 1x1 convolution for skip connection
                self.skip_convs.append(nn.Conv1d(in_channels=dilation_channels,
                                                 out_channels=skip_channels,
                                                 kernel_size=1,
                                                 bias=bias))

                receptive_field += additional_scope
                additional_scope *= 2
                init_dilation = new_dilation
                new_dilation *= 2

        self.end_conv_1 = nn.Conv1d(in_channels=skip_channels,
                                  out_channels=end_channels,
                                  kernel_size=1,
                                  bias=True)

        self.end_conv_2 = nn.Conv1d(in_channels=end_channels,
                                    out_channels=classes,
                                    kernel_size=1,
                                    bias=True)

        # self.output_length = 2 ** (layers - 1)
        self.output_length = output_length
        self.receptive_field = receptive_field

    def wavenet(self, input, dilation_func):

        x = self.start_conv(input)
        skip = 0

        # WaveNet layers
        for i in range(self.blocks * self.layers):

            #            |----------------------------------------|     *residual*
            #            |                                        |
            #            |    |-- conv -- tanh --|                |
            # -> dilate -|----|                  * ----|-- 1x1 -- + -->	*input*
            #                 |-- conv -- sigm --|     |
            #                                         1x1
            #                                          |
            # ---------------------------------------> + ------------->	*skip*

            (dilation, init_dilation) = self.dilations[i]

            residual = dilation_func(x, dilation, init_dilation, i)

            # dilated convolution
            filter = self.filter_convs[i](residual)
            filter = F.tanh(filter)
            gate = self.gate_convs[i](residual)
            gate = F.sigmoid(gate)
            x = filter * gate

            # parametrized skip connection
            s = x
            if x.size(2) != 1:
                 s = dilate(x, 1, init_dilation=dilation)
            s = self.skip_convs[i](s)
            try:
                skip = skip[:, :, -s.size(2):]
            except:
                skip = 0
            skip = s + skip

            x = self.residual_convs[i](x)
            x = x + residual[:, :, (self.kernel_size - 1):]

        x = F.relu(skip)
        x = F.relu(self.end_conv_1(x))
        x = self.end_conv_2(x)

        return x

    def wavenet_dilate(self, input, dilation, init_dilation, i):
        x = dilate(input, dilation, init_dilation)
        return x

    def queue_dilate(self, input, dilation, init_dilation, i):
        queue = self.dilated_queues[i]
        queue.enqueue(input.data[0])
        x = queue.dequeue(num_deq=self.kernel_size,
                          dilation=dilation)
        x = x.unsqueeze(0)

        return x

    def forward(self, input):
        x = self.wavenet(input,
                         dilation_func=self.wavenet_dilate)

        # reshape output
        [n, c, l] = x.size()
        l = self.output_length
        x = x[:, :, -l:]
        x = x.transpose(1, 2).contiguous()
        x = x.view(n * l, c)
        return x

    def generate(self,
                 num_samples,
                 first_samples=None,
                 temperature=1.):
        self.eval()
        if first_samples is None:
            first_samples = self.dtype(1).zero_()
        generated = Variable(first_samples, volatile=True)

        num_pad = self.receptive_field - generated.size(0)
        if num_pad > 0:
            generated = constant_pad_1d(generated, self.scope, pad_start=True)
            print("pad zero")

        for i in range(num_samples):
            input = Variable(torch.FloatTensor(1, self.classes, self.receptive_field).zero_())
            input = input.scatter_(1, generated[-self.receptive_field:].view(1, -1, self.receptive_field), 1.)

            x = self.wavenet(input,
                             dilation_func=self.wavenet_dilate)[:, :, -1].squeeze()

            if temperature > 0:
                x /= temperature
                prob = F.softmax(x, dim=0)
                prob = prob.cpu()
                np_prob = prob.data.numpy()
                x = np.random.choice(self.classes, p=np_prob)
                x = Variable(torch.LongTensor([x]))#np.array([x])
            else:
                x = torch.max(x, 0)[1].float()

            generated = torch.cat((generated, x), 0)

        generated = (generated / self.classes) * 2. - 1
        mu_gen = mu_law_expand(generated, self.classes)

        self.train()
        return mu_gen

    def generate_fast(self,
                      num_samples,
                      first_samples=None,
                      temperature=1.,
                      regularize=0.,
                      progress_callback=None,
                      progress_interval=100):
        self.eval()
        if first_samples is None:
            first_samples = torch.LongTensor(1).zero_() + (self.classes // 2)
        first_samples = Variable(first_samples)

        # reset queues
        for queue in self.dilated_queues:
            queue.reset()

        num_given_samples = first_samples.size(0)
        total_samples = num_given_samples + num_samples

        input = Variable(torch.FloatTensor(1, self.classes, 1).zero_())
        input = input.scatter_(1, first_samples[0:1].view(1, -1, 1), 1.)

        # fill queues with given samples
        for i in range(num_given_samples - 1):
            x = self.wavenet(input,
                             dilation_func=self.queue_dilate)
            input.zero_()
            input = input.scatter_(1, first_samples[i + 1:i + 2].view(1, -1, 1), 1.).view(1, self.classes, 1)

            # progress feedback
            if i % progress_interval == 0:
                if progress_callback is not None:
                    progress_callback(i, total_samples)

        # generate new samples
        generated = np.array([])
        regularizer = torch.pow(Variable(torch.arange(self.classes)) - self.classes / 2., 2)
        regularizer = regularizer.squeeze() * regularize
        tic = time.time()
        for i in range(num_samples):
            x = self.wavenet(input,
                             dilation_func=self.queue_dilate).squeeze()

            x -= regularizer

            if temperature > 0:
                # sample from softmax distribution
                x /= temperature
                prob = F.softmax(x, dim=0)
                prob = prob.cpu()
                np_prob = prob.data.numpy()
                x = np.random.choice(self.classes, p=np_prob)
                x = np.array([x])
            else:
                # convert to sample value
                x = torch.max(x, 0)[1][0]
                x = x.cpu()
                x = x.data.numpy()

            o = (x / self.classes) * 2. - 1
            generated = np.append(generated, o)

            # set new input
            x = Variable(torch.from_numpy(x).type(torch.LongTensor))
            input.zero_()
            input = input.scatter_(1, x.view(1, -1, 1), 1.).view(1, self.classes, 1)

            if (i+1) == 100:
                toc = time.time()
                print("one generating step does take approximately " + str((toc - tic) * 0.01) + " seconds)")

            # progress feedback
            if (i + num_given_samples) % progress_interval == 0:
                if progress_callback is not None:
                    progress_callback(i + num_given_samples, total_samples)

        self.train()
        mu_gen = mu_law(generated, self.classes)
        wavfile.write("generated.wav", 16000, mu_gen)
        return  mu_gen


    def parameter_count(self):
        par = list(self.parameters())
        s = sum([np.prod(list(d.size())) for d in par])
        return s

    def cpu(self, type=torch.FloatTensor):
        self.dtype = type
        for q in self.dilated_queues:
            q.dtype = self.dtype
        super().cpu()


In [None]:
model = WaveNetModel(layers=10,
                     blocks=3,
                     dilation_channels=32,
                     residual_channels=32,
                     skip_channels=1024,
                     end_channels=512,
                     output_length=16,
                     dtype=torch.FloatTensor,
                     bias=True)

In [None]:
checkpoint_path = "checkpoints/trained_model.pth"

# Load checkpoint (it might contain 'state_dict' or the whole model)
torch.serialization.add_safe_globals([WaveNetModel])
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))

<All keys matched successfully>

In [None]:
model.generate_fast(10000)

one generating step does take approximately 0.006373369693756104 seconds)


array([ 0.00266565,  0.00295683, -0.00054255, ...,  0.09697959,
        0.10611724,  0.08467614], shape=(10000,))