### TODOs:
* Remember to set input sequences as `volatile` to reduce memory usage and prevent backprops

In [106]:
# import essentials:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import numpy as np
import h5py

In [107]:
%cd ~/Desktop/pytorch_models/wavenet-speech
%pwd

/home/ptang/Desktop/pytorch_models/wavenet-speech


'/home/ptang/Desktop/pytorch_models/wavenet-speech'

In [108]:
from modules.wavenet import WaveNet
from modules.classifier import WaveNetClassifier
from warpctc_pytorch import CTCLoss

In [118]:
# model architecture & configuration settings
wavenet_dils = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512,
                1, 2, 4, 8, 16, 32, 64, 128, 256, 512,
                1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
classifier_layers = [(256, 256, 2, d) for d in [1, 2, 4] ]
downsample_rate = 1
num_labels = 5 # == |{A,G,C,T,-}|
out_dim = 128
num_levels = 256
wavenet_model_restore_path = "./runs/artificial/wavenet_model.3.pth"
classifier_model_restore_path = "./runs/artificial/classifier_model.3.pth"

In [119]:
# load HDF5:
dataset_path = "./data/artificial.large.hdf5"
dataset = h5py.File(dataset_path, 'r')

In [120]:
print(list(dataset['bucket_0'].keys()))

['read_lengths', 'reads', 'signal_lengths', 'signals']


In [121]:
# reconstruct wavenet & classifier; restore weights:
wavenet = WaveNet(num_levels, 2, [(num_levels, num_levels, 2, d) for d in wavenet_dils],
                  num_levels, softmax=False)
classifier = WaveNetClassifier(num_levels, num_labels, classifier_layers, out_dim,
                               pool_kernel_size=downsample_rate,
                               input_kernel_size=2, input_dilation=1,
                               softmax=False)
### load their saved values:
wavenet.load_state_dict(torch.load(wavenet_model_restore_path))
print("Restored WaveNet weights from: {}".format(wavenet_model_restore_path))
classifier.load_state_dict(torch.load(classifier_model_restore_path))
print("Restored classifier weights from: {}".format(classifier_model_restore_path))

Restored WaveNet weights from: ./runs/artificial/wavenet_model.3.pth
Restored classifier weights from: ./runs/artificial/classifier_model.3.pth


In [141]:
# closured function to run wavenet+classifier architecture stack on an example:
predict = lambda signal: classifier(wavenet(signal))
def run_example(ex_num):
    # load data:
    signal = torch.from_numpy(dataset['bucket_0']['signals'][str(ex_num)][:]).long()
    target = torch.from_numpy(dataset['bucket_0']['reads'][str(ex_num)][:])

    # one-hot encoding:
    one_hot_signal = torch.zeros(1, 256, signal.size(0)).scatter_(1, signal.view(1,1,-1), 1.)
    print("Signal dimensions: {}".format(one_hot_signal.size()))
    
    # prediction:
    sequence = predict(Variable(one_hot_signal, volatile=True))
    print("Predicted sequence dimensions: {}".format(sequence.size()))
    
    # print predictions:
    _, ixs = torch.max(sequence, dim=1)
    print([int(ixs[0].data[k]) for k in range(sequence.size(2)) if int(ixs[0].data[k] != 4)])

    # print target:
    print([int(t) for t in target])

    # print logits:
    for k in range(sequence.size(2)):
        print("Logits @ timestep {0}: {1}".format(k, sequence[0,:,k].data[:]))

In [150]:
run_example(32)

Signal dimensions: torch.Size([1, 256, 60])
Predicted sequence dimensions: torch.Size([1, 5, 60])
[3, 1, 1, 1, 1, 1, 2]
[2, 2, 0, 2, 0, 2, 1, 2, 2, 0, 1, 1, 1, 3, 0, 2, 3, 3, 1, 0]
Logits @ timestep 0: 
-1.3892
-0.5489
-1.0567
 0.4431
 2.2213
[torch.FloatTensor of size 5]

Logits @ timestep 1: 
-2.4674
-1.1456
-0.9009
-0.1568
 1.3085
[torch.FloatTensor of size 5]

Logits @ timestep 2: 
-2.8280
-0.5344
-1.0201
-0.3894
 1.6938
[torch.FloatTensor of size 5]

Logits @ timestep 3: 
-1.0470
-0.3755
-1.3838
-1.0431
 1.7214
[torch.FloatTensor of size 5]

Logits @ timestep 4: 
-1.3484
-0.1497
-1.0086
-1.0034
 0.5022
[torch.FloatTensor of size 5]

Logits @ timestep 5: 
-1.6486
-0.8104
-1.0608
-1.1816
 1.0426
[torch.FloatTensor of size 5]

Logits @ timestep 6: 
-1.9605
-0.3543
-0.9730
-0.7464
 0.5455
[torch.FloatTensor of size 5]

Logits @ timestep 7: 
-1.5585
-0.6893
-0.7677
-0.6477
 1.1811
[torch.FloatTensor of size 5]

Logits @ timestep 8: 
-1.3335
-0.3750
-1.2926
-0.3458
 1.3040
[torch.FloatT