In [1]:
%load_ext autoreload
%autoreload 2

import time
from modules.transformer import *
from utils import *

In [2]:
def run_epoch(data_iter, model, loss_compute):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    for i, batch in enumerate(data_iter):
        out = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        loss = loss_compute(out, batch.trg_y, batch.ntokens)
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 50 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" % 
                  (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
    return total_loss / total_tokens

In [3]:
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
    "Keep augmenting batch and calculate total number of tokens + padding."
    global max_src_in_batch, max_tgt_in_batch
    if count == 1:
        max_src_in_batch = 0
        max_tgt_in_batch = 0
    max_src_in_batch = max(max_src_in_batch,  len(new.src))
    max_tgt_in_batch = max(max_tgt_in_batch,  len(new.trg) + 2)
    src_elements = count * max_src_in_batch
    tgt_elements = count * max_tgt_in_batch
    return max(src_elements, tgt_elements)

In [4]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))
        
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

In [5]:
class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction='sum')
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = torch.zeros_like(x)  # batch size, vocab size
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, true_dist)

In [6]:
def data_gen(V, batch, nbatches):
    "Generate random data for a src-tgt copy task."
    for i in range(nbatches):
        data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))
        data[:, 0] = 1
        src = data
        tgt = data
        yield Batch(src, tgt, 0)

class SimpleLossCompute:
    "A simple loss compute and train function."
    def __init__(self, generator, criterion, opt=None):
        self.generator = generator
        self.criterion = criterion
        self.opt = opt
        
    def __call__(self, x, y, norm):
        x = self.generator(x)
        loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 
                              y.contiguous().view(-1)) / norm
        loss.backward()
        if self.opt is not None:
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss.item() * norm

In [7]:
# Train the simple copy task.
V = 11
criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
model = make_model(V, V, N=2)
model_opt = NoamOpt(model.src_embed[0].d_model, 1, 400,
        torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

for epoch in range(10):
    model.train()
    run_epoch(data_gen(V, 30, 20), model, SimpleLossCompute(model.generator, criterion, model_opt))
    model.eval()
    print(run_epoch(data_gen(V, 30, 5), model, SimpleLossCompute(model.generator, criterion, None)))

Epoch Step: 1 Loss: 3.060419 Tokens per Sec: 944.641787
Epoch Step: 1 Loss: 1.880822 Tokens per Sec: 1648.200498
1.9212539196014404
Epoch Step: 1 Loss: 1.982388 Tokens per Sec: 1082.014248
Epoch Step: 1 Loss: 1.888684 Tokens per Sec: 1776.635450
1.843814730644226
Epoch Step: 1 Loss: 1.906977 Tokens per Sec: 1061.449515
Epoch Step: 1 Loss: 1.679079 Tokens per Sec: 1719.850288
1.7075483560562135
Epoch Step: 1 Loss: 1.862169 Tokens per Sec: 1025.468770
Epoch Step: 1 Loss: 1.311242 Tokens per Sec: 1712.583456
1.3185369491577148
Epoch Step: 1 Loss: 1.531866 Tokens per Sec: 1017.535514
Epoch Step: 1 Loss: 1.072929 Tokens per Sec: 1717.915738
1.138299822807312
Epoch Step: 1 Loss: 1.348824 Tokens per Sec: 1067.400799
Epoch Step: 1 Loss: 0.744014 Tokens per Sec: 1749.416579
0.670381736755371
Epoch Step: 1 Loss: 0.825399 Tokens per Sec: 1092.575956
Epoch Step: 1 Loss: 0.315368 Tokens per Sec: 1698.637490
0.3549260199069977
Epoch Step: 1 Loss: 0.495665 Tokens per Sec: 1042.232375
Epoch Step: 1 Lo

In [8]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src))
        print(out.shape)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word.item())], dim=1)
    return ys

model.eval()
src = torch.LongTensor([[1,3,4,5,6,6,9,4,5,2]])
src_mask = torch.ones(1, 1, 10)
print(greedy_decode(model, src, src_mask, max_len=10, start_symbol=1))

torch.Size([1, 1, 512])
torch.Size([1, 2, 512])
torch.Size([1, 3, 512])
torch.Size([1, 4, 512])
torch.Size([1, 5, 512])
torch.Size([1, 6, 512])
torch.Size([1, 7, 512])
torch.Size([1, 8, 512])
torch.Size([1, 9, 512])
tensor([[1, 3, 4, 5, 6, 6, 9, 4, 5, 2]])
