In [1]:
import os
import torch
from torch import nn, optim
import numpy as np
import matplotlib.pyplot as plt

from wavenet.audiodata import AudioData, AudioLoader
from wavenet.models_torch import Model, Generator

%matplotlib inline

x_len = 2**10
num_classes = 256
num_layers = 8
num_blocks = 2
num_hidden = 32
kernel_size = 2
learn_rate = 0.001
step_size = 50
gamma = 0.8
batch_size = 8
num_workers = 1
num_epochs = 5
model_file = 'model.pt'
use_visdom = True
n_new_samples = 1000
disp_interval = 1
device = torch.device('cpu')

## Create dataset and dataloader

In [2]:
filelist = ['assets/classical.wav']

def get_ylen(x_len, num_layers, num_blocks, kernel_size):
    rec_field = 1 + (kernel_size - 1) * \
                num_blocks * sum([2**k for k in range(num_layers)])
    return x_len - rec_field

y_len = get_ylen(x_len, num_layers, num_blocks, kernel_size)
print('y_len: {}'.format(y_len))

dataset = AudioData(filelist, x_len, y_len=y_len, 
                    num_classes=num_classes,store_tracks=True)
dataloader = AudioLoader(dataset, batch_size=batch_size, 
                         num_workers=num_workers)

y_len: 513


## Define and train model

In [3]:
wave_model = Model(x_len, num_channels=1, num_classes=num_classes, 
                   num_blocks=num_blocks, num_layers=num_layers,
                   num_hidden=num_hidden, kernel_size=kernel_size)

receptive_field: 511
Output width: 514


In [4]:
wave_model.set_device(device)
if os.path.isfile(model_file):
    print('Loading model data from file: {}'.format(model_file))
    wave_model.load_state_dict(torch.load(model_file))
else:
    print('Model data not found: {}'.format(model_file))
    print('Training new model.')
    wave_model.criterion = nn.CrossEntropyLoss()
    wave_model.optimizer = optim.Adam(wave_model.parameters(), 
                                      lr=learn_rate)
    wave_model.scheduler = optim.lr_scheduler.StepLR(
        wave_model.optimizer, step_size=step_size, gamma=gamma)
    
    wave_model.train(dataloader, num_epochs=num_epochs, 
                     disp_interval=disp_interval, 
                     use_visdom=use_visdom)

    print('Saving model data to file: {}'.format(model_file))
    torch.save(wave_model.state_dict(), model_file)

Model data not found: model.pt
Training new model.
Epoch 1 / 5
Learning Rate: [0.001]
Training Loss: 5.186990846063673
----------

Sample 0 / 44100
Sample 513 / 44100
Sample 1026 / 44100
Sample 1539 / 44100
Sample 2052 / 44100
Sample 2565 / 44100
Sample 3078 / 44100
Sample 3591 / 44100
Sample 4104 / 44100
Sample 4617 / 44100
Sample 5130 / 44100
Sample 5643 / 44100
Sample 6156 / 44100
Sample 6669 / 44100
Sample 7182 / 44100
Sample 7695 / 44100
Sample 8208 / 44100
Sample 8721 / 44100
Sample 9234 / 44100
Sample 9747 / 44100
Sample 10260 / 44100
Sample 10773 / 44100
Sample 11286 / 44100
Sample 11799 / 44100
Sample 12312 / 44100
Sample 12825 / 44100
Sample 13338 / 44100
Sample 13851 / 44100
Sample 14364 / 44100
Sample 14877 / 44100
Sample 15390 / 44100
Sample 15903 / 44100
Sample 16416 / 44100
Sample 16929 / 44100
Sample 17442 / 44100
Sample 17955 / 44100
Sample 18468 / 44100
Sample 18981 / 44100
Sample 19494 / 44100
Sample 20007 / 44100
Sample 20520 / 44100
Sample 21033 / 44100
Sample 2154

Sample 18981 / 44100
Sample 19494 / 44100
Sample 20007 / 44100
Sample 20520 / 44100
Sample 21033 / 44100
Sample 21546 / 44100
Sample 22059 / 44100
Sample 22572 / 44100
Sample 23085 / 44100
Sample 23598 / 44100
Sample 24111 / 44100
Sample 24624 / 44100
Sample 25137 / 44100
Sample 25650 / 44100
Sample 26163 / 44100
Sample 26676 / 44100
Sample 27189 / 44100
Sample 27702 / 44100
Sample 28215 / 44100
Sample 28728 / 44100
Sample 29241 / 44100
Sample 29754 / 44100
Sample 30267 / 44100
Sample 30780 / 44100
Sample 31293 / 44100
Sample 31806 / 44100
Sample 32319 / 44100
Sample 32832 / 44100
Sample 33345 / 44100
Sample 33858 / 44100
Sample 34371 / 44100
Sample 34884 / 44100
Sample 35397 / 44100
Sample 35910 / 44100
Sample 36423 / 44100
Sample 36936 / 44100
Sample 37449 / 44100
Sample 37962 / 44100
Sample 38475 / 44100
Sample 38988 / 44100
Sample 39501 / 44100
Sample 40014 / 44100
Sample 40527 / 44100
Sample 41040 / 44100
Sample 41553 / 44100
Sample 42066 / 44100
Sample 42579 / 44100
Sample 43092 

## Predict sequence

In [None]:
wave_generator = Generator(wave_model, dataset)

n_total_samples = x_len + n_new_samples
audio = dataset.tracks[0]['audio'][:n_total_samples]
sample_rate = dataset.tracks[0]['sample_rate']
x = audio[:x_len]
n_predictions = n_total_samples - x_len

In [None]:
print('Predicting {} samples'.format(n_predictions))
y = wave_generator.run(x, n_predictions, disp_interval=100)

In [None]:
idxs = np.linspace(0, 
                   (n_total_samples - 1) * sample_rate, 
                   n_total_samples)
enc = dataset.encoder
reencoded_audio = enc.expand(enc.normalize(audio, span='minmax'))
plt.plot(idxs, reencoded_audio, 'b')
plt.plot(idxs[x_len:], y, 'r')