### Let's examine the results of the pore model on pore-width 1. First, let's get set up by importing and switching directories to top-level:

In [1]:
# switch to top-level dir:
%cd ..

/home/ptang/Desktop/pytorch_models/wavenet-speech


In [3]:
# import all the essentials:
import torch
import numpy as np
from modules.wavenet import WaveNet
from modules.classifier import WaveNetClassifier
from utils.loaders import PoreModelLoader

### Now let's build a pore model with the same settings as was in the training loop:

In [7]:
# Low noise and small pore width:
num_levels = 256
num_iterations = 100000
num_epochs = 50
batch_size = 16
epoch_size = 2500
nt_sample_lengths = (90,110)
pore_width = 1
srate = 4
noise = 1.
nt_to_pa = { 0: 51., 1: 22., 2: 103., 3: 115. }
dataloader = PoreModelLoader(num_iterations, num_epochs, epoch_size,
                             batch_size=batch_size, num_levels=num_levels, lengths=nt_sample_lengths,
                             pore_width=pore_width, sample_rate=srate, currents_dict=nt_to_pa,
                             sample_noise=noise)

### Now let's instantiate a model and restore the model weights:

In [12]:
num_labels = 5
out_dim = 256
downsample_rate = 1
wavenet_dils = [1, 2, 3, 4,
                1, 2, 3, 4,
                1, 2, 3, 4]
classifier_dils = [1, 2, 3, 4,
                   1, 2, 3, 4,
                   1, 2, 3, 4]

wavenet = WaveNet(num_levels, 2, [(num_levels, num_levels, 2, d) for d in wavenet_dils], num_levels, softmax=False)
classifier = WaveNetClassifier(num_levels, num_labels, [(num_levels, num_levels, 3, d) for d in classifier_dils],
                               out_dim, pool_kernel_size=downsample_rate, input_kernel_size=2, input_dilation=1,
                               softmax=False)

# restore model weights (`map_location` moves weights from CUDA to CPU):
wavenet.load_state_dict(torch.load("./runs/artificial/noiseless_porewidth_1/wavenet_model.loss0_5014.pth",
                                   map_location=lambda storage, loc: storage))
classifier.load_state_dict(torch.load("./runs/artificial/noiseless_porewidth_1/classifier_model.loss0_5014.pth",
                                      map_location=lambda storage, loc: storage))

### Define a closured function to run the WaveNet-CTC model on inputs:

In [13]:
def run_model(signal):
    intermediate_signal = wavenet(signal)
    transcription = classifier(intermediate_signal)
    return transcription

In [45]:
signal, seq, seq_lengths = dataloader.fetch()

In [46]:
ctc_preds = run_model(torch.autograd.Variable(signal.data, volatile=True))

In [47]:
# print outputs:
_lookup_ = {0: '<BLANK>', 1: 'A', 2: 'G', 3: 'C', 4: 'T'}
batch_ix = 1 # (choose which sequence of the batch you want to look at)
print_blanks = False
pred_labels = []
for k in range(ctc_preds.size(2)):
    logit, label = torch.max(torch.nn.functional.softmax(ctc_preds[batch_ix,:,k]), dim=0)
    logit_py = float(logit.data[0])
    label_py = _lookup_[int(label.data[0])]
    if (not print_blanks) and (label_py == '<BLANK>'): continue
    print("Called: {0} | Proba: {1:1.4f}".format(label_py, logit_py))
    pred_labels.append(label_py)

Called: G | Proba: 0.9826
Called: G | Proba: 0.9998
Called: G | Proba: 0.9930
Called: C | Proba: 0.9407
Called: C | Proba: 0.9853
Called: T | Proba: 0.5158
Called: T | Proba: 0.9984
Called: T | Proba: 0.7218
Called: C | Proba: 0.9644
Called: C | Proba: 0.9669
Called: C | Proba: 0.9767
Called: C | Proba: 0.8775
Called: A | Proba: 0.7157
Called: A | Proba: 0.9828
Called: A | Proba: 0.9037
Called: G | Proba: 0.9540
Called: G | Proba: 0.7187
Called: A | Proba: 0.5102
Called: A | Proba: 0.7881
Called: A | Proba: 0.8361
Called: A | Proba: 0.5533
Called: G | Proba: 0.9060
Called: G | Proba: 0.9976
Called: G | Proba: 0.7879
Called: T | Proba: 0.5262
Called: T | Proba: 0.9461
Called: A | Proba: 0.8637
Called: G | Proba: 0.7460
Called: G | Proba: 0.7094
Called: G | Proba: 0.7752
Called: G | Proba: 0.9541
Called: C | Proba: 0.8204
Called: C | Proba: 0.9691
Called: A | Proba: 0.7739
Called: A | Proba: 0.9724
Called: A | Proba: 0.9733
Called: A | Proba: 0.9731
Called: A | Proba: 0.8689
Called: T | 

In [48]:
print("".join(pred_labels))

GGGCCTTTCCCCAAAGGAAAAGGGTTAGGGGCCAAAAATTTTGGGCCCCTGGCCAATTAAGGCCGGGAAAGGGTTAAATTTGGGGTAACCCAAAAATTCCCAAAGGCCTTTGGGGGTTCCCGGGAAAGGCCTGGGCCCCCGGGTTTAAAAACCTTTCCGGGGTTTTCCCAAAACCCTTTTAAGGGTTCCCCCACCTTTGGGGTTTTTAAAGGTTTGGCCCCCTTGGAAATT


In [52]:
# look up target sequence by using the sequence lengths as an index:
def batch_index_lookup(bix, seq_lens):
    start = torch.sum(seq_lens[0:bix])
    stop = start+seq_lens[bix]
    return (start,stop)
_s0,_s1 = batch_index_lookup(batch_ix, seq_lengths)
s0 = int(_s0.data[0])
s1 = int(_s1.data[0])
print("".join(_lookup_[ix] for ix in list(seq[s0:s1].data + torch.ones(seq[s0:s1].size()).int())))

GCTCCAGAAGTAGCAATTGCTGCATAGCGGAGTATGGTACAATCAGCTGGTCGAGCTGGCCGTAACTCGGTTCACTTAGTCCACTGGTTAGTGCCCTGATT
