# WaveNet Sample Generation
Fast generation of samples from a pretrained WaveNet model

In [1]:
from model import Optimizer, WaveNetData
from WaveNetModel2 import WaveNetModel2

import torch
import numpy as np
import time

from IPython.display import Audio
from matplotlib import pyplot as plt
from matplotlib import pylab as pl
from IPython import display

%matplotlib notebook

## Load Model

In [2]:
train_sample = "train_samples/sapiens_11025.wav"
parameters = "model_parameters/sapiens_12-3-256-32-32-128-2"
layers = 12
blocks = 3
classes = 256
dilation_channels = 32
residual_channels = 32
skip_channels = 128
kernel_size = 2

use_cuda = torch.cuda.is_available()

In [3]:
model = WaveNetModel2(layers=layers,
                      blocks=blocks,
                      dilation_channels=dilation_channels,
                      residual_channels=residual_channels,
                      skip_channels=skip_channels,
                      classes=classes)

if use_cuda:
    model.cuda()
    print("use cuda")

#print("model: ", model)
print("scope: ", model.scope)

if use_cuda:
    model.load_state_dict(torch.load(parameters))
else:
    model.load_state_dict(torch.load(parameters, map_location=lambda storage, loc: storage))

data = WaveNetData(train_sample,
                   input_length=model.scope,
                   target_length=model.last_block_scope,
                   num_classes=model.classes,
                   cuda=use_cuda)

scope:  14333


In [8]:
start_data = data.get_minibatch([model.scope+15000])[0].squeeze()
#start_tensor = torch.zeros((model.scope)) + 0.0

plt.plot(start_data.numpy())

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x110f220b8>]

## Generate Samples


In [13]:
num_samples = 11025 # number of samples that will be generated
sample_rate = 11025
out_file = "generated_samples/sapiens_12-3-256-32-32-128-2.wav"

In [43]:
from ipywidgets import FloatProgress
from IPython.display import display
progress = FloatProgress(min=0, max=100)
display(progress)

def p_callback(i, total):
    progress.value += 1

tic = time.time()
generated_sample = model.generate_fast(num_samples, 
                                       first_samples=start_data,
                                       #first_samples=torch.zeros((1)),
                                       progress_callback=p_callback,
                                       sampled_generation=True,
                                       temperature=1.5)
toc = time.time()
print('Generating took {} seconds.'.format(toc-tic))

Generating took 273.9111421108246 seconds.


In [44]:
fig = plt.figure()
plt.plot(generated_sample[0:1000])

from IPython.display import Audio
Audio(np.array(generated_sample), rate=sample_rate)

<IPython.core.display.Javascript object>

In [45]:
print(np.array(generated_sample))

from scipy.io import wavfile
wavfile.write(out_file, sample_rate, np.array(generated_sample))

[ 0.         0.         0.        ..., -0.0078125 -0.0078125 -0.0078125]
