In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adadelta

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
from tqdm import tqdm_notebook as tqdm
from utils.dataloader import DataLoader

In [4]:
def data_gen(data_loader: DataLoader, size: int = 32, loops: int = -1, device=None):
    for x, xi, y, yi in data_loader.gen_batch(size, loops):
        x = torch.from_numpy(x).to(torch.long).to(device)
        xi = torch.from_numpy(xi).to(torch.long).to(device)
        y = torch.from_numpy(y).to(torch.long).to(device)
        yi = torch.from_numpy(yi).to(torch.long).to(device)
        yield x, xi, y, yi

In [5]:
data_loader = DataLoader()
data_loader

data:./data/zh.tsv
            series size:154988
            input data max length:51 words:1153
            target data max length:51 words:4461
            

In [13]:
class Tnet(nn.Module):
    def __init__(self,chars_count,target_count):
        super().__init__()
        self.embedding = nn.Embedding(chars_count,512)
        self.r1 = nn.GRU(512,1024,num_layers=2,batch_first=True,bidirectional=True)
        
        self.linear_rate = nn.Sequential(
            nn.Conv1d(2048,target_count,kernel_size=1),
            nn.Tanh(),
        )
        
        self.decoding = nn.Sequential(
            nn.Conv1d(target_count,target_count,kernel_size=1),
            nn.ReLU(),
            nn.LogSoftmax(dim=1),
        )
        
        map(nn.init.xavier_normal_,self.parameters())

    def forward(self,x:torch.Tensor):
        x = self.embedding(x)
        x,hidden = self.r1(x)
        x = self.linear_rate(x.transpose(1,2))
        labels = F.log_softmax(self.decoding(x),dim=1)
        x = F.log_softmax(x,dim=1).transpose(1,2)
        return x,labels
    
g = data_gen(data_loader)
x,xi,y,yi = next(g)
model = Tnet(data_loader.pinyin_numbers,data_loader.char_numbers)
out,labels = model(x)
out.shape,labels.shape

(torch.Size([32, 50, 4461]), torch.Size([32, 4461, 50]))

In [14]:
model = model.train().to(device=device,dtype=torch.float32)
optimizer = Adadelta(model.parameters(),lr=1)

In [15]:
bar = tqdm(data_gen(data_loader,size=128,loops=5000,device=device),total=5000)
for loop,(x,xi,y,yi) in enumerate(bar):
    optimizer.zero_grad()
    out,labels=model(x)
    loss_ctc=F.ctc_loss(out.transpose(0,1),y,torch.full((x.shape[0],),x.shape[1],device=device),yi)
    loss_ctc.backward(retain_graph=True)

    loss_label = F.nll_loss(labels.view(-1,labels.shape[1]),y.view(-1))
    loss_label.backward()
    optimizer.step()
    
    _,ypred = torch.max(labels,dim=1,keepdim=False)
    acc = ypred.eq(y.view_as(ypred)).sum().cpu().item()/(ypred.shape[0]*ypred.shape[1])
    if loop % 20 == 0:
        bar.set_postfix(ctc=f"{loss_ctc.item():0.6f}",label=f"{loss_label.item():0.6f}",acc=f"{acc:0.4f}")

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))

KeyboardInterrupt: 

In [41]:
x.shape,xi.shape,y.shape,yi.shape,out.transpose(0,1).shape

(torch.Size([32, 50]),
 torch.Size([32]),
 torch.Size([32, 50]),
 torch.Size([32]),
 torch.Size([50, 32, 1153]))

In [38]:
ctc_loss = nn.CTCLoss()
log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
input_lengths = torch.full((16,), 50, dtype=torch.long)
target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
loss

tensor(7.5836, grad_fn=<MeanBackward1>)