In [37]:
from __future__ import unicode_literals, print_function, division

import torch
from torch import optim
import torch.nn as nn
import numpy as np
import string
import re
from torch.utils.data import Dataset
import pickle as pk

from tqdm import tqdm
import unicodedata

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [39]:
PAD_idx = 0
SOS_idx = 1
EOS_idx = 2
UNK_idx= 3
batch_size = 64
MAX_SENTENCE_LENGTH = 40

LOG = "Trial_Dec14_2"
EPOCH = 50
label_smoothing = True
SAVE_MODEL = 'best'

In [40]:
class language(object):
    def __init__(self, name, i2t, t2i, embedding_matrix, train, test, val):
        self.name = name
        self.idx2token = i2t
        self.token2idx = t2i
        self.embedding_mat = embedding_matrix
        self.train_idx = train
        self.test_idx = test
        self.val_idx = val

In [41]:
data = pk.load(open('../data/vi1.1w-en6k.p', 'rb'))

In [42]:
def parser(sent, max_len):
    if len(sent) > max_len-2:
        return [SOS_idx]+sent[:max_len-2]+[EOS_idx]
    else:
        return [SOS_idx]+sent+[EOS_idx] + [PAD_idx]*(max_len-2-len(sent))

def paired_collate_fn(insts):
    src_insts, tgt_insts = list(zip(*insts))
    src_insts = collate_fn(src_insts)
    tgt_insts = collate_fn(tgt_insts)
    return (*src_insts, *tgt_insts)

def collate_fn(insts):
    ''' Pad the instance to the max seq length in batch '''

    max_len = MAX_SENTENCE_LENGTH
    batch_seq = np.array([parser(inst, max_len) for inst in insts])
    batch_pos = np.array([
        [pos_i+1 if w_i != PAD_idx else 0
         for pos_i, w_i in enumerate(inst)] for inst in batch_seq])

    batch_seq = torch.LongTensor(batch_seq)
    batch_pos = torch.LongTensor(batch_pos)

    return batch_seq, batch_pos

class TranslationDataset(torch.utils.data.Dataset):
    def __init__(
        self, src_word2idx, tgt_word2idx,
        src_insts=None, tgt_insts=None):

        assert src_insts
        assert not tgt_insts or (len(src_insts) == len(tgt_insts))

        src_idx2word = {idx:word for word, idx in src_word2idx.items()}
        self._src_word2idx = src_word2idx
        self._src_idx2word = src_idx2word
        self._src_insts = src_insts

        tgt_idx2word = {idx:word for word, idx in tgt_word2idx.items()}
        self._tgt_word2idx = tgt_word2idx
        self._tgt_idx2word = tgt_idx2word
        self._tgt_insts = tgt_insts

    @property
    def n_insts(self):
        ''' Property for dataset size '''
        return len(self._src_insts)

    @property
    def src_vocab_size(self):
        ''' Property for vocab size '''
        return len(self._src_word2idx)

    @property
    def tgt_vocab_size(self):
        ''' Property for vocab size '''
        return len(self._tgt_word2idx)

    @property
    def src_word2idx(self):
        ''' Property for word dictionary '''
        return self._src_word2idx

    @property
    def tgt_word2idx(self):
        ''' Property for word dictionary '''
        return self._tgt_word2idx

    @property
    def src_idx2word(self):
        ''' Property for index dictionary '''
        return self._src_idx2word

    @property
    def tgt_idx2word(self):
        ''' Property for index dictionary '''
        return self._tgt_idx2word

    def __len__(self):
        return self.n_insts

    def __getitem__(self, idx):
        if self._tgt_insts:
            return self._src_insts[idx], self._tgt_insts[idx]
        return self._src_insts[idx]

In [43]:
source_lan = {"name": data['src'].name,
              "token2id": data['src'].token2idx,
              "id2token": data['src'].idx2token,
              "train_instance": data['src'].train_idx,
              "val_instance": data['src'].val_idx,
              "test_instance": data['src'].test_idx
             }

target_lan = {
              "name": data['tgt'].name,
              "token2id": data['tgt'].token2idx,
              "id2token": data['tgt'].idx2token,
              "train_instance": data['tgt'].train_idx,
              "val_instance": data['tgt'].val_idx,
              "test_instance": data['tgt'].test_idx
             }

In [44]:
train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['src'].token2idx,
            tgt_word2idx=data['tgt'].token2idx,
            src_insts=data['src'].train_idx,
            tgt_insts=data['tgt'].train_idx),
        num_workers=2,
        batch_size=batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True)

val_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['src'].token2idx,
            tgt_word2idx=data['tgt'].token2idx,
            src_insts=data['src'].val_idx,
            tgt_insts=data['tgt'].val_idx),
        num_workers=2,
        batch_size=batch_size,
        collate_fn=paired_collate_fn)

test_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['src'].token2idx,
            tgt_word2idx=data['tgt'].token2idx,
            src_insts=data['src'].test_idx,
            tgt_insts=data['tgt'].test_idx),
        num_workers=2,
        batch_size=batch_size,
        collate_fn=paired_collate_fn,
        shuffle=False)

In [45]:
src_vocab_size = train_loader.dataset.src_vocab_size
tgt_vocab_size = train_loader.dataset.tgt_vocab_size
print ("Source: {}, vocab_size: {} \nTarget: {}, vocab_size: {}"
               .format(source_lan['name'], src_vocab_size, target_lan['name'], tgt_vocab_size))

Source: Vietnam, vocab_size: 11004 
Target: English, vocab_size: 6604


In [16]:
# The function perform dot product attention, similar to the one in the dot-product attention model

class ScaledDotProductAttention(nn.Module):
    def __init__(self, sqrt_dim, attn_dropout=0.1):
        super().__init__()
        self.sqrt_dim = sqrt_dim
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, query, key, value, mask=None):
        attn = torch.bmm(query, key.transpose(1, 2))
        attn = attn / self.sqrt_dim
#         Mask out the padded token
        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)
#         Convert to probability
        attn = self.softmax(attn)
#         import pdb
#         pdb.set_trace()
        attn = self.dropout(attn)
#         Apply the attention to the value
        output = torch.bmm(attn, value)
        return output, attn

In [17]:

class MultiHeadAttention(nn.Module):

    def __init__(self, num_head, hid_dim, key_dim, value_dim, dropout=0.1):
        super().__init__()

        self.num_head = num_head
        self.key_dim = key_dim
        self.value_dim = value_dim
        
#         Mapping to the desired dimension to ensure compatiability
        self.w_qs = nn.Linear(hid_dim, num_head * key_dim)
        self.w_ks = nn.Linear(hid_dim, num_head * key_dim)
        self.w_vs = nn.Linear(hid_dim, num_head * value_dim)
        
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (hid_dim + key_dim)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (hid_dim + key_dim)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (hid_dim + value_dim)))

        self.attention = ScaledDotProductAttention(sqrt_dim=np.power(key_dim, 0.5))
#         Layer normalization
        self.layer_norm = nn.LayerNorm(hid_dim)

        self.fc = nn.Linear(num_head * value_dim, hid_dim)
        nn.init.xavier_normal_(self.fc.weight)

        self.dropout = nn.Dropout(dropout)


    def forward(self, query, key, value, mask=None):

        d_k, d_v, n_head = self.key_dim, self.value_dim, self.num_head

#         Batch size, length of query, key and value
        sz_b, len_q, _ = query.size()
        sz_b, len_k, _ = key.size()
        sz_b, len_v, _ = value.size()

        residual = query

        q = self.w_qs(query).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(key).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(value).view(sz_b, len_v, n_head, d_v)
        
#         Swithch the head to the front
        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

#         Weighted combination (attention)
        mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)
        
#         Reshape 
        output = output.view(n_head, sz_b, len_q, d_v)
#         pdb.set_trace()
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)

        return output, attn

# The simple 2 layer feed forward layer on top of self attention layer
class FeedForward(nn.Module):
    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
#         Fully connected layer 1
        self.w_1 = nn.Conv1d(d_in, d_hid, 1) 
#         Fully connected layer 2
        self.w_2 = nn.Conv1d(d_hid, d_in, 1) 
        self.layer_norm = nn.LayerNorm(d_in)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        output = x.transpose(1, 2)
        output = self.w_2(F.relu(self.w_1(output)))
        output = output.transpose(1, 2)
        output = self.dropout(output)
        output = self.layer_norm(output + residual)
        return output

In [18]:
# Base layer for the full encoder (In original paper, 6 layers of this)
class EncoderLayer(nn.Module):
    
    def __init__(self, hid_dim, d_inner, num_head, key_dim, value_dim, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(
            num_head, hid_dim, key_dim, value_dim, dropout=dropout)
        self.pos_ffn = FeedForward(hid_dim, d_inner, dropout=dropout)

    def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
        enc_output *= non_pad_mask

        enc_output = self.pos_ffn(enc_output)
        enc_output *= non_pad_mask

        return enc_output, enc_slf_attn

# Base layer for the full decoder 2*self_attention+feed_forward(In original paper, 6 layers of this)
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = FeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
        dec_output, dec_slf_attn = self.slf_attn(
            dec_input, dec_input, dec_input, mask=slf_attn_mask)
        dec_output *= non_pad_mask

        dec_output, dec_enc_attn = self.enc_attn(
            dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output *= non_pad_mask

        dec_output = self.pos_ffn(dec_output)
        dec_output *= non_pad_mask

        return dec_output, dec_slf_attn, dec_enc_attn

In [19]:
# Get the mask for the one are not mask
def get_non_pad_mask(seq):
    return seq.ne(PAD_idx).type(torch.float).unsqueeze(-1)

# Get the the position encoding table
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
    
    def cal_angle(position, hid_idx):
        return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)

    def get_posi_angle_vec(position):
        return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
    
    sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
    


    

    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    if padding_idx is not None:
        # zero vector for padding dimension
        sinusoid_table[padding_idx] = 0.

    return torch.FloatTensor(sinusoid_table)


def get_attn_key_pad_mask(seq_k, seq_q):
    key_sequence = seq_k
    # Expand to fit the shape of key query attention matrix.
    len_q = seq_q.size(1)
    padding_mask = seq_k.eq(PAD_idx)
    padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # b x lq x lk

    return padding_mask

def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''

    sz_b, len_s = seq.size()
    subsequent_mask = torch.triu(
        torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls

    return subsequent_mask

class Encoder(nn.Module):
    ''' A encoder model with self attention mechanism. '''

    def __init__(
            self,
            n_src_vocab, len_max_seq, d_word_vec,
            src_embedding,
            n_layers, n_head, d_k, d_v,
            d_model, d_inner, dropout=0.1):

        super().__init__()

        n_position = len_max_seq + 1

        self.src_word_emb = nn.Embedding.from_pretrained(torch.from_numpy(src_embedding), freeze=True)

        self.position_enc = nn.Embedding.from_pretrained(
            get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
            freeze=True)

        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])

    def forward(self, src_seq, src_pos, return_attns=False):

        enc_slf_attn_list = []

        # -- Prepare masks
        slf_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=src_seq)
        non_pad_mask = get_non_pad_mask(src_seq)

        # -- Forward
        enc_output = self.src_word_emb(src_seq).float() + self.position_enc(src_pos).float()

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(
                enc_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask)
            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output,

class Decoder(nn.Module):
    ''' A decoder model with self attention mechanism. '''

    def __init__(
            self,
            n_tgt_vocab, len_max_seq, d_word_vec,
            tgt_embedding,
            n_layers, n_head, d_k, d_v,
            d_model, d_inner, dropout=0.1):

        super().__init__()
        n_position = len_max_seq + 1

        self.tgt_word_emb = nn.Embedding.from_pretrained(torch.from_numpy(tgt_embedding), freeze=True)

        self.position_enc = nn.Embedding.from_pretrained(
            get_sinusoid_encoding_table(n_position, d_word_vec, padding_idx=0),
            freeze=True)

        self.layer_stack = nn.ModuleList([
            DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])

    def forward(self, tgt_seq, tgt_pos, src_seq, enc_output, return_attns=False):

        dec_slf_attn_list, dec_enc_attn_list = [], []

        # -- Prepare masks
        non_pad_mask = get_non_pad_mask(tgt_seq)

        slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
        slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
        slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)

        dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=src_seq, seq_q=tgt_seq)

        # -- Forward
        dec_output = self.tgt_word_emb(tgt_seq).float() + self.position_enc(tgt_pos).float()

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output, enc_output,
                non_pad_mask=non_pad_mask,
                slf_attn_mask=slf_attn_mask,
                dec_enc_attn_mask=dec_enc_attn_mask)

            if return_attns:
                dec_slf_attn_list += [dec_slf_attn]
                dec_enc_attn_list += [dec_enc_attn]

        if return_attns:
            return dec_output, dec_slf_attn_list, dec_enc_attn_list
        return dec_output,

class Transformer(nn.Module):
    ''' A sequence to sequence model with attention mechanism. '''

    def __init__(
            self,
            n_src_vocab, n_tgt_vocab, len_max_seq, 
            src_embedding, tgt_embedding,
            d_word_vec=512, d_model=512, d_inner=2048,
            n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1,
            tgt_emb_prj_weight_sharing=True,
            emb_src_tgt_weight_sharing=True):

        super().__init__()

        self.encoder = Encoder(
            n_src_vocab=n_src_vocab, len_max_seq=len_max_seq, src_embedding = src_embedding,
            d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
            n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
            dropout=dropout)

        self.decoder = Decoder(
            n_tgt_vocab=n_tgt_vocab, len_max_seq=len_max_seq, tgt_embedding = tgt_embedding,
            d_word_vec=d_word_vec, d_model=d_model, d_inner=d_inner,
            n_layers=n_layers, n_head=n_head, d_k=d_k, d_v=d_v,
            dropout=dropout)

        self.tgt_word_prj = nn.Linear(d_model, n_tgt_vocab, bias=False)
        nn.init.xavier_normal_(self.tgt_word_prj.weight)

        assert d_model == d_word_vec, \
        'To facilitate the residual connections, \
         the dimensions of all module outputs shall be the same.'

        if tgt_emb_prj_weight_sharing:
            # Share the weight matrix between target word embedding & the final logit dense layer
            self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight
            self.x_logit_scale = (d_model ** -0.5)
        else:
            self.x_logit_scale = 1.

        if emb_src_tgt_weight_sharing:
            # Share the weight matrix between source & target word embeddings
            assert n_src_vocab == n_tgt_vocab, \
            "To share word embedding table, the vocabulary size of src/tgt shall be the same."
            self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight

    def forward(self, src_seq, src_pos, tgt_seq, tgt_pos):

        tgt_seq, tgt_pos = tgt_seq[:, :-1], tgt_pos[:, :-1]

        enc_output, *_ = self.encoder(src_seq, src_pos)
        dec_output, *_ = self.decoder(tgt_seq, tgt_pos, src_seq, enc_output)
        seq_logit = self.tgt_word_prj(dec_output) * self.x_logit_scale

        return seq_logit.view(-1, seq_logit.size(2))

In [20]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)*1

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

In [21]:
import math
import time

import torch.nn.functional as F
import torch.utils.data


def cal_performance(pred, gold, smoothing=False):
    ''' Apply label smoothing if needed '''

    loss = cal_loss(pred, gold, smoothing)

    pred = pred.max(1)[1]
    gold = gold.contiguous().view(-1)
    non_pad_mask = gold.ne(PAD_idx)
    n_correct = pred.eq(gold)
    n_correct = n_correct.masked_select(non_pad_mask).sum().item()

    return loss, n_correct


def cal_loss(pred, gold, smoothing):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''

    gold = gold.contiguous().view(-1)
#     pdb.set_trace()
    if smoothing:
        eps = 0.1
        n_class = pred.size(1)

        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        log_prb = F.log_softmax(pred, dim=1)

        non_pad_mask = gold.ne(PAD_idx)
        loss = -(one_hot * log_prb).sum(dim=1)
        loss = loss.masked_select(non_pad_mask).sum()  # average later
    else:
        loss = F.cross_entropy(pred, gold, ignore_index=PAD_idx, reduction='sum')

    return loss


def train_epoch(model, training_data, optimizer, device, smoothing):
    ''' Epoch operation in training phase'''

    model.train()

    total_loss = 0
    n_word_total = 0
    n_word_correct = 0
    count = 0
    for batch in tqdm(
            training_data, mininterval=2,
            desc='  - (Training)   ', leave=False):

        # prepare data
        src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
        gold = tgt_seq[:, 1:]
        
        # forward
        optimizer.zero_grad()

        pred = model(src_seq, src_pos, tgt_seq, tgt_pos)
        
        # backward
        loss, n_correct = cal_performance(pred, gold, smoothing=smoothing)
        loss.backward()

        # update parameters
        optimizer.step_and_update_lr()

        # note keeping
        total_loss += loss.item()

        non_pad_mask = gold.ne(PAD_idx)
        n_word = non_pad_mask.sum().item()
        n_word_total += n_word
        n_word_correct += n_correct
        
    loss_per_word = total_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    return loss_per_word, accuracy

def eval_epoch(model, validation_data, device):
    ''' Epoch operation in evaluation phase '''

    model.eval()

    total_loss = 0
    n_word_total = 0
    n_word_correct = 0

    with torch.no_grad():
        for batch in tqdm(
                validation_data, mininterval=2,
                desc='  - (Validation) ', leave=False):

            # prepare data
            src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
            gold = tgt_seq[:, 1:]

            # forward
            pred = model(src_seq, src_pos, tgt_seq, tgt_pos)
            loss, n_correct = cal_performance(pred, gold, smoothing=False)

            # note keeping
            total_loss += loss.item()

            non_pad_mask = gold.ne(PAD_idx)
            n_word = non_pad_mask.sum().item()
            n_word_total += n_word
            n_word_correct += n_correct

    loss_per_word = total_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    return loss_per_word, accuracy

def train(model, training_data, validation_data, optimizer, device, EPOCH):
    ''' Start training '''

    log_train_file = None
    log_valid_file = None

    if LOG:
        log_train_file = LOG + '.train.log'
        log_valid_file = LOG + '.valid.log'

        print('[Info] Training performance will be written to file: {} and {}'.format(
            log_train_file, log_valid_file))
        

        with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf:
            log_tf.write('epoch,loss,ppl,accuracy\n')
            log_vf.write('epoch,loss,ppl,accuracy\n')

    valid_accus = []
    for epoch_i in range(EPOCH):
        print('[ Epoch', epoch_i, ']')

        start = time.time()
        train_loss, train_accu = train_epoch(
            model, training_data, optimizer, device, smoothing=label_smoothing)
        print('  - (Training)   ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
              'elapse: {elapse:3.3f} min'.format(
                  ppl=math.exp(min(train_loss, 100)), accu=100*train_accu,
                  elapse=(time.time()-start)/60))

        start = time.time()
        valid_loss, valid_accu = eval_epoch(model, validation_data, device)
        print('  - (Validation) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
                'elapse: {elapse:3.3f} min'.format(
                    ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu,
                    elapse=(time.time()-start)/60))

        valid_accus += [valid_accu]

        model_state_dict = model.state_dict()
        checkpoint = {
            'model': model_state_dict,
#             'settings': opt,
            'settings': None,   
            'epoch': epoch_i}

        if SAVE_MODEL:
            if SAVE_MODEL == 'all':
                model_name = './chkpt_Dec14/'+SAVE_MODEL + '_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu)
                torch.save(checkpoint, model_name)
            elif SAVE_MODEL == 'best':
                model_name = './chkpt_Dec14/'+SAVE_MODEL + '6_layer.chkpt'
                if valid_accu >= max(valid_accus):
                    torch.save(checkpoint, model_name)
                    print('    - [Info] The checkpoint file has been updated.')
        
        
        
        if log_train_file and log_valid_file:
            with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf:
                log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
                    epoch=epoch_i, loss=train_loss,
                    ppl=math.exp(min(train_loss, 100)), accu=100*train_accu))
                log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
                    epoch=epoch_i, loss=valid_loss,
                    ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu))

In [22]:
class Beam():
    ''' Beam search '''

    def __init__(self, size, device=False):

        self.size = size
        self._done = False

        # The score for each translation on the beam.
        self.scores = torch.zeros((size,), dtype=torch.float, device=device)
        self.all_scores = []

        # The backpointers at each time-step.
        self.prev_ks = []

        # The outputs at each time-step.
        self.next_ys = [torch.full((size,), PAD_idx, dtype=torch.long, device=device)]
        self.next_ys[0][0] = SOS_idx

    def get_current_state(self):
        "Get the outputs for the current timestep."
        return self.get_tentative_hypothesis()

    def get_current_origin(self):
        "Get the backpointers for the current timestep."
        return self.prev_ks[-1]

    @property
    def done(self):
        return self._done

    def advance(self, word_prob):
        "Update beam status and check if finished or not."
        num_words = word_prob.size(1)

        # Sum the previous scores.
        if len(self.prev_ks) > 0:
            beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
        else:
            beam_lk = word_prob[0]

        flat_beam_lk = beam_lk.view(-1)

        best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
        best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort

        self.all_scores.append(self.scores)
        self.scores = best_scores

        # bestScoresId is flattened as a (beam x word) array,
        # so we need to calculate which word and beam each score came from
        prev_k = best_scores_id / num_words
        self.prev_ks.append(prev_k)
        self.next_ys.append(best_scores_id - prev_k * num_words)

        # End condition is when top-of-beam is EOS.
        if self.next_ys[-1][0].item() == EOS_idx:
            self._done = True
            self.all_scores.append(self.scores)

        return self._done

    def sort_scores(self):
        "Sort the scores."
        return torch.sort(self.scores, 0, True)

    def get_the_best_score_and_idx(self):
        "Get the score of the best in the beam."
        scores, ids = self.sort_scores()
        return scores[1], ids[1]

    def get_tentative_hypothesis(self):
        "Get the decoded sequence for the current timestep."

        if len(self.next_ys) == 1:
            dec_seq = self.next_ys[0].unsqueeze(1)
        else:
            _, keys = self.sort_scores()
            hyps = [self.get_hypothesis(k) for k in keys]
            hyps = [[SOS_idx] + h for h in hyps]
            dec_seq = torch.LongTensor(hyps)

        return dec_seq

    def get_hypothesis(self, k):
        """ Walk back to construct the full hypothesis. """
        hyp = []
        for j in range(len(self.prev_ks) - 1, -1, -1):
            hyp.append(self.next_ys[j+1][k])
            k = self.prev_ks[j][k]

        return list(map(lambda x: x.item(), hyp[::-1]))

In [23]:
class Translator(object):
# The abstract for testing the function, implemented the beam search inside.
    def __init__(self, model_path, beam_size, n_best, device):
#         self.opt = opt
        self.beam_size = beam_size
        self.n_best = n_best
        self.model_path = model_path
        
        self.device = device

        checkpoint = torch.load(model_path)

        model = Transformer(
        src_vocab_size,
        tgt_vocab_size,
        MAX_SENTENCE_LENGTH,
        src_embedding = data['src'].embedding_mat,
        tgt_embedding = data['tgt'].embedding_mat,
        tgt_emb_prj_weight_sharing=False,
        emb_src_tgt_weight_sharing=False,
        d_word_vec=300,
        d_model = 300,
        n_layers=3,
        n_head=8,
        dropout=0.1).to(device)

        model.load_state_dict(checkpoint['model'])
        print('[Info] Trained model state loaded.')

        model.word_prob_prj = nn.LogSoftmax(dim=1)

        model = model.to(self.device)

        self.model = model
        self.model.eval()

    def translate_batch(self, src_seq, src_pos):
        ''' Translation work in one batch '''

        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
            return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}

        def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
            ''' Collect tensor parts associated to active instances. '''

            _, *d_hs = beamed_tensor.size()
            n_curr_active_inst = len(curr_active_inst_idx)
            new_shape = (n_curr_active_inst * n_bm, *d_hs)

            beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
            beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
            beamed_tensor = beamed_tensor.view(*new_shape)

            return beamed_tensor

        def collate_active_info(
                src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
            # Sentences which are still active are collected,
            # so the decoder will not run on completed sentences.
            n_prev_active_inst = len(inst_idx_to_position_map)
            active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
            active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

            active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
            active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
            active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)

            return active_src_seq, active_src_enc, active_inst_idx_to_position_map

        def beam_decode_step(
                inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm):
            ''' Decode and update beam status, and then return active beam idx '''

            def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
                dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
                dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
                dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
                return dec_partial_seq

            def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
                dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device)
                dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
                return dec_partial_pos

            def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm):
                dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output)
                dec_output = dec_output[:, -1, :]  # Pick the last step: (bh * bm) * d_h
                word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1)
                word_prob = word_prob.view(n_active_inst, n_bm, -1)

                return word_prob

            def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
                active_inst_idx_list = []
                for inst_idx, inst_position in inst_idx_to_position_map.items():
                    is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
                    if not is_inst_complete:
                        active_inst_idx_list += [inst_idx]

                return active_inst_idx_list

            n_active_inst = len(inst_idx_to_position_map)

            dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
            dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
            word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm)

            # Update the beam with predicted word prob information and collect incomplete instances
            active_inst_idx_list = collect_active_inst_idx_list(
                inst_dec_beams, word_prob, inst_idx_to_position_map)

            return active_inst_idx_list

        def collect_hypothesis_and_scores(inst_dec_beams, n_best):
            all_hyp, all_scores = [], []
            for inst_idx in range(len(inst_dec_beams)):
                scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
                all_scores += [scores[:n_best]]

                hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
                all_hyp += [hyps]
            return all_hyp, all_scores

        with torch.no_grad():
            #-- Encode
            src_seq, src_pos = src_seq.to(self.device), src_pos.to(self.device)
            src_enc, *_ = self.model.encoder(src_seq, src_pos)

            #-- Repeat data for beam search
            n_bm = self.beam_size
            n_inst, len_s, d_h = src_enc.size()
            src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s)
            src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)

            #-- Prepare beams
            inst_dec_beams = [Beam(n_bm, device=self.device) for _ in range(n_inst)]

            #-- Bookkeeping for active or not
            active_inst_idx_list = list(range(n_inst))
            inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)

            #-- Decode
            for len_dec_seq in range(1, MAX_SENTENCE_LENGTH + 1):

                active_inst_idx_list = beam_decode_step(
                    inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm)

                if not active_inst_idx_list:
                    break  # all instances have finished their path to <EOS>

                src_seq, src_enc, inst_idx_to_position_map = collate_active_info(
                    src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list)

        batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, self.n_best)

        return batch_hyp, batch_scores

In [None]:
import pdb

transformer = Transformer(
        src_vocab_size,
        tgt_vocab_size,
        MAX_SENTENCE_LENGTH,
        src_embedding = data['src'].embedding_mat,
        tgt_embedding = data['tgt'].embedding_mat,
        tgt_emb_prj_weight_sharing=False,
        emb_src_tgt_weight_sharing=False,
        d_word_vec=300,
        d_model = 300,
        n_layers=6,
        n_head=8,
        dropout=0.1).to(device)

checkpoint = torch.load("./chkpt_Dec14/best6_layer.chkpt")
transformer.load_state_dict(checkpoint['model'])
print('[Info] Trained model state loaded.')

optimizer = ScheduledOptim(
        optim.Adam(
            filter(lambda x: x.requires_grad, transformer.parameters()),
            betas=(0.9, 0.98), eps=1e-09),
        300, 4000)

train(transformer, train_loader, val_loader, optimizer, device, 60)

  - (Training)   :   0%|          | 0/2084 [00:00<?, ?it/s]

[Info] Trained model state loaded.
[Info] Training performance will be written to file: Trial_Dec14_2.train.log and Trial_Dec14_2.valid.log
[ Epoch 0 ]


  - (Validation) :   0%|          | 0/20 [00:00<?, ?it/s]             

  - (Training)   ppl:  108.04744, accuracy: 30.037 %, elapse: 18.662 min


                                                                  

  - (Validation) ppl:  59.17836, accuracy: 29.008 %, elapse: 0.063 min


  - (Training)   :   0%|          | 0/2084 [00:00<?, ?it/s]

    - [Info] The checkpoint file has been updated.
[ Epoch 1 ]


  - (Validation) :   0%|          | 0/20 [00:00<?, ?it/s]             

  - (Training)   ppl:  121.90687, accuracy: 28.491 %, elapse: 18.583 min


  - (Training)   :   0%|          | 0/2084 [00:00<?, ?it/s]       

  - (Validation) ppl:  67.86104, accuracy: 26.959 %, elapse: 0.064 min
[ Epoch 2 ]


  - (Training)   :  42%|████▏     | 884/2084 [07:53<10:42,  1.87it/s]

In [59]:
from sacrebleu import raw_corpus_bleu
from sacrebleu import corpus_bleu
import pdb

def compute_bleu(translator_path):
    pred_sent = []
    ref_sent = []

    translator = Translator(translator_path, 1, 1, device)
#     translator.eval()
    for batch in tqdm(test_loader, mininterval=2, desc='  - (Test)', leave=False):
        src_seq, src_pos, tgt_seq, tgt_pos = batch
        all_hyp, all_scores = translator.translate_batch(src_seq, src_pos)
    #     pdb.set_trace()
        ref_sent+=[[test_loader.dataset.tgt_idx2word[index] for index in sent] for sent in tgt_seq.numpy()]

        for idx_seqs in all_hyp:
            for idx_seq in idx_seqs:
                pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])
                pred_sent.append(pred_line)
#         break
#     pdb.set_trace()
    _ref_sent = [[pad_parser(i) for i in ref_sent]]
    _pred_sent = [pad_parser(i.split(' ')) for i in pred_sent]
    print ("Corpus bleu:", corpus_bleu(_pred_sent, _ref_sent).score)

In [60]:
def pad_parser(word_sequence):
    '''
    Remove the pad index after the eos index
    :param word_sequence: list of the predicted words
    :return: list of predicted words before <eos>
    '''
    eos = [i for i,word in enumerate(word_sequence) if word=='.' or word=='?' or word=='<EOS>']
    word_sequence = word_sequence[:eos[-1]] if len(eos)!= 0 else word_sequence
    return ' '.join(word_sequence)

In [61]:
compute_bleu('./chkpt_Dec14/best.chkpt')

  - (Test):   0%|          | 0/25 [00:00<?, ?it/s]

[Info] Trained model state loaded.


                                                           

> <ipython-input-59-c92daa3fc3c6>(23)compute_bleu()
-> _ref_sent = [[pad_parser(i) for i in ref_sent]]
(Pdb) c
Corpus bleu: 4.3963106233489455
