### RawCTCNet Benchmark/Eval with trained model: CTCLoss of approx. 1.2 (best: 1.03)

In [2]:
# switch to toplevel dir:
%cd ~/Desktop/pytorch_models/wavenet-speech/
!pwd
%load_ext autoreload
%autoreload 2

/home/ptang/Desktop/pytorch_models/wavenet-speech
/home/ptang/Desktop/pytorch_models/wavenet-speech


In [3]:
# imports:
import torch
from torch.autograd import Variable
import numpy as np
from warpctc_pytorch import CTCLoss

In [4]:
# import gaussian model, RawCTCNet, sequential decoder:
from utils.gaussian_kmer_model import RawGaussianModelLoader
from utils.pore_model import PoreModelLoader
from modules.raw_ctcnet import RawCTCNet
from modules.sequence_decoders import argmax_decode, labels2strings, BeamSearchDecoder

### Construct data generator from gaussian model using the same parameters as we did during training:

In [5]:
# create artificial data model:
max_iterations = 1000000 # 1 million examples
num_epochs = 100
epoch_size = 10000
kmer_model_path = "utils/r9.4_450bps.5mer.template.npz"
batch_size = 8
upsample_rate = 3
min_sample_len = 80
max_sample_len = 90
dataset = RawGaussianModelLoader(max_iterations, num_epochs, epoch_size, kmer_model_path, batch_size=batch_size,
                                 upsampling=upsample_rate, lengths=(min_sample_len,max_sample_len))

In [6]:
# inspect dataset:
dataset.fetch()

(Variable containing:
   73.8920   73.8920   73.8920  ...   103.5742  103.5742  103.5742
  100.0079  100.0079  100.0079  ...     0.0000    0.0000    0.0000
  100.7313  100.7313  100.7313  ...     0.0000    0.0000    0.0000
              ...                ⋱                ...             
  107.1002  107.1002  107.1002  ...     0.0000    0.0000    0.0000
  100.4840  100.4840  100.4840  ...     0.0000    0.0000    0.0000
   89.7458   89.7458   89.7458  ...     0.0000    0.0000    0.0000
 [torch.FloatTensor of size 8x243], Variable containing:
  4
  3
  2
  1
  3
  3
  1
  2
  2
  1
  1
  2
  1
  2
  3
  4
  2
  3
  3
  4
  1
  3
  3
  1
  2
  3
  3
  1
  1
  3
  2
  4
  1
  1
  3
  2
  1
  2
  4
  1
  4
  4
  3
  1
  2
  4
  3
  4
  1
  2
  1
  4
  4
  2
  1
  3
  2
  4
  3
  3
  2
  2
  4
  1
  4
  4
  1
  4
  4
  4
  2
  2
  4
  2
  4
  2
  3
  2
  3
  3
  2
  2
  4
  3
  2
  2
  1
  2
  2
  2
  3
  4
  3
  2
  4
  1
  1
  4
  3
  3
  1
  3
  2
  1
  1
  1
  4
  1
  4
  4
  3
  1
  4


### Construct model with same parameters as during training and load saved models:

In [7]:
# build model:
nfeats = 2048
nhid = 512
feature_kwidth = 3
num_labels = 5
num_dilation_blocks = 10
dilations = [1, 2, 4, 8, 16] * num_dilation_blocks
layers = [(nhid, nhid, 2, d) for d in dilations] + [(nhid, nhid, 3, d) for d in dilations]
out_dim = 512
is_causal = False
ctcnet = RawCTCNet(nfeats, feature_kwidth, num_labels, layers, out_dim, input_kernel_size=2, input_dilation=1,
                   softmax=False, causal=is_causal)
batch_norm = torch.nn.BatchNorm1d(1)

In [8]:
# load saved model parameters:
ctcnet_save_path = "./runs/gaussian-model/raw_ctc_net.model.adam_lr1e_5.pth"
batchnorm_save_path = "./runs/gaussian-model/raw_ctc_net.batch_norm.adam_lr1e_5.pth"
map_cpu = lambda storage, loc: storage
ctcnet.load_state_dict(torch.load(ctcnet_save_path, map_location=map_cpu))
batch_norm.load_state_dict(torch.load(batchnorm_save_path, map_location=map_cpu))

In [9]:
# CTCLoss:
ctc_loss_fn = CTCLoss()

### Helper function to fetch & evaluate model on data:

In [10]:
def eval_model():
    # use volatile variables for better execution speed/memory usage:
    signals, sequences, lengths = dataset.fetch()
    signals_var = Variable(signals.data, volatile=True)
    sequences_var = Variable(sequences.data, volatile=True)
    lengths_var = Variable(lengths.data, volatile=True)
    # run networks:
    probas = ctcnet(batch_norm(signals_var.unsqueeze(1)))
    transcriptions = probas.permute(2,0,1) # need seq x batch x dim
    transcription_lengths = Variable(torch.IntTensor([transcriptions.size(0)] * batch_size))
    ctc_loss = ctc_loss_fn(transcriptions, sequences_var, transcription_lengths, lengths_var)
    avg_ctc_loss = (ctc_loss / transcriptions.size(0))
    return (transcriptions, ctc_loss, avg_ctc_loss, sequences.data, lengths.data)

In [11]:
def split_target_seqs(seqs, lengths):
    """Split a flattened array of target sequences into their constituents."""
    split_seqs = []
    labels_parsed = 0
    for ll in lengths:
        split_seqs.append( seqs[labels_parsed:labels_parsed+ll] )
        labels_parsed += ll
    return split_seqs

### evaluate results against true sequences with argmax and beam search (run these commands in sequence a few times):

In [17]:
logits, loss, avg_loss, true_seqs, true_seq_lengths = eval_model()
print("CTC Loss on whole sequence: {}".format(loss.data[0]))
print("CTC Loss, averaged per-logit: {}".format(avg_loss.data[0]))

CTC Loss on whole sequence: 279.40179443359375
CTC Loss, averaged per-logit: 1.1404154300689697


In [18]:
# normalize probabilities with a softmax operation:
for k in range(len(logits)):
    logits[k,:,:] = torch.nn.functional.softmax(logits[k,:,:])

In [19]:
# print true sequences:
true_base_sequences = split_target_seqs(true_seqs, true_seq_lengths)
for k in range(len(true_base_sequences)):
    print(labels2strings(true_base_sequences[k].unsqueeze(0))[0])

CTAAAACGAGCAGCCTAGCCTCTACATACGAAATTGAAACAGCGATCCAGCTCACAGCGACCACCGCTAGGACAAGGCAT
AGTTCACCCTTCAATGCTTAATGGGATTCCAGTTAGACGGGCTATTACGTCCGCGGATTCTACGTCCATATAACCGATTTCGGATAATT
CACTGTATAGCACTGTTGACTGCCACGACCATGGATAGTTGAGAGTAGCCCCCAAGGGCTGGAACAAGGAAGGAGCTTCCGAA
AGACGGTATCATGTCATCTGTCAAAGCAAACGACGGTGGTGACCTAAAAGATGACGAGTTAAGCTTGTGGAGCACATACTG
TCTTCACGCAAAGAGGCCGATTTGTGCCTATTGTACACGGGCCTATACCTCACACTTTAATGTCTGCCCAGCTCTACTTATTAGACCC
ACAATCGCAAACTGGACTAAAGGAATACCTCAAATACCCATAGTGGTAAGTACCGTCCTCGTTGGTGCAGCTGAGATCAATACAGC
TATCCCCTCCTGAATGCTAAACAGTAGGATTCCGGTCGGACCTTGACTGCGTTTCTAGCGGTATCCTTTTTGAGCATCCCGGTGTCGTC
CTGGAGACGCGTCTACGAGGATTCCGGCACCCTCCCGACGCCCGTTCATAATAGTGATCTACGGATGACTGCTGTCAGAACCAG


In [20]:
# argmax decoding: expects (batch, seq, dim) and returns (batch, seq)
argmax_decoded = argmax_decode(logits.permute(1,0,2).contiguous().data)
argmax_basecalls = labels2strings(argmax_decoded)
for k in range(len(argmax_decoded)):
    print(argmax_basecalls[k])

CGAGAAACGAGCAGCATTGCCTATCCATACTACATTGAAACCGAGATCCAGCTCACAGCGCCCACCGCTAGGACAAGCTT
TGAAAACCTTCCATTGCTTAATTGGGATTCCAGTTAGCCGGGCTATTACGTACGAGGATTCTAGTTCCATATAACCGATTTCGTAGGCG
GAGTCGAGCACTTGTTGACTGCCACGCCCATTGGCGAGTTGAGAGTAGCCACCAAGGGCTTGGAAAAAGGAATGAGATTCCCCCT
CTGGGTTTATAATTGGCATATGTCAACGCAAACGACGGTGGTGACATAAAAGATTGACGAGTTAAGCTTGTGGAGCACATCTT
TTACCCGCAAAGAGGCAGATTTGTGACTATTGTCCACCGGGCATATACCTCCCACTTTAATTTGTATGCCCAGCTATACTTTATTTTGGACG
GTTTCGCAAAATTGGCATAAAGGAATACCTCAAATACCCAGGGTGGTAATTACAGTCCTAGTTGGTGCAGATTGAGAGCCATTGTTCT
TTACCCCTCATGACTTGCTAAACCGTAGGATTCAGGTGGGCCCTTGACTTGCGGTTCTCGCGGTATCCTTTTTGAGCATACCGGGTGTCTTA
GCTGACGCGTCTACGAGGATTAAGGCACCCTCACGCAGCCCGTTCAGACTAGTGATATACGGATGAATTGCTTGTAATACTTTT


In [21]:
# beam search decoded: expects (batch, dim, seq)
beam_search_decoder = BeamSearchDecoder(batch_size=batch_size, num_labels=5, beam_width=6)
probas, hyp_seqs = beam_search_decoder.decode(logits.permute(1, 2, 0))

In [22]:
print("Normalized probabilities:")
for k in range(len(probas)):
    print(probas[k] / logits.size(0))

Normalized probabilities:
0.2951114888093909
0.29567472496811226
0.29745726293447067
0.29748092962771044
0.2956850947165976
0.2951805581851881
0.29634259282326214
0.2959955643634407


In [23]:
lookup_dict = {0: '', 1: 'A', 2: 'G', 3: 'C', 4: 'T', 5: '<SOS>', 6: '<EOS>'}
for ll in range(len(hyp_seqs)):
    print("".join([lookup_dict[lbl] for lbl in hyp_seqs[ll]]))

<SOS>CGAGAAACGAGCAGCATTGCCTATCCATACTACATTGAAACCGAGATCCAGCTCACAGCGCCCACCGCTAGGACAAGCTT<EOS>
<SOS>TGAAAACCTTCCATTGCTTAATTGGGATTCCAGTTAGCCGGGCTATTACGTACGAGGATTCTAGTTCCATATAACCGATTTCGTAGGCG<EOS>
<SOS>GAGTCGAGCACTTGTTGACTGCCACGCCCATTGGCGAGTTGAGAGTAGCCACCAAGGGCTTGGAAAAAGGAATGAGATTCCCCCT<EOS>
<SOS>CTGGGTTTATAATTGGCATATGTCAACGCAAACGACGGTGGTGACATAAAAGATTGACGAGTTAAGCTTGTGGAGCACATCTT<EOS>
<SOS>TTACCCGCAAAGAGGCAGATTTGTGACTATTGTCCACCGGGCATATACCTCCCACTTTAATTTGTATGCCCAGCTATACTTTATTTTGGACG<EOS>
<SOS>GTTTCGCAAAATTGGCATAAAGGAATACCTCAAATACCCAGGGTGGTAATTACAGTCCTAGTTGGTGCAGATTGAGAGCCATTGTTCT<EOS>
<SOS>TTACCCCTCATGACTTGCTAAACCGTAGGATTCAGGTGGGCCCTTGACTTGCGGTTCTCGCGGTATCCTTTTTGAGCATACCGGGTGTCTTA<EOS>
<SOS>GCTGACGCGTCTACGAGGATTAAGGCACCCTCACGCAGCCCGTTCAGACTAGTGATATACGGATGAATTGCTTGTAATACTTTT<EOS>
