In [None]:
from torch import nn, optim

from wavenet.audiodata import AudioData, AudioLoader
from wavenet.models_torch import Model

from IPython.display import Audio

%matplotlib inline

x_len = 2**10
num_classes = 256
num_blocks = 2
num_layers = 9
num_hidden = 64
kernel_size = 2
learn_rate = 0.001
step_size = 50
gamma = 0.5
num_workers = 1
batch_size = 8

## Create dataset

In [None]:
filelist = ['assets/classical.wav']
dataset = AudioData(filelist, x_len, classes=num_classes, store_tracks=True)

Audio(dataset.tracks[0]['audio'], rate=dataset.tracks[0]['sample_rate'])

## Create dataloader

In [None]:
dataloader = AudioLoader(dataset, batch_size=batch_size, 
                         num_workers=num_workers)

In [None]:
ins, outs = dataset.__getitem__(0)
print(dataset.datarange)
print(ins)
print(outs)
print(dataset.encoder.decode(ins))
print(dataset.label2value(outs))

## Define training parameters

In [None]:
wave_model = Model(x_len, num_channels=1, num_classes=num_classes, 
                   num_blocks=num_blocks, num_layers=num_layers,
                   num_hidden=num_hidden, kernel_size=kernel_size)

wave_model.criterion = nn.CrossEntropyLoss()
wave_model.optimizer = optim.Adam(wave_model.parameters(), 
                                  lr=learn_rate)
wave_model.scheduler = optim.lr_scheduler.StepLR(wave_model.optimizer, 
                                                 step_size=step_size, 
                                                 gamma=gamma)

## Train model

In [None]:
wave_model.train(dataloader)