In [1]:
import time
from wavenet_model import *
from audio_data import WavenetDataset
from wavenet_training import *
from model_logging import *
#from optimizers import SGDNormalized
from scipy.io import wavfile

dtype = torch.FloatTensor
ltype = torch.LongTensor

use_cuda = torch.cuda.is_available()
if use_cuda:
    print('use gpu')
    dtype = torch.cuda.FloatTensor
    ltype = torch.cuda.LongTensor

In [2]:
model = WaveNetModel(layers=6,
                     blocks=4,
                     dilation_channels=16,
                     residual_channels=16,
                     skip_channels=32,
                     output_length=8,
                     dtype=dtype, 
                    bias=False)
model = load_latest_model_from('snapshots', use_cuda=use_cuda)
#model = torch.load('snapshots/saber_model_2017-12-18_20-47-36', map_location=lambda storage, loc: storage)
model.dtype = dtype
if use_cuda:
    model.cuda()
else:
    model.cpu()

print('model: ', model)
print('receptive field: ', model.receptive_field)
print('parameter count: ', model.parameter_count())

relu network
load model snapshots/saber_model_2017-12-20_19-43-52
model:  WaveNetModel(
  (main_convs): ModuleList(
    (0): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (1): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (2): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (3): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (4): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (5): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (6): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (7): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (8): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (9): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (10): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (11): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (12): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (13): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (14): Conv1d (32, 32, kernel_size=(2,), stride=(1,))
    (15): Conv1d (32, 32, kernel_size=(



In [3]:
data = WavenetDataset(dataset_file='train_samples/saber/dataset.npz',
                      item_length=model.receptive_field + model.output_length - 1,
                      target_length=model.output_length,
                      file_location='train_samples/saber',
                      test_stride=20)
print('the dataset has ' + str(len(data)) + ' items')

one hot input
the dataset has 22095 items


In [None]:
def generate_and_log_samples(step):
    sample_length=4000
    gen_model = load_latest_model_from('snapshots')
    print("start generating...")
    samples = generate_audio(gen_model,
                             length=sample_length,
                             temperatures=[0])
    tf_samples = tf.convert_to_tensor(samples, dtype=tf.float32)
    logger.audio_summary('temperature 0', tf_samples, step, sr=16000)

    samples = generate_audio(gen_model,
                             length=sample_length,
                             temperatures=[0.5])
    tf_samples = tf.convert_to_tensor(samples, dtype=tf.float32)
    logger.audio_summary('temperature 0.5', tf_samples, step, sr=16000)
    print("audio clips generated")

In [None]:
logger = TensorboardLogger(log_interval=200,
                           validation_interval=200,
                           generate_interval=500,
                           generate_function=generate_and_log_samples,
                           log_dir="logs")

In [None]:
trainer = WavenetTrainer(model=model,
                           dataset=data,
                           lr=0.001,
                           weight_decay=0.0,
                        gradient_clipping=None,
                           snapshot_path='snapshots',
                           snapshot_name='saber_model',
                           snapshot_interval=100000)

print('start training...')
tic = time.time()
trainer.train(batch_size=8,
              epochs=20)
toc = time.time()
print('Training took {} seconds.'.format(toc - tic))

In [None]:
data.start_samples
data.train = False
trainer.dataloader.dataset.train = False

In [None]:
print("dataloader length: ", len(trainer.dataloader))
print("test length:", len(data))
print("sample length:", data._length)

In [None]:
model.dtype = dtype
print(model.dtype)

In [12]:
start_data = data[100][0]
start_data = torch.max(start_data, 0)[1]
print(start_data)


 217
 207
 197
 ⋮  
   0
   0
   1
[torch.LongTensor of size 4108]



In [13]:
def prog_callback(step, total_steps):
    print(str(100 * step // total_steps) + "% generated")
for q in model.dilated_queues:
    q.dtype = dtype
    
generated1 = model.generate_fast(num_samples=160000, 
                                 first_samples=start_data,
                                 progress_callback=prog_callback,
                                 progress_interval=1000,
                                 temperature=1.0)

0% generated
0% generated
1% generated
1% generated
2% generated
one generating step does take approximately 0.008517191410064698 seconds)
3% generated
3% generated
4% generated
4% generated
5% generated
6% generated
6% generated
7% generated
7% generated
8% generated
9% generated
9% generated
10% generated
10% generated
11% generated
12% generated
12% generated
13% generated
14% generated
14% generated
15% generated
15% generated
16% generated
17% generated
17% generated
18% generated
18% generated
19% generated
20% generated
20% generated
21% generated
21% generated
22% generated
23% generated
23% generated
24% generated
24% generated
25% generated
26% generated
26% generated
27% generated
28% generated
28% generated
29% generated
29% generated
30% generated
31% generated
31% generated
32% generated
32% generated
33% generated
34% generated
34% generated
35% generated
35% generated
36% generated
37% generated
37% generated
38% generated
38% generated
39% generated
40% generated
40% g

In [14]:
import IPython.display as ipd

ipd.Audio(generated1, rate=16000)

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(generated1); ax1.set_title('Raw audio signal')
ax2.specgram(generated1); ax2.set_title('Spectrogram');

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(start_sample); ax1.set_title('Raw audio signal')
ax2.specgram(start_sample); ax2.set_title('Spectrogram');

In [None]:
start training...
epoch 0
loss at step 50: 5.601520707905292
one training step does take approximately 0.19061788082122802 seconds)
loss at step 100: 5.552221384048462
loss at step 150: 5.544267034530639
loss at step 200: 5.531772727966309
validation loss: 5.490092188517252
validation accuracy: 1.0451505016722409%
loss at step 250: 5.524168214797974
loss at step 300: 5.501959915161133
loss at step 350: 5.440135269165039
loss at step 400: 5.332433271408081
validation loss: 4.427264089584351
validation accuracy: 24.91638795986622%
loss at step 450: 5.3144251728057865
loss at step 500: 5.112221040725708
loss at step 550: 5.0085866355896
loss at step 600: 4.963179755210876
validation loss: 3.9374835840861
validation accuracy: 25.22993311036789%
loss at step 650: 4.930944819450378
loss at step 700: 4.881737098693848
loss at step 750: 4.742912063598633
loss at step 800: 4.686616892814636
validation loss: 3.8073496564229328
validation accuracy: 25.68979933110368%
loss at step 850: 4.577959842681885
loss at step 900: 4.399262466430664
loss at step 950: 4.375202603340149
loss at step 1000: 4.3079585933685305
validation loss: 3.5753333044052122
validation accuracy: 24.498327759197323%
loss at step 1050: 4.244945120811463
loss at step 1100: 4.123299965858459
loss at step 1150: 4.103064022064209
loss at step 1200: 4.082510600090027
validation loss: 3.413504378000895
validation accuracy: 26.00334448160535%
loss at step 1250: 3.939071798324585
loss at step 1300: 3.9508083343505858
loss at step 1350: 3.8663349866867067
loss at step 1400: 3.8707763385772704
validation loss: 3.2716021649042766
validation accuracy: 25.020903010033447%
epoch 1
loss at step 1450: 3.7944415807724
loss at step 1500: 3.82066180229187
loss at step 1550: 3.8355930709838866
loss at step 1600: 3.7929911947250368
validation loss: 3.1106809441248577
validation accuracy: 27.38294314381271%
loss at step 1650: 3.761087512969971
loss at step 1700: 3.7161417627334594
loss at step 1750: 3.68661922454834
loss at step 1800: 3.5772906827926634
validation loss: 2.9680276489257813
validation accuracy: 28.38628762541806%
loss at step 1850: 3.653769178390503
loss at step 1900: 3.8210517024993895
loss at step 1950: 3.4200775527954104
loss at step 2000: 3.5994531393051146
validation loss: 2.997499696413676
validation accuracy: 28.365384615384613%
loss at step 2050: 3.5013914012908938
loss at step 2100: 3.3859068155288696
loss at step 2150: 3.4870605945587156
loss at step 2200: 3.382463240623474
validation loss: 2.9953096040089924
validation accuracy: 28.010033444816052%
loss at step 2250: 3.2740977144241334
loss at step 2300: 3.3375968599319457
loss at step 2350: 3.33543728351593
loss at step 2400: 3.311717290878296
validation loss: 2.741686725616455
validation accuracy: 29.15969899665552%
loss at step 2450: 3.3888323879241944
loss at step 2500: 3.2774668455123903
loss at step 2550: 3.2909540367126464
loss at step 2600: 3.156819558143616
validation loss: 2.644340982437134
validation accuracy: 29.38963210702341%
loss at step 2650: 3.1362243604660036
loss at step 2700: 3.1809526824951173
loss at step 2750: 3.1044933462142943
loss at step 2800: 3.2104168224334715
validation loss: 2.710980224609375
validation accuracy: 28.511705685618725%
epoch 2
loss at step 2850: 3.1645427131652832
loss at step 2900: 3.086708178520203
loss at step 2950: 3.1935667037963866
loss at step 3000: 3.065649948120117
validation loss: 2.599242707888285
validation accuracy: 29.95401337792642%
loss at step 3050: 2.9623973870277407
loss at step 3100: 2.977948703765869
loss at step 3150: 3.039284749031067
loss at step 3200: 3.1032708168029783
validation loss: 2.5787479861577354
validation accuracy: 30.0376254180602%
loss at step 3250: 3.020424065589905
loss at step 3300: 2.9368478298187255
loss at step 3350: 3.011261811256409
loss at step 3400: 2.936244683265686
validation loss: 2.510010568300883
validation accuracy: 30.56020066889632%
loss at step 3450: 2.92849506855011
loss at step 3500: 2.903533215522766
loss at step 3550: 2.835393509864807
loss at step 3600: 2.875207557678223
validation loss: 2.5806426111857097
validation accuracy: 29.995819397993312%
loss at step 3650: 2.982465934753418
loss at step 3700: 2.8224086570739746
loss at step 3750: 2.773958697319031
loss at step 3800: 2.933848671913147
validation loss: 2.429751847585042
validation accuracy: 31.47993311036789%
loss at step 3850: 2.935438051223755
loss at step 3900: 2.8551607513427735
loss at step 3950: 2.7788655376434326
loss at step 4000: 2.7510599946975707
validation loss: 2.3318386379877727
validation accuracy: 31.25%
loss at step 4050: 2.7630084943771362
loss at step 4100: 2.784786548614502
loss at step 4150: 2.823610978126526
loss at step 4200: 2.74433349609375
validation loss: 2.3619025961558022
validation accuracy: 31.08277591973244%
loss at step 4250: 2.7720167875289916
epoch 3
loss at step 4300: 2.722008581161499
loss at step 4350: 2.683127827644348
loss at step 4400: 2.7036391639709474
validation loss: 2.3295965019861855
validation accuracy: 31.709866220735787%
loss at step 4450: 2.5949549078941345
loss at step 4500: 2.6527379083633424
loss at step 4550: 2.6835867977142334
loss at step 4600: 2.6377884101867677
validation loss: 2.428244962692261
validation accuracy: 31.438127090301005%
loss at step 4650: 2.682296323776245
loss at step 4700: 2.6830776596069335
loss at step 4750: 2.7608815002441407
loss at step 4800: 2.5994027352333067
validation loss: 2.2691842166582745
validation accuracy: 32.002508361204015%
loss at step 4850: 2.6003666806221006
loss at step 4900: 2.7449550104141234
loss at step 4950: 2.6577998113632204
loss at step 5000: 2.593499083518982
validation loss: 2.2588222297032674
validation accuracy: 31.960702341137125%
loss at step 5050: 2.6504480028152466
loss at step 5100: 2.692755765914917
loss at step 5150: 2.646983962059021
loss at step 5200: 2.5553077936172484
validation loss: 2.235258067448934
validation accuracy: 33.61204013377927%
loss at step 5250: 2.5953399658203127
loss at step 5300: 2.77093816280365
loss at step 5350: 2.628749816417694
loss at step 5400: 2.558472900390625
validation loss: 2.253657941818237
validation accuracy: 33.27759197324415%
loss at step 5450: 2.6695416879653933
loss at step 5500: 2.6403193950653074
loss at step 5550: 2.6906979990005495
loss at step 5600: 2.632576594352722
validation loss: 2.2271932284037272
validation accuracy: 32.29515050167224%
loss at step 5650: 2.6107604622840883
epoch 4
loss at step 5700: 2.582443132400513
loss at step 5750: 2.6650914669036867
loss at step 5800: 2.8158610439300538
validation loss: 2.228590728441874
validation accuracy: 33.34030100334448%
loss at step 5850: 2.6931549406051634
loss at step 5900: 2.651780562400818
loss at step 5950: 2.750603561401367
loss at step 6000: 2.722158169746399
validation loss: 3.0212835629781085
validation accuracy: 30.748327759197323%
loss at step 6050: 2.6879207038879396
loss at step 6100: 2.6904709482192994
loss at step 6150: 2.6776280212402344
loss at step 6200: 2.7096633672714234
validation loss: 2.599953867594401
validation accuracy: 32.98494983277592%
loss at step 6250: 2.634800329208374
loss at step 6300: 2.6009347152709963
loss at step 6350: 2.628697416782379
loss at step 6400: 2.7100318574905398
validation loss: 2.356964473724365
validation accuracy: 33.570234113712374%
loss at step 6450: 2.782756748199463
loss at step 6500: 2.7148419046401977
loss at step 6550: 2.64122682094574