In [1]:
import time
from wavenet_model import *
from audio_data import WavenetDataset
from wavenet_training import *
from model_logging import *
from scipy.io import wavfile

dtype = torch.FloatTensor
ltype = torch.LongTensor

use_cuda = torch.cuda.is_available()
if use_cuda:
    print('use gpu')
    dtype = torch.cuda.FloatTensor
    ltype = torch.cuda.LongTensor

In [2]:
model = WaveNetModel(layers=8,
                     blocks=4,
                     dilation_channels=16,
                     residual_channels=16,
                     skip_channels=16,
                     output_length=8,
                     dtype=dtype)

print('model: ', model)
print('receptive field: ', model.receptive_field)
print('parameter count: ', model.parameter_count())

model:  WaveNetModel(
  (filter_convs): ModuleList(
    (0): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (1): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (2): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (3): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (4): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (5): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (6): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (7): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (8): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (9): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (10): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (11): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (12): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias=False)
    (13): Conv1d (16, 16, kernel_size=(2,), stride=(1,), bias

In [3]:
data = WavenetDataset(dataset_file='train_samples/saber/dataset.npz',
                      item_length=model.receptive_field + model.output_length - 1,
                      target_length=model.output_length,
                      file_location='train_samples/saber',
                      test_stride=20)
print('the dataset has ' + str(len(data)) + ' items')

the dataset has 17945 items


In [4]:
def generate_and_log_samples(step):
    sample_length=4000
    gen_model = load_latest_model_from('snapshots')
    print("start generating...")
    samples = generate_audio(gen_model,
                             length=sample_length,
                             temperatures=[0])
    tf_samples = tf.convert_to_tensor(samples, dtype=tf.float32)
    logger.audio_summary('temperature 0', tf_samples, step, sr=16000)

    samples = generate_audio(gen_model,
                             length=sample_length,
                             temperatures=[0.5])
    tf_samples = tf.convert_to_tensor(samples, dtype=tf.float32)
    logger.audio_summary('temperature 0.5', tf_samples, step, sr=16000)
    print("audio clips generated")

In [5]:
logger = TensorboardLogger(log_interval=200,
                           validation_interval=200,
                           generate_interval=500,
                           generate_function=generate_and_log_samples,
                           log_dir="logs")

In [7]:
trainer = WavenetTrainer(model=model,
                           dataset=data,
                           lr=0.0001,
                           weight_decay=0.1,
                           snapshot_path='snapshots',
                           snapshot_name='saber_model',
                           snapshot_interval=500)

print('start training...')
tic = time.time()
trainer.train(batch_size=8,
              epochs=20)
toc = time.time()
print('Training took {} seconds.'.format(toc - tic))

start training...
epoch 0
loss at step 50: 5.551240663528443


Process Process-15:
Process Process-12:
Process Process-14:
Process Process-11:
Process Process-10:
Process Process-9:
Process Process-13:
Process Process-16:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/multiprocessin

KeyboardInterrupt: 

  File "/Users/vincentherrmann/anaconda3/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 36, in _worker_loop
    r = index_queue.get()
  File "/Users/vincentherrmann/anaconda3/lib/python3.5/multiprocessing/queues.py",