# Training script for the WaveNet network on the VCTK corpus.

This script trains a network with the WaveNet using data from the VCTK corpus,
which can be freely downloaded at the following site (~10 GB):
http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html


In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
from __future__ import division 
from __future__ import print_function

import argparse
from datetime import datetime
import json
import os
import sys
import time
import librosa
import numpy as np
import numpy.random as random
import pandas as pd
from IPython.display import Audio
import IPython.display

import tensorflow as tf
from tensorflow.python.client import timeline

from wavenet import WaveNetModel, AudioReader, optimizer_factory, mu_law_encode, mu_law_decode, audio_reader

In [3]:
!nvidia-smi

Wed May  9 16:21:59 2018       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 384.90                 Driver Version: 384.90                    |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  Off  | 00000000:02:00.0 Off |                  N/A |
| 48%   82C    P2   244W / 250W |  11003MiB / 11172MiB |     86%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  Off  | 00000000:03:00.0 Off |                  N/A |
| 23%   31C    P8    16W / 250W |  10136MiB / 11172MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce GTX 108...  Off  | 00000000:81:00.0 Off |                  N/A |
| 23%   

In [4]:
# gpu
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [5]:
SAMPLES = 16000
TEMPERATURE = 1.0
LOGDIR = './logdir'
WAVENET_PARAMS = './wavenet_params.json'
SAVE_EVERY = 8000
SILENCE_THRESHOLD = 0.1

In [6]:
def mfcc_dist(sp1, sp2):
    m = min(sp1.shape[0], sp2.shape[0])
    print (sp1.shape, sp2.shape, m)
    return np.linalg.norm(sp1[:m, :]-sp2[:m, :]) 

In [7]:
def get_arguments(args):
    def _str_to_bool(s):
        """Convert string to bool (in argparse context)."""
        if s.lower() not in ['true', 'false']:
            raise ValueError('Argument needs to be a '
                             'boolean, got {}'.format(s))
        return {'true': True, 'false': False}[s.lower()]

    def _ensure_positive_float(f):
        """Ensure argument is a positive float."""
        if float(f) < 0:
            raise argparse.ArgumentTypeError(
                    'Argument must be greater than zero')
        return float(f)

    parser = argparse.ArgumentParser(description='WaveNet generation script')
    parser.add_argument(
        'checkpoint', type=str, help='Which model checkpoint to generate from')
    parser.add_argument(
        '--samples',
        type=int,
        default=SAMPLES,
        help='How many waveform samples to generate')
    parser.add_argument(
        '--temperature',
        type=_ensure_positive_float,
        default=TEMPERATURE,
        help='Sampling temperature')
    parser.add_argument(
        '--logdir',
        type=str,
        default=LOGDIR,
        help='Directory in which to store the logging '
        'information for TensorBoard.')
    parser.add_argument(
        '--wavenet_params',
        type=str,
        default=WAVENET_PARAMS,
        help='JSON file with the network parameters')
    parser.add_argument(
        '--wav_out_path',
        type=str,
        default=None,
        help='Path to output wav file')
    parser.add_argument(
        '--save_every',
        type=int,
        default=SAVE_EVERY,
        help='How many samples before saving in-progress wav')
    parser.add_argument(
        '--fast_generation',
        type=_str_to_bool,
        default=True,
        help='Use fast generation')
    parser.add_argument(
        '--wav_seed',
        type=str,
        default=None,
        help='The wav file to start generation from')
    parser.add_argument(
        '--gc_channels',
        type=int,
        default=None,
        help='Number of global condition embedding channels. Omit if no '
             'global conditioning.')
    parser.add_argument(
        '--gc_cardinality',
        type=int,
        default=None,
        help='Number of categories upon which we globally condition.')
    parser.add_argument(
        '--gc_id',
        type=int,
        default=None,
        help='ID of category to generate, if globally conditioned.')
    parser.add_argument(
        '--lc_channels', 
        type=int, 
        default=0,
        help='Number of local condition channels. Should be consistent with the local condition file provided. Default: 0')
    parser.add_argument(
        '--lc_path', 
        type=str,
        default=None,
        help='The path to the local condition csv file (no header). If not provided, assume no local condition.') 
    parser.add_argument(
        '--compare_path', 
        type=str,
        default=None,
        help='The path to the wave file for comparison with generation.') 
    parser.add_argument(
        '--lower_bound', 
        type=int,
        default=None,
        help='The lowerbound for log probability of sample. Any probability below e^lower_bound will be cleared to 0.' 
             'If not provided, assume no lower bound.') 
    arguments = parser.parse_args(args)
    if arguments.gc_channels is not None:
        if arguments.gc_cardinality is None:
            raise ValueError("Globally conditioning but gc_cardinality not "
                             "specified. Use --gc_cardinality=377 for full "
                             "VCTK corpus.")

        if arguments.gc_id is None:
            raise ValueError("Globally conditioning, but global condition was "
                              "not specified. Use --gc_id to specify global "
                              "condition.")

    return arguments


def write_wav(waveform, sample_rate, filename):
    y = np.array(waveform)
    librosa.output.write_wav(filename, y, sample_rate)
    print('Updated wav file at {}'.format(filename))


def create_seed(filename,
                sample_rate,
                quantization_channels,
                window_size,
                silence_threshold=SILENCE_THRESHOLD):
    audio, _ = librosa.load(filename, sr=sample_rate, mono=True)
    audio = audio_reader.trim_silence(audio, silence_threshold)

    quantized = mu_law_encode(audio, quantization_channels)
    cut_index = tf.cond(tf.size(quantized) < tf.constant(window_size),
                        lambda: tf.size(quantized),
                        lambda: tf.constant(window_size))

    return quantized[:cut_index]



def main(argv):
    args = get_arguments(argv)
    
    if args.lc_channels > 0:
        if args.lc_path is not None:
            raw_lc = pd.read_csv(args.lc_path, sep=',',header=None).values
        else:
            print (
                ValueError('Location condition is enabled,' 
                           'and a local condition file must be provided.'))
        lc = audio_reader.align_local_condition(raw_lc, args.samples)

        lc_placeholder = tf.placeholder(tf.float32)
    else:
        lc = None
        lc_placeholder = None

    started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
    logdir = os.path.join(args.logdir, 'generate', started_datestring)
    with open(args.wavenet_params, 'r') as config_file:
        wavenet_params = json.load(config_file)

    sess = tf.Session()

    net = WaveNetModel(
        batch_size=1,
        dilations=wavenet_params['dilations'],
        filter_width=wavenet_params['filter_width'],
        residual_channels=wavenet_params['residual_channels'],
        dilation_channels=wavenet_params['dilation_channels'],
        quantization_channels=wavenet_params['quantization_channels'],
        skip_channels=wavenet_params['skip_channels'],
        use_biases=wavenet_params['use_biases'],
        scalar_input=wavenet_params['scalar_input'],
        initial_filter_width=wavenet_params['initial_filter_width'],
        global_condition_channels=args.gc_channels,
        global_condition_cardinality=args.gc_cardinality,
        local_condition_channels=args.lc_channels)

    samples = tf.placeholder(tf.int32)

    if args.fast_generation:
        next_sample = net.predict_proba_incremental(samples, args.gc_id, local_condition = lc_placeholder)
    else:
        next_sample = net.predict_proba(samples, args.gc_id,  local_condition = lc_placeholder)

    if args.fast_generation:
        sess.run(tf.global_variables_initializer())
        sess.run(net.init_ops)

    variables_to_restore = {
        var.name[:-2]: var for var in tf.global_variables()
        if not ('state_buffer' in var.name or 'pointer' in var.name)}
    saver = tf.train.Saver(variables_to_restore)

    print('Restoring model from {}'.format(args.checkpoint))
    saver.restore(sess, args.checkpoint)

    decode = mu_law_decode(samples, wavenet_params['quantization_channels'])

    quantization_channels = wavenet_params['quantization_channels']
    if args.wav_seed:
        seed = create_seed(args.wav_seed,
                           wavenet_params['sample_rate'],
                           quantization_channels,
                           net.receptive_field)
        waveform = sess.run(seed).tolist()
    else:
        # Silence with a single random sample at the end.
        waveform = [quantization_channels / 2] * (net.receptive_field - 1)
        waveform.append(np.random.randint(quantization_channels))
    
    if lc is not None:
        waveform_lc = [[0] * args.lc_channels] * (net.receptive_field - 1)
    
    probs = []

    if args.fast_generation and args.wav_seed:
        # When using the incremental generation, we need to
        # feed in all priming samples one by one before starting the
        # actual generation.
        # TODO This could be done much more efficiently by passing the waveform
        # to the incremental generator as an optional argument, which would be
        # used to fill the queues initially.
        outputs = [next_sample]
        outputs.extend(net.push_ops)

        print('Priming generation...')
        for i, x in enumerate(waveform[-net.receptive_field: -1]):
            if i % 100 == 0:
                print('Priming sample {}'.format(i))
            sess.run(outputs, feed_dict={samples: x})
        print('Done.')

    true_wav = None
    true_mfcc = None
    encoded_true_wav = None
    if args.compare_path:
        true_wav, _ = librosa.load(args.compare_path, sr=wavenet_params['sample_rate'])
        true_mfcc = librosa.feature.mfcc(true_wav, sr=wavenet_params['sample_rate'], n_mfcc=args.lc_channels).T
        print(true_wav.shape, true_mfcc.shape)
        encoded_true_wav = mu_law_encode(true_wav, wavenet_params['quantization_channels'])
        encoded_true_wav = net._one_hot(encoded_true_wav)
        # Log ground truth
        tf.summary.image(
            "Ground_Truth",
            tf.reshape(
                encoded_true_wav, 
                [1, -1, wavenet_params['quantization_channels'], 1]),
            max_outputs=1
        )
        IPython.display.display(Audio(true_wav, rate=wavenet_params['sample_rate']))

    last_sample_timestamp = datetime.now()
    for step in range(args.samples):
        if args.lc_channels > 0:
            waveform_lc.append(list(lc[step, :]))
        if args.fast_generation:
            outputs = [next_sample]
            outputs.extend(net.push_ops)
            window = waveform[-1]
            if args.lc_channels > 0:
                lc_window = np.reshape(lc[step], (1, -1))
        else:
            if len(waveform) > net.receptive_field:
                window = waveform[-net.receptive_field:]
                if args.lc_channels > 0:
                    lc_window = np.reshape(waveform_lc[-net.receptive_field:], 
                                           (1, -1, args.lc_channels))
            else:
                window = waveform
                if args.lc_channels > 0:
                    lc_window = np.reshape(waveform_lc, (1, -1, args.lc_channels))
            outputs = [next_sample]
        
        # Run the WaveNet to predict the next sample.
        if args.lc_channels > 0:
            prediction = sess.run(outputs, feed_dict={samples: window, lc_placeholder: lc_window})[0]
        else:
            prediction = sess.run(outputs, feed_dict={samples: window})[0]

        # Scale prediction distribution using temperature.
        np.seterr(divide='ignore')
        scaled_prediction = np.log(prediction) / args.temperature
#         if args.lower_bound:
#             scaled_prediction[scaled_prediction < args.lower_bound] = -np.inf
        scaled_prediction = (scaled_prediction -
                             np.logaddexp.reduce(scaled_prediction))
        scaled_prediction = np.exp(scaled_prediction)
        np.seterr(divide='warn')

        # Prediction distribution at temperature=1.0 should be unchanged after
        # scaling.
        if args.temperature == 1.0:
            np.testing.assert_allclose(
                    prediction, scaled_prediction, atol=1e-5,
                    err_msg='Prediction scaling at temperature=1.0 '
                            'is not working as intended.')

        sample = np.random.choice(
            np.arange(quantization_channels), p=scaled_prediction)
        waveform.append(sample)
        probs.append(scaled_prediction)
        
        # Show progress only once per second.
        current_sample_timestamp = datetime.now()
        time_since_print = current_sample_timestamp - last_sample_timestamp
        if time_since_print.total_seconds() > 1.:
            print('Sample {:3<d}/{:3<d}'.format(step + 1, args.samples), end='\r')
            last_sample_timestamp = current_sample_timestamp

        # If we have partial writing, save the result so far.
        if (args.wav_out_path and args.save_every and
                (step + 1) % args.save_every == 0):
            out = sess.run(decode, feed_dict={samples: waveform})
            write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
            
            IPython.display.display(Audio(out, rate=wavenet_params['sample_rate']))

##            Loss doesn't make sense here
#             loss = tf.nn.softmax_cross_entropy_with_logits(
#                                         logits=tf.convert_to_tensor(np.array(probs), np.float32),
#                                         labels=tf.convert_to_tensor(encoded_true_wav[:, :step+1, :], np.float32))
#             reduced_loss = tf.reduce_mean(loss)
#             print(sess.run(reduced_loss))
            out_mfcc = librosa.feature.mfcc(out[net.receptive_field:], sr=wavenet_params['sample_rate'], n_mfcc=args.lc_channels).T
            print(mfcc_dist(out_mfcc, raw_lc))
            if args.compare_path:
                print(mfcc_dist(out_mfcc, true_mfcc), mfcc_dist(raw_lc, true_mfcc))
            
    # Introduce a newline to clear the carriage return from the progress.
    print()

    # Save the result as an audio summary.
    datestring = str(datetime.now()).replace(' ', 'T')
    writer = tf.summary.FileWriter(logdir)
    tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
    
    # Log prediction
    tf.summary.image(
        "Predicted_Probabilities",
        tf.reshape(
            np.array(probs), 
            [1, -1, wavenet_params['quantization_channels'], 1]),
        max_outputs=1
    )

    summaries = tf.summary.merge_all()
    summary_out = sess.run(summaries,
                           feed_dict={samples: np.reshape(waveform, [-1, 1])})
    writer.add_summary(summary_out)

    out = sess.run(decode, feed_dict={samples: waveform})

    # Save the result as a wav file.
    if args.wav_out_path:
        write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)

    print('Finished generating. The result can be viewed in TensorBoard.')
    IPython.display.display(Audio(out, rate=wavenet_params['sample_rate']))
    
## Loss doesn't make sense here
#     loss = tf.nn.softmax_cross_entropy_with_logits(
#                                         logits=tf.convert_to_tensor(np.array(probs), np.float32),
#                                         labels=tf.convert_to_tensor(encoded_true_wav, np.float32))
#     reduced_loss = tf.reduce_mean(loss)
#     print(sess.run(reduced_loss))
    out_mfcc = librosa.feature.mfcc(out[net.receptive_field:], sr=wavenet_params['sample_rate'], n_mfcc=args.lc_channels).T
    print(mfcc_dist(out_mfcc, raw_lc))
    if args.compare_path:
        print(mfcc_dist(out_mfcc, true_mfcc), mfcc_dist(raw_lc, true_mfcc))

## Args for p225, mfcc local conditions and small architecture

In [8]:
# tf.reset_default_graph()
# argv =[]
# argv.extend(['--temperature', "1"])
# argv.extend(['--compare_path', "../VCTK-Corpus/wav48/p225/p225_001.wav"])
# argv.extend(['--lc_channels', "20"])
# argv.extend(['--lc_path', "../VCTK-Corpus/mfcc/p225/p225_001.csv"])
# argv.extend(['--samples', "32825"])
# argv.extend(['--fast_generation', "False"])
# argv.extend(['--wavenet_params', "./wavenet_params2.json"])
# argv.extend(['--wav_out_path', "./results/gen_p225_mfcc_fix_146950_p225_001.wav"])
# argv.extend(["logdir/train/p225_mfcc_fix/model.ckpt-146950"])
# print(argv)
# main(argv)

In [9]:
# tf.reset_default_graph()
# argv =[]
# argv.extend(['--temperature', "1"])
# argv.extend(['--compare_path', "../VCTK-Corpus/wav48/p225/p225_366.wav"])
# argv.extend(['--lc_channels', "20"])
# argv.extend(['--lc_path', "../VCTK-Corpus/mfcc/p225/p225_366.csv"])
# argv.extend(['--samples', "84799"])
# argv.extend(['--fast_generation', "False"])
# argv.extend(['--wavenet_params', "./wavenet_params2.json"])
# argv.extend(['--wav_out_path', "./results/gen_p225_mfcc_fix_149200_p225_366.wav"])
# argv.extend(["logdir/train/p225_mfcc_fix/model.ckpt-149200"])
# print(argv)
# main(argv)

### Old

In [10]:
# tf.reset_default_graph()
# argv =[]
# argv.extend(['--temperature', "1"])
# argv.extend(['--compare_path', "../cmu_us_slt_arctic/wav/arctic_a0164.wav"])
# argv.extend(['--lc_channels', "40"])
# argv.extend(['--lc_path', "../cmu_us_slt_arctic/mfcc40/arctic_a0164.csv"])
# argv.extend(['--samples', "53681"])
# argv.extend(['--fast_generation', "False"])
# argv.extend(['--wavenet_params', "./wavenet_params2.json"])
# argv.extend(['--wav_out_path', "./results/gen_p225_cmu_mfcc40_99999_arctic_a0164.wav"])
# argv.extend(["logdir/train/cmu_mfcc40/model.ckpt-99999"])
# print(argv)
# main(argv)

In [11]:
# tf.reset_default_graph()
# argv =[]
# argv.extend(['--temperature', "1"])
# argv.extend(['--compare_path', "../cmu_us_slt_arctic/wav/arctic_a0164.wav"])
# argv.extend(['--lc_channels', "40"])
# argv.extend(['--lc_path', "../cmu_us_slt_arctic/mfcc40_48k/arctic_a0164.csv"])
# argv.extend(['--samples', "53681"])
# argv.extend(['--fast_generation', "False"])
# argv.extend(['--wavenet_params', "./wavenet_params2.json"])
# argv.extend(['--wav_out_path', "./results/gen_cmu_mfcc40_48k_93200_arctic_a0164.wav"])
# argv.extend(["logdir/train/cmu_mfcc40_48k/model.ckpt-93200"])
# print(argv)
# main(argv)

## Args for p225, mgc local conditions and small architecture

In [12]:
# tf.reset_default_graph()
# argv =[]
# argv.extend(['--temperature', "1"])
# argv.extend(['--compare_path', "../VCTK-Corpus/wav48/p225/p225_366.wav"])
# argv.extend(['--lc_channels', "26"])
# argv.extend(['--lc_path', "../VCTK-Corpus/mgc/p225/p225_366.csv"])
# argv.extend(['--samples', "84799"])
# argv.extend(['--fast_generation', "False"])
# argv.extend(['--wavenet_params', "./wavenet_params2.json"])
# argv.extend(['--wav_out_path', "./results/gen_p225_mgc_fix_3900_p225_366.wav"])
# argv.extend(["logdir/train/p225_mgc_fix/model.ckpt-3900"])
# print(argv)
# main(argv)

# argv =[]
# argv.extend(['--temperature', "1"])
# argv.extend(['--compare_path', "../VCTK-Corpus/wav48/p225/p225_001.wav"])
# argv.extend(['--lc_channels', "26"])
# argv.extend(['--lc_path', "../VCTK-Corpus/mgc/p225/p225_001.csv"])
# argv.extend(['--samples', "32825"])
# argv.extend(['--fast_generation', "False"])
# argv.extend(['--wavenet_params', "./wavenet_params2.json"])
# argv.extend(['--wav_out_path', "./results/gen_p225_mgc_fix_3850_p225_001.wav"])
# argv.extend(["logdir/train/p225_mgc_fix/model.ckpt-3850"])
# print(argv)
# main(argv)

## Args for mfcc local conditions (25+1 coeff, 16k， hop_length = 80, frame_length = 1024) and small architecture, CMU Dataset
* Output prob gets lower bounded and power 2

### With noise

In [13]:
# tf.reset_default_graph()
# argv =[]
# argv.extend(['--temperature', "1"])
# argv.extend(['--compare_path', "../cmu_us_slt_arctic/wav/arctic_a0164.wav"])
# argv.extend(['--lc_channels', "26"])
# argv.extend(['--lc_path', "../cmu_us_slt_arctic/mfcc25+1_16k/arctic_a0164.csv"])
# argv.extend(['--samples', "53681"])
# argv.extend(['--fast_generation', "False"])
# argv.extend(['--wavenet_params', "./wavenet_params2.json"])
# argv.extend(['--wav_out_path', "./results/gen_CMU_alt_mfcc25+1_16k_99999_arctic_a0164.wav"])
# argv.extend(["logdir/train/CMU_alt_mfcc25+1_16k/model.ckpt-99999"])
# print(argv)
# main(argv)

### Without noise

In [None]:
tf.reset_default_graph()
argv =[]
argv.extend(['--temperature', "1"])
argv.extend(['--compare_path', "../cmu_us_slt_arctic/wav/arctic_a0164.wav"])
argv.extend(['--lc_channels', "26"])
argv.extend(['--lc_path', "../cmu_us_slt_arctic/mfcc25+1_16k/arctic_a0164.csv"])
argv.extend(['--samples', "53681"])
argv.extend(['--fast_generation', "False"])
argv.extend(['--wavenet_params', "./wavenet_params2.json"])
argv.extend(['--wav_out_path', "./results/gen_CMU_alt_no_noise_retrain_mfcc25+1_16k_84356_arctic_a0164.wav"])
argv.extend(["logdir/train/CMU_alt_no_noise_retrain_mfcc25+1_16k/model.ckpt-84356"])
print(argv)
main(argv)

['--temperature', '1', '--compare_path', '../cmu_us_slt_arctic/wav/arctic_a0164.wav', '--lc_channels', '26', '--lc_path', '../cmu_us_slt_arctic/mfcc25+1_16k/arctic_a0164.csv', '--samples', '53681', '--fast_generation', 'False', '--wavenet_params', './wavenet_params2.json', '--wav_out_path', './results/gen_CMU_alt_no_noise_retrain_mfcc25+1_16k_84356_arctic_a0164.wav', 'logdir/train/CMU_alt_no_noise_retrain_mfcc25+1_16k/model.ckpt-84356']
Restoring model from logdir/train/CMU_alt_no_noise_retrain_mfcc25+1_16k/model.ckpt-84356
INFO:tensorflow:Restoring parameters from logdir/train/CMU_alt_no_noise_retrain_mfcc25+1_16k/model.ckpt-84356
(53681,) (105, 26)


Sample 98/53681