This notebook defines and tests the implementation of the WaveNet model:

The original paper is here: https://arxiv.org/pdf/1609.03499.pdf

Some additional reference for code: https://github.com/ibab/tensorflow-wavenet/blob/master/wavenet/model.py

In [None]:
import librosa
from IPython.display import Audio
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline

## Loading the music track

In [None]:
music_track, sample_rate = librosa.load("../../data/music/kalpol_introl.mp3", mono=True, sr=220)
Audio(data=music_track, rate=sample_rate, autoplay=False)

In [None]:
plt.plot(range(len(music_track)), music_track)
plt.show()

## WaveNet model

As outined in the paper, the model takes non-linear qunatization of input waveform as input and predict it for the next signal.

The exact implemntation was taken from the github link above

In [None]:
def mu_law_encode(audio, quantization_channels=256):
    mu = float(quantization_channels) - 1
    # Perform mu-law companding transformation (ITU-T, 1988).
    # Minimum operation is here to deal with rare large amplitudes caused by resampling.
    safe_audio_abs = np.minimum(np.abs(audio), 1.0)
    magnitude = np.log1p(mu * safe_audio_abs) / np.log1p(mu)
    signal = np.sign(audio) * magnitude
    # Quantize signal to the specified number of levels.
    return ((signal + 1) / 2 * mu + 0.5).astype(int)

def mu_law_decode(output, quantization_channels=256):
    mu = quantization_channels - 1
    # Map values back to [-1, 1].
    signal = 2 * (output.astype(float) / mu) - 1
    # Perform inverse of mu-law transformation.
    magnitude = (1 / mu) * ((1 + mu)**abs(signal) - 1)
    return np.sign(signal) * magnitude

In [None]:
class WaveNet:
    
    def __init__(self, model_name, n_dilation_blocks, max_dilation_base,
                 n_dilation_filters, n_residual_channels=1, learning_rate=0.1):
        with tf.variable_scope(model_name):
            self.model_name = model_name
            self.n_dilation_blocks = n_dilation_blocks
            self.max_dilation_base = max_dilation_base
            self.n_dilation_filters = n_dilation_filters
            self.n_residual_channels = n_residual_channels
            self.learning_rate = 0.1
            self.session = tf.Session()

            # The model / prediction is assumed to start with quantizied signals
            # That way they are categorigal variables and have to be mapped back to continious form
            # The embedding dimension depends on the number of resudual channels flowing through the model
            self.input_layer = tf.placeholder(shape=(1, None), dtype=tf.int32)
            self.embeddings = tf.Variable(tf.random_uniform([256, n_residual_channels], -1, 1), name="embeddings")
            self.embedded_input = tf.nn.embedding_lookup(self.embeddings, self.input_layer)
            
            # Setting the flow of residuals - will be additively accumulated through the network
            residual_flow = self.embedded_input

            # Setting skip_connection_layer - it is going to accumulate all outputs of intermediary layers
            # For the future: this has to point for current state of residual_flow if I go for N_RESIDUAL_CHANNELS > 1
            skip_connections_layer = self.embedded_input

            # Going through stacks of dilated convolution layers
            dilation_steps = [2 ** i for i in range(max_dilation_base)]
            for block_num in range(1, n_dilation_blocks + 1):
                for d in dilation_steps:

                    # Padding the intermediary sequences / layers from left
                    paddings = [[0, 0], [d, 0], [0, 0]]
                    conv_padded = tf.pad(residual_flow, paddings, "constant")

                    # Defining gated activation unit
                    conv_gate = tf.sigmoid(tf.layers.conv1d(conv_padded, filters=n_dilation_filters,
                                                            kernel_size=2, padding="valid", 
                                                            name="dilated_block" + str(block_num) + "_" + str(d) + "_gate", 
                                                            use_bias=False, dilation_rate=d
                                                           )
                                          )     

                    # Defining filter for gated activation unit
                    conv_filter = tf.tanh(tf.layers.conv1d(conv_padded, filters=n_dilation_filters,
                                                           kernel_size=2, padding="valid", 
                                                           name="dilated_block" + str(block_num) + "_" + str(d) + "_filter", 
                                                           use_bias=False, dilation_rate=d
                                                          )
                                         ) 

                    # Calculating layer output to send via skip-connections 
                    # 1x1 convolutions are applied to squeeze several dilation filters (if present) 
                    conv_residual = tf.layers.conv1d(conv_filter * conv_gate, filters=n_residual_channels,
                                                     kernel_size=1, strides=1, use_bias=False,
                                                     name="dilated_block" + str(block_num) + "_" + str(d) + "_residual"
                                                    )

                    # Necessary bookkeeping: updating residual_flow and connecting current output with final layers
                    skip_connections_layer += conv_residual
                    residual_flow += conv_residual
                    
            # Squashing residual channels for skip connections layer
            trunc_skip_layer = tf.layers.conv1d(tf.nn.relu(skip_connections_layer), 
                                                filters=1, kernel_size=1, use_bias=False,
                                                name="skip_connections_conv")
            
            # Getting the base for quantizied signal probabilities
            softmax_base = tf.layers.conv1d(tf.nn.relu(trunc_skip_layer), 
                                            filters=256, kernel_size=1, use_bias=False,
                                            name="softmax_base")
                        
            self.pred_proba = tf.nn.softmax(softmax_base)
            
            # Reminder: the WaveNet is a generative model so it is supposed to predict future signal values
            # Thus I am comparing predicted probablilities against the original shifted input
            loss_logits = softmax_base[:, :-1, :]
            loss_labels = tf.one_hot(self.input_layer, depth=256, on_value=1, off_value=0)[:, 1:, :]
            self.loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=loss_logits, labels=loss_labels))
            
            self.adam = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss, var_list=self.model_variables())
                    
            self.session.run(tf.global_variables_initializer())

    def model_variables(self):
        return [x for x in tf.trainable_variables() if self.model_name in x.name]
    
    def model_size(self):
        var_sizes = [tf.size(x) for x in self.model_variables()]
        return self.session.run(tf.reduce_sum(var_sizes))
    
    def receptive_field(self):
        return sum([2 ** i for i in range(self.max_dilation_base)]) * self.n_dilation_blocks
    
    def predict(self, track):
        return self.session.run(self.pred_proba, feed_dict={self.input_layer: track})

    def show_loss(self, track):
        return self.session.run(self.loss, feed_dict={self.input_layer: track})
    
    def train_op(self, track):
        self.session.run(self.adam, feed_dict={self.input_layer: track})
        
    def generate(self, track):
        next_proba = self.session.run(self.pred_proba[:, -1, :], feed_dict={self.input_layer: track})
        return np.random.choice(range(256), p=next_proba[0])        

## Testing WaveNet implementation

In [None]:
tf.reset_default_graph()
wv_model = WaveNet("wv_test", 
                   n_dilation_blocks=4, 
                   max_dilation_base=8, 
                   n_dilation_filters=10, 
                   n_residual_channels=5)

print(wv_model.model_size())
print(wv_model.receptive_field())

In [None]:
track = mu_law_encode(music_track[:1000].reshape(1, -1))

In [None]:
for i in range(100):
    wv_model.train_op(track)
    print(wv_model.show_loss(track))

In [None]:
continued_track = np.array(track)
for i in range(1000):
    if i % 100 == 0 or i < 10:
        print(i)
    next_signal = wv_model.generate(continued_track[:, -wv_model.receptive_field():])
    continued_track = np.append(continued_track, np.array(next_signal).reshape(1, 1), axis=1)

In [None]:
Audio(data=music_track[:1000], rate=sample_rate, autoplay=False)

In [None]:
Audio(data=mu_law_decode(continued_track), rate=sample_rate, autoplay=False)

In [None]:
plt.plot(range(len(mu_law_decode(continued_track)[0])), mu_law_decode(continued_track)[0])
plt.show()