In [1]:
import time

from model import WaveNetModel, Optimizer, WaveNetData

from IPython.display import Audio
from IPython.core.debugger import Tracer
from matplotlib import pyplot as plt
from matplotlib import pylab as pl
from IPython import display
import torch
import numpy as np

%matplotlib notebook

In [2]:
model = WaveNetModel(num_layers=10,
                     num_blocks=2,
                     num_classes=128,
                     hidden_channels=64)

print("model: ", model)
print("scope: ", model.scope)

data = WaveNetData('train_samples/violin.wav',
                   input_length=model.scope,
                   target_length=model.last_block_scope,
                   num_classes=model.num_classes)

model:  WaveNetModel (
  (main): Sequential (
    (b0-l0.wavenet_layer-d1): WaveNetLayer (
      (dil_conv): Conv1d(1, 128, kernel_size=(2,), stride=(1,), bias=False)
      (onexone_conv): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
    )
    (b0-l1.wavenet_layer-d2): WaveNetLayer (
      (dil_conv): Conv1d(128, 128, kernel_size=(2,), stride=(1,), bias=False)
      (onexone_conv): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
    )
    (b0-l2.wavenet_layer-d4): WaveNetLayer (
      (dil_conv): Conv1d(128, 128, kernel_size=(2,), stride=(1,), bias=False)
      (onexone_conv): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
    )
    (b0-l3.wavenet_layer-d8): WaveNetLayer (
      (dil_conv): Conv1d(128, 128, kernel_size=(2,), stride=(1,), bias=False)
      (onexone_conv): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
    )
    (b0-l4.wavenet_layer-d16): WaveNetLayer (
      (dil_conv): Conv1d(128, 128, kernel_size=(2,), stride=(1,), 

In [3]:
start_tensor = data.get_minibatch([model.scope])[0].squeeze()

plt.ion()
plt.plot(start_tensor[-200:].numpy())
plt.ioff()

<IPython.core.display.Javascript object>

In [9]:
optimizer = Optimizer(model,
                      learning_rate=0.01,
                      stop_threshold=0.1,
                      avg_length=4,
                      mini_batch_size=2)

def hook(losses):
    ax.clear()
    ax.plot(losses)
    fig.canvas.draw()
    
optimizer.hook = hook

In [11]:
fig = plt.figure()
ax = fig.add_subplot(111)
plt.ion()

fig.show()
fig.canvas.draw()

print('start training...')
tic = time.time()
optimizer.train(data, epochs=100)
toc = time.time()
print('Training took {} seconds.'.format(toc-tic))

<IPython.core.display.Javascript object>

start training...
epoch  0
epoch  1
epoch  2
epoch  3
epoch  4
epoch  5
epoch  6
epoch  7
epoch  8
epoch  9
epoch  10
epoch  11
epoch  12
epoch  13
epoch  14
epoch  15
epoch  16
epoch  17
epoch  18
epoch  19
epoch  20
epoch  21
epoch  22
epoch  23
epoch  24
epoch  25
epoch  26
epoch  27
epoch  28
epoch  29
epoch  30
epoch  31
epoch  32
epoch  33
epoch  34
epoch  35
epoch  36
epoch  37
epoch  38
epoch  39
epoch  40
epoch  41
epoch  42
epoch  43
epoch  44
epoch  45
epoch  46
epoch  47
epoch  48
epoch  49
epoch  50
epoch  51
epoch  52
epoch  53
epoch  54
epoch  55
epoch  56
epoch  57
epoch  58
epoch  59
epoch  60
epoch  61
epoch  62
epoch  63
epoch  64
epoch  65
epoch  66
epoch  67
epoch  68
epoch  69
epoch  70
epoch  71
epoch  72
epoch  73
epoch  74
epoch  75
epoch  76
epoch  77
epoch  78
epoch  79
epoch  80
epoch  81
epoch  82
epoch  83
epoch  84
epoch  85
epoch  86
epoch  87
epoch  88
epoch  89
epoch  90
epoch  91
epoch  92
epoch  93
epoch  94
epoch  95
epoch  96
epoch  97
epoch  98
ep

In [13]:
torch.save(model.state_dict(), "model_parameters/saber_11-2-256-128")