This is notebook gives a quick overview of this WaveNet implementation, i.e. creating the model and the data set, training the model and generating samples from it.

In [1]:
from model_logging import *
import torch
from wavenet_model import *
from audio_data import WavenetDataset
from wavenet_training import *
#from model_logging import *

## Model
This is an implementation of WaveNet as it was described in the original paper (https://arxiv.org/abs/1609.03499). Each layer looks like this:

```
            |----------------------------------------|      *residual*
            |                                        |
            |    |-- conv -- tanh --|                |
 -> dilate -|----|                  * ----|-- 1x1 -- + -->  *input*
                 |-- conv -- sigm --|     |
                                         1x1
                                          |
 ---------------------------------------> + ------------->  *skip*
```

Each layer dilates the input by a factor of two. After each block the dilation is reset and start from one. You can define the number of layers in each block (``layers``) and the number of blocks (``blocks``). The blocks are followed by two 1x1 convolutions and a softmax output function.
Because of the dilation operation, the independent output for multiple successive samples can be calculated efficiently. With ``output_length``, you can define the number these outputs. Empirically, it seems that a large number of skip channels is required.

In [2]:
# initialize cuda option
dtype = torch.FloatTensor # data type
ltype = torch.LongTensor # label type

use_cuda = torch.cuda.is_available()
#use_cuda = False
if use_cuda:
    print('use gpu')
    dtype = torch.cuda.FloatTensor
    ltype = torch.cuda.LongTensor

use gpu


In [3]:
model = WaveNetModel(layers=10,
                     blocks=3,
                     dilation_channels=32,
                     residual_channels=32,
                     skip_channels=1024,
                     end_channels=512, 
                     output_length=16,
                     dtype=dtype, 
                     bias=True)
# model = load_latest_model_from('snapshots', use_cuda=use_cuda)

print('model: ', model.cuda())
print('receptive field: ', model.receptive_field)
print('parameter count: ', model.parameter_count())

model:  WaveNetModel (
  (filter_convs): ModuleList (
    (0): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (1): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (2): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (3): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (4): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (5): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (6): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (7): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (8): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (9): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (10): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (11): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (12): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (13): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (14): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (15): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    (16): Conv1d(32, 32, kernel_size=(2,), stride=(1,))
    

## Data Set
To create the data set, you have to specify a path to a data set file. If this file already exists it will be used, if not it will be generated. If you want to generate the data set file (a ``.npz`` file), you have to specify the directory (``file_location``) in which all the audio files you want to use are located. The attribute ``target_length`` specifies the number of successive samples are used as a target and corresponds to the output length of the model. The ``item_length`` defines the number of samples in each item of the dataset and should always be ``model.receptive_field + model.output_length - 1``.

```
          |----receptive_field----|
                                |--output_length--|
example:  | | | | | | | | | | | | | | | | | | | | |
target:                           | | | | | | | | | |  
```
To create a test set, you should define a ``test_stride``. Then each ``test_stride``th item will be assigned to the test set.

In [4]:
data = WavenetDataset(dataset_file='train_samples/bach_chaconne/dataset.npz',
                      item_length=model.receptive_field + model.output_length - 1,
                      target_length=model.output_length,
                      file_location='train_samples/bach_chaconne',
                      test_stride=500)
print('the dataset has ' + str(len(data)) + ' items')

one hot input
the dataset has 598277 items


## Training and Logging
This implementation supports logging with TensorBoard (you need to have TensorFlow installed). You can even generate audio samples from the current snapshot of the model during training. This will happen in a background thread on the cpu, so it will not interfere with the actual training but will be rather slow. If you don't have TensorFlow, you can use the standard logger that will print out to the console.
The trainer uses Adam as default optimizer.

In [9]:
def generate_and_log_samples(step):
    sample_length=32000
    gen_model = load_latest_model_from('snapshots', use_cuda=False)
    print("start generating...")
    samples = generate_audio(gen_model,
                             length=sample_length,
                             temperatures=[0.5])
    tf_samples = tf.convert_to_tensor(samples, dtype=tf.float32)
    logger.audio_summary('temperature_0.5', tf_samples, step, sr=16000)

    samples = generate_audio(gen_model,
                             length=sample_length,
                             temperatures=[1.])
    tf_samples = tf.convert_to_tensor(samples, dtype=tf.float32)
    logger.audio_summary('temperature_1.0', tf_samples, step, sr=16000)
    print("audio clips generated")


#logger = TensorboardLogger(log_interval=200,
 #                          validation_interval=400,
  #                         generate_interval=1000,
   #                        generate_function=generate_and_log_samples,
    #                       log_dir="logs/chaconne_model")

logger = Logger(log_interval=20,
                 validation_interval=400,
                 generate_interval=1000)

In [None]:
trainer = WavenetTrainer(model=model,
                         dataset=data,
                         lr=0.001,
                         snapshot_path='snapshots',
                         snapshot_name='chaconne_model',
                         snapshot_interval=1000,
                         logger=logger,
                         dtype=dtype,
                         ltype=ltype)

print('start training...')
trainer.train(batch_size=8,
              epochs=10)

start training...
epoch 0
loss at step 20: 5.162780165672302
loss at step 40: 5.089954590797424
loss at step 60: 5.054309391975403
loss at step 80: 4.960115432739258
one training step does take approximately 0.5236522936820984 seconds)
loss at step 100: 4.897372150421143
loss at step 120: 4.790813255310058
loss at step 140: 4.636266350746155
loss at step 160: 4.484929537773132
loss at step 180: 4.268103325366974
loss at step 200: 4.15226628780365
loss at step 220: 4.086933124065399
loss at step 240: 4.068137288093567
loss at step 260: 4.022809708118439
loss at step 280: 4.026028096675873
loss at step 300: 3.888850724697113
loss at step 320: 3.8549909591674805
loss at step 340: 3.834302806854248
loss at step 360: 3.79020094871521
loss at step 380: 3.7681735157966614
loss at step 400: 3.810551917552948
validation loss: 4.034210497538249
validation accuracy: 6.787353923205343%
loss at step 420: 3.756807065010071
loss at step 440: 3.706703209877014
loss at step 460: 3.831369996070862
loss 

validation loss: 3.4052259667714435
validation accuracy: 9.24457429048414%
loss at step 4020: 3.1789023518562316
loss at step 4040: 3.2641913294792175
loss at step 4060: 3.177927255630493
loss at step 4080: 3.2050485610961914
loss at step 4100: 3.2564444184303283
loss at step 4120: 3.162487733364105
loss at step 4140: 3.192570388317108
loss at step 4160: 3.2074580430984496
loss at step 4180: 3.2347574710845945
loss at step 4200: 3.1917534470558167
loss at step 4220: 3.2261887431144713
loss at step 4240: 3.1985196232795716
loss at step 4260: 3.2420174837112428
loss at step 4280: 3.220473349094391
loss at step 4300: 3.1718783020973205
loss at step 4320: 3.1213488936424256
loss at step 4340: 3.2046050906181334
loss at step 4360: 3.2713900446891784
loss at step 4380: 3.225289046764374
loss at step 4400: 3.1465413928031922
validation loss: 3.3603880818684897
validation accuracy: 9.307178631051753%
loss at step 4420: 3.1918266654014587
loss at step 4440: 3.1917948484420777
loss at step 4460:

loss at step 8000: 3.114830756187439
validation loss: 3.235477124849955
validation accuracy: 10.079298831385643%
loss at step 8020: 3.128624749183655
loss at step 8040: 3.111064302921295
loss at step 8060: 3.221082878112793
loss at step 8080: 3.1077369809150697
loss at step 8100: 3.012094271183014
loss at step 8120: 3.03189754486084
loss at step 8140: 3.1080143094062804
loss at step 8160: 3.0648762583732605
loss at step 8180: 3.0929367184638976
loss at step 8200: 3.075855779647827
loss at step 8220: 3.095041048526764
loss at step 8240: 3.0323700189590452
loss at step 8260: 3.1083167552948
loss at step 8280: 3.100862979888916
loss at step 8300: 3.156915545463562
loss at step 8320: 3.057618248462677
loss at step 8340: 3.0737389087677003
loss at step 8360: 3.158937060832977
loss at step 8380: 3.063744866847992
loss at step 8400: 3.063803267478943
validation loss: 3.210866157213847
validation accuracy: 10.929674457429048%
loss at step 8420: 3.16995393037796
loss at step 8440: 3.14454472064

loss at step 11920: 3.0050369024276735
loss at step 11940: 3.0233883261680603
loss at step 11960: 3.042835235595703
loss at step 11980: 3.0677093982696535
loss at step 12000: 3.0230388879776
validation loss: 3.129176624615987
validation accuracy: 11.675709515859767%
loss at step 12020: 3.055820894241333
loss at step 12040: 3.0321760416030883
loss at step 12060: 3.080505061149597
loss at step 12080: 2.967938709259033
loss at step 12100: 3.0617630958557127
loss at step 12120: 3.060823452472687
loss at step 12140: 3.043993330001831
loss at step 12160: 2.953952968120575
loss at step 12180: 3.0513290762901306
loss at step 12200: 3.071247625350952
loss at step 12220: 2.9987076997756956
loss at step 12240: 3.024428296089172
loss at step 12260: 3.0340832114219665
loss at step 12280: 3.039249229431152
loss at step 12300: 3.025560128688812
loss at step 12320: 2.966801941394806
loss at step 12340: 3.074874687194824
loss at step 12360: 2.9879834771156313
loss at step 12380: 3.0438493847846986
loss

loss at step 15780: 3.0605921387672423
loss at step 15800: 2.914723575115204
loss at step 15820: 2.9062338352203367
loss at step 15840: 3.02288738489151
loss at step 15860: 3.0138417363166807
loss at step 15880: 3.025317406654358
loss at step 15900: 3.053657019138336
loss at step 15920: 2.9703461170196532
loss at step 15940: 2.924605929851532
loss at step 15960: 2.962989556789398
loss at step 15980: 2.967585599422455
loss at step 16000: 2.9600739002227785
validation loss: 3.0751457357406617
validation accuracy: 12.087854757929883%
loss at step 16020: 2.9348230719566346
loss at step 16040: 2.980185353755951
loss at step 16060: 2.9786244869232177
loss at step 16080: 2.979057991504669
loss at step 16100: 2.991746699810028
loss at step 16120: 2.9444692492485047
loss at step 16140: 2.9237182974815368
loss at step 16160: 2.987958514690399
loss at step 16180: 2.978611671924591
loss at step 16200: 2.98181414604187
loss at step 16220: 2.967973160743713
loss at step 16240: 2.931735026836395
loss

loss at step 19660: 2.8733541488647463
loss at step 19680: 3.0222919821739196
loss at step 19700: 2.969448411464691
loss at step 19720: 2.92838397026062
loss at step 19740: 2.961627697944641
loss at step 19760: 2.9649990439414977
loss at step 19780: 2.9534488201141356
loss at step 19800: 3.004178774356842
loss at step 19820: 2.952987849712372
loss at step 19840: 2.8952455878257752
loss at step 19860: 2.927228546142578
loss at step 19880: 2.8792880892753603
loss at step 19900: 2.9473291993141175
loss at step 19920: 2.9616614103317263
loss at step 19940: 2.8822675824165342
loss at step 19960: 2.9603505730628967
loss at step 19980: 2.8559457778930666
loss at step 20000: 3.005399703979492
validation loss: 3.0161498181025186
validation accuracy: 12.588689482470786%
loss at step 20020: 2.950360369682312
loss at step 20040: 2.921996772289276
loss at step 20060: 2.8846943974494934
loss at step 20080: 2.944383442401886
loss at step 20100: 3.00750390291214
loss at step 20120: 2.8378545522689818


loss at step 23560: 2.8911163449287414
loss at step 23580: 2.934093940258026
loss at step 23600: 2.887026333808899
validation loss: 2.9883971865971883
validation accuracy: 12.93823038397329%
loss at step 23620: 2.939064919948578
loss at step 23640: 3.0214253067970276
loss at step 23660: 2.8852784276008605
loss at step 23680: 2.818788480758667
loss at step 23700: 2.9638047575950623
loss at step 23720: 2.8777089595794676
loss at step 23740: 2.9331770658493044
loss at step 23760: 2.9146593809127808
loss at step 23780: 2.924949753284454
loss at step 23800: 2.856704044342041
loss at step 23820: 2.9359554052352905
loss at step 23840: 2.894316041469574
loss at step 23860: 2.8419658541679382
loss at step 23880: 2.907114124298096
loss at step 23900: 2.804286277294159
loss at step 23920: 2.9262362241744997
loss at step 23940: 2.7953301668167114
loss at step 23960: 2.7992517828941343
loss at step 23980: 2.8759032845497132
loss at step 24000: 2.879504954814911
validation loss: 2.974695094426473
va

loss at step 27440: 2.745735466480255
loss at step 27460: 2.828411340713501
loss at step 27480: 2.890606963634491
loss at step 27500: 2.8152225971221925
loss at step 27520: 2.892170810699463
loss at step 27540: 2.7564960598945616
loss at step 27560: 2.8761111736297607
loss at step 27580: 2.878416621685028
loss at step 27600: 2.875919210910797
validation loss: 2.9683462317784626
validation accuracy: 12.979966611018364%
loss at step 27620: 2.8221802949905395
loss at step 27640: 2.8912249207496643
loss at step 27660: 2.860300874710083
loss at step 27680: 2.9220999479293823
loss at step 27700: 2.8966709852218626
loss at step 27720: 2.7727614998817445
loss at step 27740: 2.8241966009140014
loss at step 27760: 2.865434455871582
loss at step 27780: 2.798726940155029
loss at step 27800: 2.8246052503585815
loss at step 27820: 2.846036970615387
loss at step 27840: 2.8893505573272704
loss at step 27860: 2.8836658239364623
loss at step 27880: 2.7954781532287596
loss at step 27900: 2.81930708885192

loss at step 31320: 2.7807181239128114
loss at step 31340: 2.81596497297287
loss at step 31360: 2.847038757801056
loss at step 31380: 2.8042336463928224
loss at step 31400: 2.829120969772339
loss at step 31420: 2.8319985270500183
loss at step 31440: 2.8390973567962647
loss at step 31460: 2.816698944568634
loss at step 31480: 2.815987539291382
loss at step 31500: 2.823652780056
loss at step 31520: 2.817155909538269
loss at step 31540: 2.8306984424591066
loss at step 31560: 2.8561968207359314
loss at step 31580: 2.800739860534668
loss at step 31600: 2.8324445605278017
validation loss: 2.9354896847407024
validation accuracy: 13.230383973288815%
loss at step 31620: 2.8472894072532653
loss at step 31640: 2.7653544545173645
loss at step 31660: 2.8218127250671388
loss at step 31680: 2.8751335978507995
loss at step 31700: 2.8020596861839295
loss at step 31720: 2.8508884072303773
loss at step 31740: 2.763071596622467
loss at step 31760: 2.794885849952698
loss at step 31780: 2.8570738434791565
l

validation loss: 2.918790752092997
validation accuracy: 13.52253756260434%
loss at step 35220: 2.8136191725730897
loss at step 35240: 2.823182666301727
loss at step 35260: 2.8808452129364013
loss at step 35280: 2.843585479259491
loss at step 35300: 2.77946195602417
loss at step 35320: 2.8024762868881226
loss at step 35340: 2.8228626608848573
loss at step 35360: 2.8412224531173704
loss at step 35380: 2.7441245317459106
loss at step 35400: 2.764326107501984
loss at step 35420: 2.8242959499359133
loss at step 35440: 2.83495637178421
loss at step 35460: 2.8213343739509584
loss at step 35480: 2.8211952209472657
loss at step 35500: 2.7946226358413697
loss at step 35520: 2.7875479102134704
loss at step 35540: 2.772036814689636
loss at step 35560: 2.8317281007766724
loss at step 35580: 2.7961588978767393
loss at step 35600: 2.8505170941352844
validation loss: 2.9317983134587604
validation accuracy: 12.865191986644408%
loss at step 35620: 2.766214406490326
loss at step 35640: 2.7851132988929748

loss at step 39100: 2.792754077911377
loss at step 39120: 2.8294769644737245
loss at step 39140: 2.7910500168800354
loss at step 39160: 2.8229071259498597
loss at step 39180: 2.770267367362976
loss at step 39200: 2.7944493889808655
validation loss: 2.894195751349131
validation accuracy: 13.606010016694492%
loss at step 39220: 2.7460877299308777
loss at step 39240: 2.7500099539756775
loss at step 39260: 2.84627822637558
loss at step 39280: 2.776406693458557
loss at step 39300: 2.8199816942214966
loss at step 39320: 2.753622353076935
loss at step 39340: 2.8161106944084167
loss at step 39360: 2.7645013093948365
loss at step 39380: 2.821846973896027
loss at step 39400: 2.8638893485069277
loss at step 39420: 2.7836749792099
loss at step 39440: 2.805334508419037
loss at step 39460: 2.7995888590812683
loss at step 39480: 2.858773875236511
loss at step 39500: 2.7547016620635985
loss at step 39520: 2.7877488493919373
loss at step 39540: 2.7915931820869444
loss at step 39560: 2.8331016659736634


loss at step 42980: 2.8076645851135256
loss at step 43000: 2.7098663091659545
loss at step 43020: 2.7693687319755553
loss at step 43040: 2.7816125750541687
loss at step 43060: 2.7404900550842286
loss at step 43080: 2.7885075330734255
loss at step 43100: 2.686580407619476
loss at step 43120: 2.8048353791236877
loss at step 43140: 2.7701401472091676
loss at step 43160: 2.8033422350883486
loss at step 43180: 2.8665018677711487
loss at step 43200: 2.7691388010978697
validation loss: 2.8685682741800944
validation accuracy: 13.757303839732888%
loss at step 43220: 2.8172378301620484
loss at step 43240: 2.795344150066376
loss at step 43260: 2.792289471626282
loss at step 43280: 2.8147246122360228
loss at step 43300: 2.7371333718299864
loss at step 43320: 2.7172937631607055
loss at step 43340: 2.755329728126526
loss at step 43360: 2.7495023727416994
loss at step 43380: 2.765028154850006
loss at step 43400: 2.826498103141785
loss at step 43420: 2.7623222947120665
loss at step 43440: 2.8129791378

loss at step 46860: 2.7825340747833254
loss at step 46880: 2.8455321907997133
loss at step 46900: 2.7569782376289367
loss at step 46920: 2.7729106545448303
loss at step 46940: 2.807039797306061
loss at step 46960: 2.800514554977417
loss at step 46980: 2.661225140094757
loss at step 47000: 2.7731064796447753
loss at step 47020: 2.757568371295929
loss at step 47040: 2.86082661151886
loss at step 47060: 2.7160537600517274
loss at step 47080: 2.744175112247467
loss at step 47100: 2.7443116784095762
loss at step 47120: 2.758276093006134
loss at step 47140: 2.8226725697517394
loss at step 47160: 2.7191230773925783
loss at step 47180: 2.75035582780838
loss at step 47200: 2.833278167247772
validation loss: 2.876226096153259
validation accuracy: 13.757303839732888%
loss at step 47220: 2.770278573036194
loss at step 47240: 2.6442154169082643
loss at step 47260: 2.7812323331832887
loss at step 47280: 2.8133098602294924
loss at step 47300: 2.7786020398139955
loss at step 47320: 2.7778807520866393


loss at step 50780: 2.773786163330078
loss at step 50800: 2.7333736062049865
validation loss: 2.833524515628815
validation accuracy: 14.284223706176963%
loss at step 50820: 2.7207476258277894
loss at step 50840: 2.7761268734931948
loss at step 50860: 2.7427863836288453
loss at step 50880: 2.8087934017181397
loss at step 50900: 2.788203001022339
loss at step 50920: 2.8196742296218873
loss at step 50940: 2.7619823813438416
loss at step 50960: 2.739240312576294
loss at step 50980: 2.6895206689834597
loss at step 51000: 2.7420029759407045
loss at step 51020: 2.7467488169670107
loss at step 51040: 2.753857409954071
loss at step 51060: 2.776341104507446
loss at step 51080: 2.7315117478370667
loss at step 51100: 2.7882624745368956
loss at step 51120: 2.799879002571106
loss at step 51140: 2.7458465814590456
loss at step 51160: 2.7568896651268004
loss at step 51180: 2.7715800046920775
loss at step 51200: 2.746755313873291
validation loss: 2.845983176231384
validation accuracy: 13.89294657762938

loss at step 54660: 2.7001577377319337
loss at step 54680: 2.6975210428237917
loss at step 54700: 2.626429057121277
loss at step 54720: 2.617084431648254
loss at step 54740: 2.779800498485565
loss at step 54760: 2.7971866846084597
loss at step 54780: 2.7710905194282534
loss at step 54800: 2.731093239784241
validation loss: 2.8952632919947305
validation accuracy: 13.637312186978297%
loss at step 54820: 2.764820158481598
loss at step 54840: 2.8281885981559753
loss at step 54860: 2.7724265456199646
loss at step 54880: 2.8015897274017334
loss at step 54900: 2.8398864030838014
loss at step 54920: 2.768544614315033
loss at step 54940: 2.8096962094306948
loss at step 54960: 2.764060711860657
loss at step 54980: 2.7460445165634155
loss at step 55000: 2.8249533414840697
loss at step 55020: 2.695324456691742
loss at step 55040: 2.7881848931312563
loss at step 55060: 2.743749499320984
loss at step 55080: 2.782752740383148
loss at step 55100: 2.6948406100273132
loss at step 55120: 2.77072526216506

loss at step 58520: 2.7327016830444335
loss at step 58540: 2.6798805356025697
loss at step 58560: 2.738019955158234
loss at step 58580: 2.745971703529358
loss at step 58600: 2.7260540962219237
loss at step 58620: 2.815381479263306
loss at step 58640: 2.7682741045951844
loss at step 58660: 2.7829195857048035
loss at step 58680: 2.790846061706543
loss at step 58700: 2.6524888277053833
loss at step 58720: 2.791000282764435
loss at step 58740: 2.772694158554077
loss at step 58760: 2.7792198419570924
loss at step 58780: 2.797192931175232
loss at step 58800: 2.7804934859275816
validation loss: 2.8259252373377484
validation accuracy: 14.352045075125208%
loss at step 58820: 2.7067964911460876
loss at step 58840: 2.7651066422462462
loss at step 58860: 2.7699437737464905
loss at step 58880: 2.6713834047317504
loss at step 58900: 2.723837506771088
loss at step 58920: 2.750230872631073
loss at step 58940: 2.643436551094055
loss at step 58960: 2.7456119656562805
loss at step 58980: 2.74765638113021

validation loss: 2.8294899702072143
validation accuracy: 15.071994991652755%
loss at step 62420: 2.8431493520736693
loss at step 62440: 2.6620311379432677
loss at step 62460: 2.7294252276420594
loss at step 62480: 2.639150249958038
loss at step 62500: 2.7677031993865966
loss at step 62520: 2.752448832988739
loss at step 62540: 2.7343505144119264
loss at step 62560: 2.8077401518821716
loss at step 62580: 2.6673295855522157
loss at step 62600: 2.6994274139404295
loss at step 62620: 2.7986891627311707
loss at step 62640: 2.714406096935272
loss at step 62660: 2.7308116436004637
loss at step 62680: 2.758282554149628
loss at step 62700: 2.705794095993042
loss at step 62720: 2.808923804759979
loss at step 62740: 2.829856288433075
loss at step 62760: 2.7670055747032167
loss at step 62780: 2.7374179124832154
loss at step 62800: 2.7568812370300293
validation loss: 2.850922118028005
validation accuracy: 14.085976627712855%
loss at step 62820: 2.6964955925941467
loss at step 62840: 2.7683202505111

loss at step 66280: 2.683622121810913
loss at step 66300: 2.745735454559326
loss at step 66320: 2.719000542163849
loss at step 66340: 2.6693244218826293
loss at step 66360: 2.715834951400757
loss at step 66380: 2.7228469014167787
loss at step 66400: 2.6910349249839784
validation loss: 2.8239078839619953
validation accuracy: 14.456385642737896%
loss at step 66420: 2.8501644372940063
loss at step 66440: 2.7390034914016725
loss at step 66460: 2.7048932790756224
loss at step 66480: 2.736556112766266
loss at step 66500: 2.6898212313652037
loss at step 66520: 2.615202176570892
loss at step 66540: 2.760747218132019
loss at step 66560: 2.7362848043441774
loss at step 66580: 2.761297607421875
loss at step 66600: 2.7430568814277647
loss at step 66620: 2.673439145088196
loss at step 66640: 2.658943462371826
loss at step 66660: 2.8573553562164307
loss at step 66680: 2.6952586650848387
loss at step 66700: 2.7244382739067077
loss at step 66720: 2.7591853499412538
loss at step 66740: 2.71591054201126

loss at step 70160: 2.66344039440155
loss at step 70180: 2.68296355009079
loss at step 70200: 2.641629385948181
loss at step 70220: 2.7020328521728514
loss at step 70240: 2.7160440325737
loss at step 70260: 2.755428671836853
loss at step 70280: 2.6769171595573424
loss at step 70300: 2.6899892687797546
loss at step 70320: 2.753960978984833
loss at step 70340: 2.660026800632477
loss at step 70360: 2.5964502811431887
loss at step 70380: 2.741564762592316
loss at step 70400: 2.699671483039856
validation loss: 2.785056522687276
validation accuracy: 15.087646076794659%
loss at step 70420: 2.7176315069198607
loss at step 70440: 2.742523527145386
loss at step 70460: 2.6984883904457093
loss at step 70480: 2.7235883474349976
loss at step 70500: 2.734596681594849
loss at step 70520: 2.7392531275749206
loss at step 70540: 2.709845983982086
loss at step 70560: 2.7243436217308044
loss at step 70580: 2.7171441435813906
loss at step 70600: 2.7232242226600647
loss at step 70620: 2.647429716587067
loss 

loss at step 74040: 2.7906698107719423
loss at step 74060: 2.693540632724762
loss at step 74080: 2.7451900601387025
loss at step 74100: 2.669507610797882
loss at step 74120: 2.732143783569336
loss at step 74140: 2.7503705620765686
loss at step 74160: 2.63346072435379
loss at step 74180: 2.743454360961914
loss at step 74200: 2.661134052276611
loss at step 74220: 2.6852021336555483
loss at step 74240: 2.7577125906944273
loss at step 74260: 2.7353310227394103
loss at step 74280: 2.7367393136024476
loss at step 74300: 2.7299183011054993
loss at step 74320: 2.7396024227142335
loss at step 74340: 2.7395118713378905
loss at step 74360: 2.6789526104927064
loss at step 74380: 2.748567795753479
loss at step 74400: 2.7364168524742127
validation loss: 2.7966662502288817
validation accuracy: 15.27024207011686%
loss at step 74420: 2.7225094079971313
loss at step 74440: 2.8276709079742433
loss at step 74460: 2.6744786381721495
loss at step 74480: 2.761888802051544
loss at step 74500: 2.78363599777221

loss at step 77960: 2.6565163016319273
loss at step 77980: 2.6262449383735658
loss at step 78000: 2.64651243686676
validation loss: 2.8214937893549603
validation accuracy: 14.560726210350584%
loss at step 78020: 2.6724722504615785
loss at step 78040: 2.6888559103012084
loss at step 78060: 2.7622055768966676
loss at step 78080: 2.6918702244758608
loss at step 78100: 2.727749490737915
loss at step 78120: 2.6655391693115233
loss at step 78140: 2.758989930152893
loss at step 78160: 2.715466928482056
loss at step 78180: 2.7344421267509462
loss at step 78200: 2.7127715587615966
loss at step 78220: 2.706951367855072
loss at step 78240: 2.6767958283424376
loss at step 78260: 2.642867422103882
loss at step 78280: 2.6565951466560365
loss at step 78300: 2.664384496212006
loss at step 78320: 2.6224735140800477
loss at step 78340: 2.6470692634582518
loss at step 78360: 2.6815388202667236
loss at step 78380: 2.7179011344909667
loss at step 78400: 2.704615664482117
validation loss: 2.797864165306091


loss at step 81840: 2.6102961778640745
loss at step 81860: 2.7354508876800536
loss at step 81880: 2.693737292289734
loss at step 81900: 2.6816327452659605
loss at step 81920: 2.7075544357299806
loss at step 81940: 2.7244885206222533
loss at step 81960: 2.626539480686188
loss at step 81980: 2.7327697396278383
loss at step 82000: 2.6473228573799132
validation loss: 2.7735563770929974
validation accuracy: 14.832011686143574%
loss at step 82020: 2.684140205383301
loss at step 82040: 2.741627836227417
loss at step 82060: 2.7919574618339538
loss at step 82080: 2.6352982878685
loss at step 82100: 2.6950711488723753
loss at step 82120: 2.76212557554245
loss at step 82140: 2.7342135667800904
loss at step 82160: 2.693541634082794
loss at step 82180: 2.6095203399658202
loss at step 82200: 2.6525866389274597
loss at step 82220: 2.7564761877059936
loss at step 82240: 2.704667329788208
loss at step 82260: 2.7064175486564634
loss at step 82280: 2.750186729431152
loss at step 82300: 2.717523765563965


loss at step 85700: 2.7244537472724915
loss at step 85720: 2.7095914006233217
loss at step 85740: 2.6857770919799804
loss at step 85760: 2.754808723926544
loss at step 85780: 2.6399945616722107
loss at step 85800: 2.6930785298347475
loss at step 85820: 2.717161762714386
loss at step 85840: 2.701657009124756
loss at step 85860: 2.6453122735023498
loss at step 85880: 2.6920691609382628
loss at step 85900: 2.679655838012695
loss at step 85920: 2.686196005344391
loss at step 85940: 2.6726969122886657
loss at step 85960: 2.6998019337654116
loss at step 85980: 2.73535498380661
loss at step 86000: 2.6932485580444334
validation loss: 2.7636483828226726
validation accuracy: 15.071994991652755%
loss at step 86020: 2.5815449595451354
loss at step 86040: 2.642139995098114
loss at step 86060: 2.635317826271057
loss at step 86080: 2.6790221452713014
loss at step 86100: 2.6624902844429017
loss at step 86120: 2.783533239364624
loss at step 86140: 2.658139336109161
loss at step 86160: 2.77686607837677


validation loss: 2.824626355965932
validation accuracy: 14.132929883138564%
loss at step 89620: 2.6821656465530395
loss at step 89640: 2.5945727705955504
loss at step 89660: 2.7084940314292907
loss at step 89680: 2.6272295475006104
loss at step 89700: 2.643423342704773
loss at step 89720: 2.6210603594779966
loss at step 89740: 2.6940165638923643
loss at step 89760: 2.693687653541565
loss at step 89780: 2.6580484390258787
loss at step 89800: 2.6608741879463196
loss at step 89820: 2.7335798859596254
loss at step 89840: 2.679896867275238
loss at step 89860: 2.6682975769042967
loss at step 89880: 2.6780964970588683
loss at step 89900: 2.6442212104797362
loss at step 89920: 2.708136761188507
loss at step 89940: 2.644916272163391
loss at step 89960: 2.715713953971863
loss at step 89980: 2.670178508758545
loss at step 90000: 2.7517282485961916
validation loss: 2.768266417980194
validation accuracy: 15.348497495826377%
loss at step 90020: 2.7632113337516784
loss at step 90040: 2.65085319280624

loss at step 93500: 2.6895387172698975
loss at step 93520: 2.6374205112457276
loss at step 93540: 2.652469575405121
loss at step 93560: 2.7325408935546873
loss at step 93580: 2.6788623452186586
loss at step 93600: 2.6965091586112977
validation loss: 2.7779923772811888
validation accuracy: 14.941569282136896%
loss at step 93620: 2.595069396495819
loss at step 93640: 2.724934732913971
loss at step 93660: 2.7170013904571535
loss at step 93680: 2.6191623210906982
loss at step 93700: 2.6277590155601502
loss at step 93720: 2.6751591086387636
loss at step 93740: 2.6369277238845825
loss at step 93760: 2.7097522616386414
loss at step 93780: 2.6904122591018678
loss at step 93800: 2.7778671622276305
loss at step 93820: 2.6560313940048217
loss at step 93840: 2.6721355319023132
loss at step 93860: 2.6570173621177675
loss at step 93880: 2.6860291719436646
loss at step 93900: 2.612747514247894
loss at step 93920: 2.752856731414795
loss at step 93940: 2.71970739364624
loss at step 93960: 2.63289878368

loss at step 97360: 2.672602152824402
loss at step 97380: 2.685239005088806
loss at step 97400: 2.6621739149093626
loss at step 97420: 2.7698076248168944
loss at step 97440: 2.6420738458633424
loss at step 97460: 2.778098261356354
loss at step 97480: 2.700210988521576
loss at step 97500: 2.6513081192970276
loss at step 97520: 2.6943679213523866
loss at step 97540: 2.74864581823349
loss at step 97560: 2.7187570691108705
loss at step 97580: 2.6432628870010375
loss at step 97600: 2.7239083886146545
validation loss: 2.7779286924997963
validation accuracy: 15.625%
loss at step 97620: 2.646199607849121
loss at step 97640: 2.6900200247764587
loss at step 97660: 2.6529792666435243
loss at step 97680: 2.64248206615448
loss at step 97700: 2.721072030067444
loss at step 97720: 2.7067981243133543
loss at step 97740: 2.7177141308784485
loss at step 97760: 2.656858777999878
loss at step 97780: 2.746215748786926
loss at step 97800: 2.7380526185035707
loss at step 97820: 2.702978730201721
loss at step

loss at step 101220: 2.6691025853157044
loss at step 101240: 2.625395894050598
loss at step 101260: 2.5753268122673036
loss at step 101280: 2.6050233483314513
loss at step 101300: 2.6523296356201174
loss at step 101320: 2.6847665786743162
loss at step 101340: 2.6948291182518007
loss at step 101360: 2.6853113412857055
loss at step 101380: 2.759034752845764
loss at step 101400: 2.7054874420166017
loss at step 101420: 2.7358849167823793
loss at step 101440: 2.6349490761756895
loss at step 101460: 2.7603220343589783
loss at step 101480: 2.6653544425964357
loss at step 101500: 2.7170209527015685
loss at step 101520: 2.664554464817047
loss at step 101540: 2.6926305890083313
loss at step 101560: 2.637912356853485
loss at step 101580: 2.612354552745819
loss at step 101600: 2.7167787671089174
validation loss: 2.753519817988078
validation accuracy: 15.880634390651085%
loss at step 101620: 2.6247223615646362
loss at step 101640: 2.7726691365242004
loss at step 101660: 2.6700878500938416
loss at s

loss at step 105020: 2.6956525564193727
loss at step 105040: 2.7128366589546205
loss at step 105060: 2.680810809135437
loss at step 105080: 2.642867636680603
loss at step 105100: 2.643260860443115
loss at step 105120: 2.719448482990265
loss at step 105140: 2.6530309557914733
loss at step 105160: 2.657451891899109
loss at step 105180: 2.720752990245819
loss at step 105200: 2.6286710619926454
validation loss: 2.7473398129145306
validation accuracy: 15.588480801335558%
loss at step 105220: 2.6763653635978697
loss at step 105240: 2.54657701253891
loss at step 105260: 2.668469715118408
loss at step 105280: 2.633170962333679
loss at step 105300: 2.679142141342163
loss at step 105320: 2.6609315633773805
loss at step 105340: 2.6678587079048155
loss at step 105360: 2.617349588871002
loss at step 105380: 2.6016167521476747
loss at step 105400: 2.684893453121185
loss at step 105420: 2.635358726978302
loss at step 105440: 2.6918542861938475
loss at step 105460: 2.651079475879669
loss at step 10548

loss at step 108820: 2.6926472544670106
loss at step 108840: 2.618320369720459
loss at step 108860: 2.7179187536239624
loss at step 108880: 2.7108777165412903
loss at step 108900: 2.6983333587646485
loss at step 108920: 2.546515429019928
loss at step 108940: 2.7496787309646606
loss at step 108960: 2.8116978645324706
loss at step 108980: 2.688618516921997
loss at step 109000: 2.690885639190674
loss at step 109020: 2.682310688495636
loss at step 109040: 2.6846598863601683
loss at step 109060: 2.705125463008881
loss at step 109080: 2.701354277133942
loss at step 109100: 2.633997344970703
loss at step 109120: 2.6749712347984316
loss at step 109140: 2.682475709915161
loss at step 109160: 2.64698805809021
loss at step 109180: 2.570146417617798
loss at step 109200: 2.680081069469452
validation loss: 2.7673524030049643
validation accuracy: 15.44762103505843%
loss at step 109220: 2.677096796035767
loss at step 109240: 2.654175853729248
loss at step 109260: 2.6872895479202272
loss at step 109280

loss at step 112640: 2.6645275473594667
loss at step 112660: 2.691979467868805
loss at step 112680: 2.6795437693595887
loss at step 112700: 2.669618558883667
loss at step 112720: 2.6545138120651246
loss at step 112740: 2.7186007261276246
loss at step 112760: 2.6112676441669462
loss at step 112780: 2.667857086658478
loss at step 112800: 2.727476692199707
validation loss: 2.7420052909851074
validation accuracy: 15.661519198664442%
loss at step 112820: 2.6203613758087156
loss at step 112840: 2.6709481835365296
loss at step 112860: 2.647661340236664
loss at step 112880: 2.662482261657715
loss at step 112900: 2.720310890674591
loss at step 112920: 2.618266224861145
loss at step 112940: 2.6424999117851256
loss at step 112960: 2.6215763092041016
loss at step 112980: 2.8250272393226625
loss at step 113000: 2.6372488379478454
loss at step 113020: 2.621204102039337
loss at step 113040: 2.7044130206108092
loss at step 113060: 2.6996842861175536
loss at step 113080: 2.566091442108154
loss at step 

loss at step 116420: 2.756659984588623
loss at step 116440: 2.714279556274414
loss at step 116460: 2.7071182131767273
loss at step 116480: 2.664304792881012
loss at step 116500: 2.614995503425598
loss at step 116520: 2.551398479938507
loss at step 116540: 2.6691449165344237
loss at step 116560: 2.6349333763122558
loss at step 116580: 2.664230000972748
loss at step 116600: 2.6055002450942992
loss at step 116620: 2.640992558002472
loss at step 116640: 2.664368748664856
loss at step 116660: 2.6880857586860656
loss at step 116680: 2.6151037931442263
loss at step 116700: 2.625714099407196
loss at step 116720: 2.6638700008392333
loss at step 116740: 2.630499076843262
loss at step 116760: 2.6483338952064512
loss at step 116780: 2.6879353761672973
loss at step 116800: 2.651603376865387
validation loss: 2.7430141480763752
validation accuracy: 15.625%
loss at step 116820: 2.6489794492721557
loss at step 116840: 2.6270684123039247
loss at step 116860: 2.6639470219612122
loss at step 116880: 2.627

## Generating
This model has the Fast Wavenet Generation Algorithm (https://arxiv.org/abs/1611.09482) implemented. This might run faster on the cpu. You can give some starting data (of at least the length of receptive field) or let the model generate from zero. In my experience, a temperature between 0.5 and 1.0 yields the best results, but this may depend on the data set.

In [None]:
start_data = data[250000][0] # use start data from the data set
start_data = torch.max(start_data, 0)[1] # convert one hot vectors to integers

def prog_callback(step, total_steps):
    print(str(100 * step // total_steps) + "% generated")

generated = model.generate_fast(num_samples=160000,
                                 first_samples=start_data,
                                 progress_callback=prog_callback,
                                 progress_interval=1000,
                                 temperature=1.0,
                                 regularize=0.)

In [None]:
import IPython.display as ipd

ipd.Audio(generated, rate=16000)