### Let's try to overfit the whole network (WaveNet core network + CTC layer) on a single (signal, sequence) pair.

#### First, import the necessities:

In [1]:
# set cwd:
%cd ~/Desktop/pytorch_models/wavenet-speech/
!pwd

/home/ptang/Desktop/pytorch_models/wavenet-speech
/home/ptang/Desktop/pytorch_models/wavenet-speech


In [2]:
# imports:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.autograd import Variable

from modules.wavenet import WaveNet
from modules.classifier import WaveNetClassifier
from warpctc_pytorch import CTCLoss

import os
import h5py
import numpy as np
from tqdm import tqdm

#### Now load a pair of data seq's as variables:

In [3]:
hf = h5py.File("data/bucketed_data.hdf5")
read_np = hf['bucket_0']['reads']['12'][:]
signal_np = hf['bucket_0']['signals']['12'][:]
hf.close()

In [4]:
# load read as target:
target_seq = Variable(torch.from_numpy(read_np).int())
target_length = Variable(torch.IntTensor([target_seq.size(0)]))
type(target_seq.data), target_seq.size(), target_length.data

(torch.IntTensor, torch.Size([429]), 
  429
 [torch.IntTensor of size 1])

In [5]:
# one-hot encode the signal:
signal_pt = torch.from_numpy(signal_np).long().view(1,-1)
signal_pt.size()

torch.Size([1, 4505])

In [6]:
signal = torch.zeros(1, 256, signal_pt.size(1)).scatter_(1, signal_pt.unsqueeze(0), 1.)

In [7]:
# quick look to make sure the encoding makes sense (you don't always need to run this):
#signal[0,:,123]

In [8]:
data_seq = Variable(signal)
type(data_seq.data), data_seq.size()

(torch.FloatTensor, torch.Size([1, 256, 4505]))

In [9]:
# define the cross-entropy target sequence as the dense signal:
xe_target_seq = Variable(signal_pt)
type(xe_target_seq.data), xe_target_seq.size()

(torch.LongTensor, torch.Size([1, 4505]))

#### Now let's construct the model, the optimizers, and the loss functions.

In [10]:
# model construction:
wavenet_dils = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512,
                1, 2, 4, 8, 16, 32, 64, 128, 256, 512,
                1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
classifier_dils = [1, 2, 4, 8, 16, 32,
                   1, 2, 4, 8, 16, 32,
                   1, 2, 4, 8, 16, 32,]
downsample_rate = 4
num_labels = 5 # == |{A,G,C,T,-}|
out_dim = 256
num_levels = 256
wavenet_layers = [(num_levels, num_levels, 2, d) for d in wavenet_dils]
classifier_layers = [(num_levels, num_levels, 2, d) for d in classifier_dils ]
wavenet = WaveNet(num_levels, 2, wavenet_layers, num_levels, softmax=False)
classifier = WaveNetClassifier(num_levels, num_labels, classifier_layers, out_dim, pool_kernel_size=downsample_rate,
                               input_kernel_size=2, input_dilation=1, softmax=False)

In [11]:
# loss functions:
ctc_loss_fn = CTCLoss()
xe_loss_fn = nn.CrossEntropyLoss()

In [12]:
# optimizers:
#wavenet_optimizer = optim.Adam(wavenet.parameters())
#ctc_optimizer = optim.Adam(classifier.parameters())
#joint_optimizer = optim.Adam(list(wavenet.parameters()) + list(classifier.parameters()),
#                             lr=0.0001, weight_decay=0.001)
#scheduler = ReduceLROnPlateau(joint_optimizer, 'min')
joint_optimizer = optim.Adagrad(list(wavenet.parameters())+list(classifier.parameters()),
                                lr=0.00003)

#### Now let's define convenient shorthand closures to train the network:

In [13]:
def predict():
    _xe_prediction = wavenet(data_seq)
    _probs = classifier(_xe_prediction)
    return (_xe_prediction, _probs)

def train(log=False):
    # clear gradients:
    #wavenet_optimizer.zero_grad()
    #ctc_optimizer.zero_grad()
    joint_optimizer.zero_grad()

    # make a prediction:
    xe_prediction, probs = predict()

    # compute cross-entropy loss against *shifted* dense signal:
    xe_loss = 0.
    for t in range(xe_prediction.size(2)-1):
        xe_loss = xe_loss + xe_loss_fn(xe_prediction[:,:,t], xe_target_seq[:,t+1])

    # compute CTC loss against labels:
    probs_rearranged = probs.permute(2,0,1).contiguous()
    prob_lengths = Variable(torch.IntTensor([probs.size(2)]))
    target_seq_shifted = target_seq + Variable(torch.ones(target_seq.size()).int()) # because label 0 == <BLANK>
    ctc_loss = ctc_loss_fn(probs_rearranged, target_seq_shifted, prob_lengths, target_length)
    
    # backprop (choose one):
    total_loss = xe_loss + ctc_loss
    avg_xe_loss = xe_loss / data_seq.size(2)
    avg_ctc_loss = ctc_loss / probs.size(2)
    avg_loss = avg_xe_loss + avg_ctc_loss
    #avg_loss.backward()
    avg_ctc_loss.backward()
    
    # apply gradient descent updates:
    #ctc_optimizer.step()
    #wavenet_optimizer.step()
    joint_optimizer.step()
    #scheduler(ctc_loss / probs.size(2))
    
    # return all values of interest for logging:
    if log: return (total_loss.data[0], xe_loss.data[0], ctc_loss.data[0])

#### Main training loop (run this cell and the next block of prediction/view cells multiple times until convergence)

In [14]:
# train loop:
num_iterations = 1000
log_every = 10
for k in tqdm(range(num_iterations)):
    if k % log_every != 0:
        train(log=False)
    else:
        tot_, xe_, ctc_ = train(log=True)
        xe_pc = xe_ / data_seq.size(2)
        ctc_pc = ctc_ / target_seq.size(0)
        print(("XE+CTC Loss Tot: {0:07.4f}={1:07.4f}+{2:07.4f} | " + \
               "Per-Sample XE: {3:07.4f} | Per-Label CTC: {4:07.4f}").format(tot_, xe_, ctc_, xe_pc, ctc_pc))
        if (ctc_pc < 0.3):
            print("Early stopping!")
            break

  0%|          | 1/1000 [00:34<9:33:45, 34.46s/it]

XE+CTC Loss Tot: 39974.9805=38791.2070+1183.7747 | Per-Sample XE: 08.6107 | Per-Label CTC: 02.7594


  1%|          | 11/1000 [05:51<8:41:24, 31.63s/it]

XE+CTC Loss Tot: 39400.8438=38830.5078+570.3347 | Per-Sample XE: 08.6194 | Per-Label CTC: 01.3295


  2%|▏         | 21/1000 [11:05<8:32:36, 31.42s/it]

XE+CTC Loss Tot: 39255.3242=38775.4570+479.8654 | Per-Sample XE: 08.6072 | Per-Label CTC: 01.1186


  3%|▎         | 31/1000 [16:14<8:14:06, 30.60s/it]

XE+CTC Loss Tot: 39156.5352=38764.1680+392.3681 | Per-Sample XE: 08.6047 | Per-Label CTC: 00.9146


  4%|▍         | 41/1000 [21:17<8:08:17, 30.55s/it]

XE+CTC Loss Tot: 39125.9453=38788.3633+337.5819 | Per-Sample XE: 08.6101 | Per-Label CTC: 00.7869


  5%|▌         | 51/1000 [26:23<8:08:32, 30.89s/it]

XE+CTC Loss Tot: 39043.8086=38776.1250+267.6847 | Per-Sample XE: 08.6074 | Per-Label CTC: 00.6240


  6%|▌         | 61/1000 [31:42<8:22:56, 32.14s/it]

XE+CTC Loss Tot: 38997.3203=38805.2852+192.0343 | Per-Sample XE: 08.6138 | Per-Label CTC: 00.4476


  7%|▋         | 71/1000 [36:56<8:05:36, 31.36s/it]

XE+CTC Loss Tot: 38983.9688=38831.0078+152.9609 | Per-Sample XE: 08.6195 | Per-Label CTC: 00.3566


  8%|▊         | 80/1000 [41:41<8:02:42, 31.48s/it]

XE+CTC Loss Tot: 38923.9492=38823.3047+100.6434 | Per-Sample XE: 08.6178 | Per-Label CTC: 00.2346
Early stopping!


In [16]:
# save the model to disk:
torch.save(wavenet.state_dict(), "./ipynbs/wavenet_model.overfit.pth")
torch.save(classifier.state_dict(), "./ipynbs/classifier_model.overfit.pth")

#### These next few cells inspect predictions; run these after each run of the training block above:

In [17]:
# generate fresh predictions:
_, ctc_preds = predict()

In [23]:
# print outputs:
_lookup_ = {0: '<BLANK>', 1: 'A', 2: 'G', 3: 'C', 4: 'T'}
print_blanks = False
pred_labels = []
for k in range(ctc_preds.size(2)):
    logit, label = torch.max(torch.nn.functional.softmax(ctc_preds[0,:,k]), dim=0)
    logit_py = float(logit.data[0])
    label_py = _lookup_[int(label.data[0])]
    if (not print_blanks) and (label_py == '<BLANK>'): continue
    print("Called: {0} | Proba: {1:1.4f}".format(label_py, logit_py))
    pred_labels.append(label_py)

Called: A | Proba: 0.8391
Called: C | Proba: 0.8950
Called: A | Proba: 0.9490
Called: T | Proba: 0.6770
Called: T | Proba: 0.7874
Called: A | Proba: 0.9319
Called: C | Proba: 0.9179
Called: T | Proba: 0.8514
Called: T | Proba: 0.7239
Called: T | Proba: 0.6321
Called: C | Proba: 0.9145
Called: G | Proba: 0.8421
Called: T | Proba: 0.8787
Called: T | Proba: 0.9059
Called: G | Proba: 0.9116
Called: A | Proba: 0.6849
Called: T | Proba: 0.7350
Called: T | Proba: 0.7878
Called: A | Proba: 0.9633
Called: C | Proba: 0.9220
Called: G | Proba: 0.7141
Called: T | Proba: 0.7282
Called: T | Proba: 0.7633
Called: A | Proba: 0.8934
Called: T | Proba: 0.9409
Called: T | Proba: 0.7898
Called: G | Proba: 0.8589
Called: C | Proba: 0.8013
Called: T | Proba: 0.8448
Called: G | Proba: 0.8460
Called: A | Proba: 0.9286
Called: A | Proba: 0.7578
Called: A | Proba: 0.7427
Called: T | Proba: 0.8852
Called: C | Proba: 0.8919
Called: C | Proba: 0.8216
Called: T | Proba: 0.9485
Called: C | Proba: 0.8678
Called: G | 

In [24]:
print("".join(pred_labels))

ACATTACTTTCGTTGATTACGTTATTGCTGAAATCCTCGAAAGCGATATTCCTCTTTTGCAGATTTTTAACAAAAGTGGTTTTCAAAACTGCTCTATTCAAAAGAAAGGTTCCAGCTCTCTATTTTAGTTGAGGGCACATCACAAATAAATAAACCTTTAACAGAATGGCCTTCTGTCTAGTTTTCCACGGGAAGATCGTATTTCCTTTTCACCCAATACGCCTGAAAGCGCTCAAATGTCCATATTTCAGATACCTGCAAAAGAGTTGTTTCCCAAGCCTGCTCTATGAGGAATGCTCCAGCTCTGTAGATTGAATAGACGTTACAAAAGTTTCCGGATGCTTGCTTGTATCTCCTTTTTATAATTAATTAGTCCGGTTTCCAAGCAGAATCCTTCAAAGCTTGGCAAATATCCACTTTGCAGATTCCACGAAAAACGGTGTTTCAGAACTGCTCCTTCAAAGGCAGTAGTTTTTCAAAATTTTC


In [20]:
print("".join(_lookup_[ix] for ix in list(target_seq.data + torch.ones(target_seq.size()).int())))

ACATACTTCGTTGATTACGTATTGCTGAAATCCTCGAAGCGATATCCTCTTGCAGATTTTACAAAGTGTTTCAAAACTGCTCTATCAAAGAAAGGTTCCAGCTCTCTATTAGTTGAGGCACATCACAAATAATAAACTACAGAATGCTTCTGTCTAGTTTTCACGGGAAGATCGTATTTCCTTTTCACCCATACGCCTGAAAGCGCTCAAATGTCATATTCAGATACTGCAAAGAGTGTTTCCAGCCTGCTCTATGAGGAATGCTCAGCTCTGTAGATTGAATAGACGTACAAAGTTTCTGAGATGCTGCTGTATCTCCTTTTATATATAGTCCGTTTCCAGCAGAATCCTCAAAGCTGGCAAATATCCACTTGCAGATTCACGAAAACGGTGTTTCAGAACTGCTCCTTCAAGGCAGTAGTTTCAATC


#### CTC debugging:

In [16]:
# taken from the WarpCTC tests:
_activs = Variable(torch.FloatTensor([[[-10., -9., -8., -7., -6.]]]).transpose(0, 1).contiguous(), requires_grad=True)
_activ_sizes = Variable(torch.IntTensor([1]))
_labels = Variable(torch.IntTensor([3]))
_label_sizes = Variable(torch.IntTensor([1]))
print(ctc_loss_fn(_activs, _labels, _activ_sizes, _label_sizes))

Variable containing:
 1.4519
[torch.FloatTensor of size 1]

