dataset : https://github.com/multi30k/dataset
<br>
https://github.com/kh-kim/simple-nmt/blob/master/simple_nmt/models/seq2seq.py<br><br>
https://pytorch.org/tutorials/beginner/torchtext_translation_tutorial.html

In [1]:
import torch
import torch.nn as nn
from torchtext import data, datasets
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack

# import simple_nmt.data_loader as data_loader
# from simple_nmt.search import SigleBeamSearchBoard

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/multi30k-en-fr/val.en
/kaggle/input/multi30k-en-fr/train.fr
/kaggle/input/multi30k-en-fr/val.fr
/kaggle/input/multi30k-en-fr/train.en


In [2]:
# link에서 다운 받아서 하는 방법을 모르겠다
multi30k_path = "/kaggle/input/multi30k-en-fr/"

train_list = ["train.fr","train.en"]
valid_list = ["val.fr","val.en"]
# test_list = ["test_2016_flickr.fr.gz","test_2016_flickr.en.gz"]

In [3]:
# load train data
# 그냥 txt 파일처럼 읽어버리면 되는 거였다.
with open(multi30k_path+train_list[0], 'rb') as fr_path:
    fr_train = fr_path.readlines()
with open(multi30k_path+train_list[1], 'rb') as en_path:
    en_train = en_path.readlines()
    
# load validation data
with open(multi30k_path+valid_list[0], 'rb') as fr_path:
    fr_val = fr_path.readlines()
with open(multi30k_path+valid_list[1], 'rb') as en_path:
    en_val = en_path.readlines()
    
# test data는 트레이닝 끝난 다음에 하기

print(f"training_french len is: {len(fr_train)}")
print(f"training_english len is: {len(en_train)}")
print()
print(f"validate_french len is: {len(fr_val)}")
print(f"validate_english len is: {len(en_val)}")

training_french len is: 29000
training_english len is: 29000

validate_french len is: 1014
validate_english len is: 1014


### Search

In [4]:
from operator import itemgetter
LENGTH_PENALTY = .2
MIN_LENGTH = 5

class SingleBeamSearchBoard():
    def __init__(self, 
                 device, 
                 prev_status_config,
                 beam_size=5,
                max_length=255):
        
        self.beam_size = beam_size
        self.max_length = max_length
        
        # To put data to same device
        self.device = device
        # Inferred word index for each time-step. For now, init with init time-step
        self.word_indice = [torch.LongTensor(beam_size).zero_().to(self.device) + data_loader.BOS]
        # Beam index for selected word index, at each time-step
        self.beam_indice = [torch.LongTensor(beam_size).zero_().to(self.device)-1]
        # cumulative log-probability for each beam
        self.cumulative_probs = [torch.FloatTensor([.0]+[-float('inf')]*(beam_size-1)).to(self.device)]
        # 1 if it is done else 0
        self.masks = [torch.BoolTensor(Beam_size).zero_().to(self.device)]
        
        # we don't need to remember every time-step of hidden states
        #        prev_hidden, prev_cell, prev_h_t_tilde
        # what we need is remember just last one
        self.prev_status = {}
        self.batch_dims = {}
        for prev_status_name, each_config in prev_status_config.items():
            init_status = each_config['init_status']
            batch_dim_index = each_config['batch_dim_index']
            if init_status is not None:
                self.prev_status[prev_status_name] = torch.cat([init_status]*beam_size,
                                                              dim=batch_dim_index)
            else:
                self.prev_status[prev_status_name] = None
            self.batch_dims[prev_status_name] = batch_dim_index
            
        self.current_time_step = 0
        self.done_cnt = 0
        
    def get_length_penalty(self, 
                          length,
                          alpha=LENGTH_PENALTY,
                          min_length=MIN_LENGTH):
        # calculate length-penalty
        # because shorter sentence usually have bigger probabilty
        # In fact, we represent this as log-probability, which is negative value
        # Thus, we need to multiply bigger penalty for shorter one
        p = ((min_length+1)/(min_length+length))**alpha
        
        return p
    
    def is_done(self):
        # return 1, if we had EOS more than 'beam_size'-times
        if self.done_cnt >= self.beam_size:
            return 1
        return 0
    
    def get_batch(self):
        y_hat = self.word_indice[-1].unsqueeze(-1)
        return y_hat, self.prev_status
    
    def collect_result(self, y_hat, prev_status):
        output_size = y_hat.size(-1)
        
        self.current_time_step += 1
    
        cumulative_prob = self.cumulative_probs[-1].masked_fill_(self.masks[-1], -float('inf'))
        cumulative_prob = y_hat + cumulative_prob.view(-1, 1, 1).expand(self.beam_size, 1, output_size)
        # now, we have new top log-probability and its index
        # we picked top index as many as 'beam_size'
        # be aware that we picked top-k from whole batch through 'view(-1)'
        
        # following lines are using torch.sort, instead of using torch.topk
        top_log_prob, top_indice = cumulative_prob.view(-1).sort(descending=True)
        top_log_prob, top_indice = top_log_prob[:self.beam_size], top_indice[:self.beam_size]
        
        # because we picked from whole batch, original word index should be calculated again
        self.word_indice += [top_indice.fmod(output_size)]
        # also, we can get an index of beam, which has top-k log-probability search result
        self.beam_indice += [top_indice.div(float(output_size)).long()]
        
        # add results to history boards
        self.cumulative_probs += [top_log_prob]
        self.masks += [torch.eq(self.word_indice[-1], data_loader.EOS)]  # set finish mask if we got EOS
        # calculate a number of finished beams
        self.done_cnt += self.masks[-1].float().sum()
        
        # In beam search procedure, we only need to memorize latest status
        # for seq2seq, it would be latest hidden and cell state, and h_t_tilde
        # The problem is hidden(or cell) state and h_t_tilde has different dimension order
        
        # In other words, a dimension for batch index is different
        # For transformer, latest status is each layer's decoder output from the beginning
        # Unlike seq2seq, transformer has to memorize every previous output for attention operation
        for prev_status_name, prev_status in prev_status.items():
            self.prev_status[prev_status_name] = torch.index_select(
                                                 prev_status,
                                                dim=self.batch_dims[prev_status_name],
                                                index=self.beam_indice[-1]).contiguous()
            
    def get_n_best(self, n=1, length_penalty=.2):
        sentences, probs, founds = [], [], []
        
        for t in range(len(self.word_indce)):    # for each time-step
            for b in range(self.beam_size):     # for each beam
                if self.masks[t][b] == 1:  # if we had EOS on this time-step and beam
                    # take a record of penaltified log-probability
                    probs += [self.cumulative_probs[t][b] *\
                                     self.get_length_penalty(t, alpha=length_penalty)]
                    founds += [(t,b)]
                    
        sorted_founds_with_probs = sorted(
                                    zip(founds, probs),
                                    key=itemgetter(1),
                                    reverse=True)[:n]
        probs = []
        
        for (end_index, b), prob in sorted_founds_with_probs:
            sentence = []
            
            # trace from the end
            for t in range(end_index, 0, -1):
                sentence = [self.word_indice[t][b]] + sentence
                b = self.beam_indice[t][b]
                
            sentences += [sentence]
            probs += [prob]
            
        return sentences, probs

### Dataloader

In [5]:
PAD,BOS,EOS = 1,2,3

class DataLoader():
    def __init__(self,
                train_fn=None,
                valid_fn=None,
                exts=None,
                batch_size=64,
                device='cpu',
                max_vocab=999999999,
                max_length=255,
                fix_length=None,
                use_bos=True,
                use_eos=True,
                shuffle=True,
                dsl=False):
        
        super(DataLoader, self).__init__()
        
        self.src = data.Field(
                    sequential=True,
                    use_vocab=True,
                    batch_first=True,
                    include_lengths=True,
                    fix_length=fix_length,
                    init_token="<BOS>" if dsl else None,
                    eos_token="<EOS>" if dsl else None)
        
        self.tgt = data.Field(
                    sequential=True,
                    use_vocab=True,
                    batch_first=True,
                    include_lengths=True,
                    fix_length=fix_length,
                    init_token="<BOS>" if use_bos else None,
                    eos_token="<EOS>" if use_eos else None)
        
        if train_fn is not None and valid_fn is not None and exts is not None:
            train = TranslationDataset(
                path=train_fn,
                exts=exts,
                fields=[('src',self.src),('tgt',self.tgt)],
                max_length=max_length)
            
            valid = TranslationDataset(
                path=valid_fn,
                exts=exts,
                fields=[("src",self.src),("tgt",self.tgt)],
                max_length=max_length)
            
            self.train_iter = data.BucketIterator(
                train,
                batch_size=batch_size,
                device="cuda" if cuda.is_available() else 'cpu',
                shuffle=shuffle,
                sort_key=lambda x:len(x.tgt)+(max_length*len(x.src)),
                sort_within_batch=True)
            
            self.valid_iter = data.BucketIterator(
                valid,
                batch_size=batch_size,
                device='cuda' if cuda.is_available() else 'cpu',
                shuffle=False,
                sort_key=lambda x:len(x.tgt)+(max_length*len(x.src)),
                sort_within_batch=True)
            
            self.src.build_vocab(train, max_size=max_vocab)
            self.tgt.build_vocab(train, max_size=max_vocab)
            
        
        def load_vocab(self, src_vocab, tgt_vocab):
            self.src.vocab = src_vocab
            self.tgt.vocab = tgt_vocab

In [6]:
class TranslationDataset(data.Dataset):
    """Defines a dataset for machine translation"""
    
    @staticmethod
    def sort_key(ex):
        return data.interleave_keys(len(ex.src),len(ex.trg))
    
    def __init__(self, path, exts, fields, max_length=None, **kwards):
        if not isinstance(fields[0], (tuple, list)):
            fields = [("src",fields[0]),("trg",fields[1])]
            
        if not path.endswith("."):
            path += "."
            
        src_path, trg_path = tuple(os.path.expanduser(path+x) for x in exts)
        examples = []
        
        with open(src_path, encoding='uf-8') as src_file, open(trg_path,encoding='utf-8') as trg_file:
            for src_line, trg_line in zip(src_file, trg_file):
                src_line, trg_line = src_line.strip(), trg_line.strip()
                if max_length and max_length < max(len(src_line.split()), len(trg_line.split())):
                    continue
                if src_line != '' and trg_line != '':
                    examples += [data.Example.fromlist([src_line,trg_line], fields)]
                    
        super().__init__(examples, fields, **kwargs)

## Utils

In [7]:
from operator import itemgetter

@torch.no_grad()
def get_grad_norm(parameters, norm_type=2):
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    total_norm = 0
    
    try:
        for p in parameters:
            total_norm += (p.grad.data ** norm_type).sum()
        total_norm = total_norm ** (1. / norm_type)
    except Exception as e:
        print(e)
    
    return total_norm


@torch.no_grad()
def get_paramter_norm(parameters, norm_type=2):
    total_norm = 0
    
    try:
        for p in parameters:
            total_norm += (p.data ** norm_type).sum()
        total_norm = total_norm ** (1. / norm_type)
    except Exception as e:
        print(e)
    
    return total_norm


def sort_by_length(x, lengths):
    batch_size = x.size(0)
    x = [x[i] for i in range(batch_size)]
    lengths = [lengths[i] for i in range(batch_size)]
    orders = [i for i in range(batch_size)]
    
    sorted_outputs = sorted(zip(x, lengths, orders), key=itemgetter(1), reverse=True)
    sorted_x = torch.stack([sorted_tuples[i][0] for i in range(batch_size)])
    sorted_lengths = torch.stack(sorted_tuples[i][1] for i in range(batch_size))
    sorted_orders = [sorted_tuples[i][2] for i in range(batch_size)]
    
    return sorted_x, sorted_lengths, sorted_orders



def sort_by_order(x, orders):
    batch_size = x.size(0)
    x = [x[i] for i in range(batch_size)]
    sorted_tuples = sorted(zip(x, orders), key=itemgetter(1))
    sorted_x = torch.stack([sorted_tuples[i][0] for i in range(batch_size)])
    
    return sorted_x

## Seq2Seq

In [8]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, h_src, h_t_tgt, mask=None):      
        query = self.linear(h_t_tgt)
        weight = torch.bmm(query, h_src.transpose(1,2))
        
        if mask is not None:
            # set each weight as -inf, if the mask value equals to 1.
            # Since the softmax operation makes -inf to 0,
            # masked weights would be set to 0 after softmax operation.
            # Thus, if the sample is shorter than other samples in mini-batch,
            # the weight for empty time-step would be set to 0.
            weight.masked_fill_(mask.unsqueeze(1), -float('inf'))
        
        weight = self.softmax(weight)
        context_vector = torch.bmm(weight, h_src)
        
        return context_vector

In [9]:
class Encoder(nn.Module):
    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Encoder, self).__init__()
        # Be aware of value of 'batch_first' parameter
        # Also, its hidden_size is half of original hidden_size,
        # because it is bidirectional ###TODO: 웨??????
        
        self.rnn = nn.LSTM(word_vec_size,
                          int(hidden_size/2),
                          num_layers = n_layers,
                          dropout=dropout_p,
                          bidirectional=True,
                          batch_first=True)
        
    def forward(self, emb):
        # |emb| = (batch_size, length, word_vec_size)
        if isinstance(emb, tuple):
            x, lengths = emb
            x = pack(x, lengths.tolist(), batch_first=True)
            
            # Below is how pack_padded_sequence works.
            # as you can see,
            # PackedSequence object has information about mini-batch-wise information,
            # not time-step-wise information.
        else:
            x = emb
            
        y, h = self.rnn(x)
        if isinstance(emb, tuple):
            y, _ = unpack(y, batch_first=True)
            
        return y, h

In [10]:
class Decoder(nn.Module):
    def __init__(self, word_vec_size, hidden_size, n_layers=4, dropout_p=.2):
        super(Decoder, self).__init__()
        
        # Be aware of value of 'batch_first' parameter and 'bidirectional' parameter
        self.rnn = nn.LSTM(
                word_vec_size + hidden_size,
                hidden_size,
                num_layers=n_layers,
                dropout=dropout_p,
                bidirectional=False,
                batch_first=True)
        
        
    def forward(self, emb_t, h_t_1_tilde, h_t_1):
        batch_size = emb_t.size(0)
        hidden_size = h_t_1[0].size(-1)
        
        if h_t_1_tilde is None:
            # if this is the first time-step
            h_t_1_tilde = emb_t.new(batch_size, 1, hidden_size).zero_()
            
        # input feeding trick
        x = torch.cat([emb_t, h_t_1_tilde], dim=-1)
        
        # unlike encoder, decoder must take an input for sequentially
        y, h = self.rnn(x, h_t_1)
        return y, h

In [11]:
class Generator(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(Generator, self).__init__()
        self.output = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        
    def forward(self, x):      
        y = self.softmax(self.output(x))
        
        # return log-probability instead of just probability
        return y

In [12]:
class Seq2Seq(nn.Module):
    def __init__(self, input_size, word_vec_size, hidden_size, output_size,
                 n_layers=4, dropout_p=.2) :
        
        self.input_size = input_size
        self.word_vec_size = word_vec_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        
        super(Seq2Seq, self).__init__()
        
        self.emb_src = nn.Embedding(input_size, word_vec_size)
        self.emb_dec = nn.Embedding(output_size, word_vec_size)
        
        self.encoder = Encoder(word_vec_size, hidden_size,
                              n_layers=n_layers, dropout_p=dropout_p)
        
        self.decoder = Decoder(word_vec_size, hidden_size,
                              n_layers=n_layers, dropout_p=dropout_p)
        
        self.attn = Attention(hidden_size)
        
        self.concat = nn.Linear(hidden_size * 2, hidden_size)
        self.tanh = nn.Tanh()
        self.generator = Generator(hidden_size, output_size)
        
    def generate_mask(self, x, length):
        mask = []
        
        max_length = max(length)
        for l in length:
            if max_length - l > 0:
                # if the length is shorter than maximum length among samples,
                # set last few values to be 1s to remove attention weight.
                mask += [torch.cat([x.new_ones(1,l).zero_(),
                                   x.new_ones(1, (max_length-l))],
                                  dim=-1)]
                
            else:
                # if the length of the sample equals to maximum length among samples,
                # set every value in mask to be 0.
                mask += [x.new_ones(1,l).zero_()]
                
        mask = torch.cat(mask, dim=0).bool()
        return mask
    
    
    def merge_encoder_hiddens(self, encoder_hiddens):
        new_hiddens = []
        new_cells = []
        
        hiddens, cells = encoder_hiddens
        
        # i-th and (i+l)-th layer is opposite direction.
        # also, each direction of layer is half hidden size
        # therefore, we concatenate both directions to 1 hidden size layer
        for i in range(0, hiddens.size(0), 2):
            new_hiddens += [torch.cat([hiddens[i], hiddens[i+1]], dim=-1)]
            new_cells += [torch.cat([cells[i], cells[i+1]], dim=-1)]
            
        new_hiddens, new_cells = torch.stack(new_hiddens), torck.stack(new_cells)
        return (new_hiddens, new_cells)
    
    def fast_merge_encoder_hiddens(self, encoder_hiddens):
        # merge bidirectional to uni-directional
        # we need to convert size from (n_layers*2, batch_size, hidden_size/2)
        # to (n_layers, batch_size, hidden_size)
        # Thus, the converting operation will not working with just 'view' method
        h_0_tgt, c_0_tgt = encoder_hiddens
        batch_size = h_0_tgt.size(1)
        
        h_0_tgt = h_0_tgt.transpose(0,1).contiguous().view(batch_size,
                                                          -1,
                                                self.hidden_size).transpose(0,1).contiguous()
        c_0_tgt = c_0_tgt.transpose(0,1).contiguous().view(batch_size,
                                                          -1,
                                                          self.hidden_size).transpose(0,1).contiguous()
        # you can use 'merge_encoder_hiddens' method, instead of using abovec 3 lines
        # 'merge_encoder_hiddens' method works with non-parallel way
        
        return h_0_tgt, c_0_tgt
    
    def forward(self, src, tgt):
        batch_size = tgt.size(0)
        
        mask = None
        x_length = None
        if isinstance(src, tuple):
            x, x_length = src
            # based on the length information, generate mask to prevent that
            # shorter sample has wasted attention
            mask = self.generate_mask(x, x_length)
        else:
            x = src
            
        if isinstance(tgt, tuple):
            tgt = tgt[0]
            
        # get word embedding vectors for every time-step of input sentence
        emb_src = self.emb_src(x)
    
        # the last hidden state of the encoder would be a initial hidden state of decoder
        h_src, h_0_tgt = self.encoder((emb_src, x_length))
    
        h_0_tgt = self.fast_merge_encoder_hiddens(h_0_tgt)
        emb_tgt = self.emb_dec(tgt)

        h_tilde = []
        
        h_t_tilde = None
        decoder_hidden = h_0_tgt
        # run decoder until the end of the time-step
        for t in range(tgt.size(1)):
            # Teacher forcing: take each input from training set
            # not from the last time-step's output
            # because of Teacher forcing
            # training procedure and inference procedure becomes different
            # of course, because of sequential running in decoder,
            # this causes servere bottle-neck
            emb_t = emb_tgt[:, t, :].unsqueeze(1)
            
            decoder_output, decoder_hidden = self.decoder(emb_t, 
                                                         h_t_tilde,
                                                         decoder_hidden)

            context_vector = self.attn(h_src, decoder_output, mask)
            h_t_tilde = self.tanh(self.concat(torch.cat([decoder_output,
                                                        context_vector],
                                                       dim=-1)))

            h_tilde += [h_t_tilde]
        
        h_tilde = torch.cat(h_tilde, dim=1)
        y_hat = self.generator(h_tilde)
        return y_hat
    
    
    def search(self, src, is_greedy=True, max_length=255):
        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
        else:
            x, x_length = src, None
            mask = None
        batch_size = x.size(0)
        
        # same procedure as teacher forcicng
        emb_src = self.emb_src(x)
        h_src, h_0_tgt = self.encoder((emb_src, x_length))
        decoder_hidden = self.fast_mergge_encoder_hiddens(h_0_tgt)
        
        # fill a vetor, which has 'batch_size' dimension, with BOS value
        y = x.new(batch_-size, 1).zero_() + data_loader.BOS
        
        is_decoding = x.new_ones(batch_size, 1).bool()
        h_t_tilde, y_hats, indice = None, [], []
        
        # repeat a loop while sum of 'is_decoding' flag is bigger than 0
        # or current time-step is smaller than maximum length
        while is_decoding.sum() > 0 and len(indice) < max_length:
            # Unlike training procedure,
            # take the last time-step's output during the inference
            emb_t = self.emb_dec(y)
            
            decoder_output, decoder_hidden = self.decoder(emb_t,
                                                         h_t_tilde,
                                                         decoder_hidden)
            context_vector = self.attn(h_src, decoder_output, mask)
            h_t_tilde = self.tanh(self.concat(torch.cat([decoder_output,
                                                        context_vector],
                                                       dim=-1)))
            y_hat = self.generator(h_t_tilde)
            y_hats += [y_hat]
            
            if is_greedy:
                y = y_hat.argmax(dim=-1)
            else:
                # Take a random sampling based on the multinoulli dist
                y = torch.multinomial(y_hat.exp().view(batch_size, -1),  1)
            
            # put PAD if the sample is done
            y = y.masked_fill_(~is_decoding, data_loader.PAD)
            # Update is_decoding if there is EOS token
            is_decoding = is_decoding * torch.ne(y, data_loader.EOS)

            indice += [y]
            
        y_hats = torch.cat(y_hats, dim=1)
        indice = torch.cat(indice, dim=1)
        return y_hats, indice
    

    def batch_beam_search(self,
                         src,
                         beam_size=5,
                         max_length=255,
                         n_best=1,
                         length_penalty=.2):

        mask, x_length = None, None
        if isinstance(src, tuple):
            x, x_length = src
            mask = self.generate_mask(x, x_length)
        else:
            x = src
        batch_size = x.size(0)
        
        emb_src = self.emb_src(x)
        h_src, h_0_tgt = self.encoder((emb_src, x_length))
        h_0_tgt = self.fast_merge_encoder(hiddens(h_0_tgt))
        
        # initialize 'SingleBeamSearchBoard' as many as batch_size
        boards = [SingleBeamSearchBoard(
                    h_src.device,
                    {
                        'hidden_state':{
                            'init_status':h_0_tgt[0][:, i, :].unsqueeze(1),
                            'batch_dim_index':1, 
                        },    # |hidden_state| = (n_layers, batch_size, hidden_size)
                        'cell_state':{
                            'init_status':h_0_tgt[1][:, i, :].unsqueeze(1),
                            'batch_dim_index':1,
                        },   # |cell_state| = (n_layers, batch_size, hidden_size)
                        'h_t_1_tilde':{
                            'init_status':None,
                            'batch_dim_index':0,
                        },   # |h_t_1_tilde| = (batch_size, 1, hidden_size)
                    },
            beam_size=beam_size, max_length=max_length) for i in range(batch_size)]
        is_done = [board.is_done() for board in boards]
        
        length = 0
        # Run loop while sum of 'is_done' is smaller than batch_size,
        # or length is still smaller than max_length
        while sum(is_done) < batch_size and length <= max_length:
            # current_batch_size = sum(is_done) * beam_size
            
            # initialize fabricated variables
            # As far as batch-beam-search is running
            # temporary batch-size for fabricate mini-batch is
            # 'beam_size'-times bigger than original batch_size
            fab_input, fab_hidden, fab_cell, fab_h_t_tilde = [], [], [], []
            fab_h_src, fab_mask = [], []
            
            # Build fabricated mini-batch in non-parallel way
            # this may cause a bottle-neck
            for i, board in enumerate(boards):
                # Batchify if the inference for the sample is still not finished
                if board.is_done() == 0:
                    y_hat_i, prev_status = board.get_batch()
                    hidden_i = prev_status['hidden_state']
                    cell_i = prev_status['cell_state']
                    h_t_tilde_i = prev_status['h_t_1_tilde']
                    
                    fab_input += [y_hat_i]
                    fab_hidden += [hidden_i]
                    fab_cell += [cell_i]
                    fab_h_src += [h_src[i, :, :]] * beam_size
                    fab_mask += [mask[i, :]] * beam_size
                    if h_t_tilde_i is not None:
                        fab_h_t_tilde += [h_t_tilde_i]
                    else:
                        fab_h_t_tilde = None
                        
            # Now, concatenate list of tensors
            fab_input = torch.cat(fab_input, dim=0)
            fab_hidden = torch.cat(fab_hidden, dim=1)
            fab_cell = torch.cat(fab_cell, dim=1)
            fab_h_src = torch.stack(fab_h_src)
            fab_mask = torch.stack(fab_mask)
            if fab_h_t_tilde is not None:
                fab_h_t_tilde = torch.cat(fab_h_t_tilde, dim=0)
            
            emb_t = self.emb_dec(fab_input)  
            fab_decoder_output, (fab_hidden, fab_cell) = self.decoder(emb_t, 
                                                                     fab_h_t_tilde,
                                                                     (fab_hidden, fab_cell))
            context_vector = self.attn(fab_h_src, fab_decoder_output, fab_mask)
            
            fab_h_t_tilde = self.tanh(self.concat(torch.cat([fab_decoder_output,
                                                            context_vector], 
                                                           dim=-1)))
            y_hat = self.generator(fab_h_t_tilde)
            
            cnt = 0
            for board in boards:
                if board.is_done() == 0:
                    # decide a range of each sample
                    begin = cnt * beam_size
                    end = begin + beam_size
                    
                    # pick k-best results for each sample
                    board.collect_result(
                    y_hat[begin:end],
                    {
                        'hidden_state': fab_hidden[:, begin:end, :],
                        'cell_state': fab_cell[:, begin:end, :],
                        'h_t_1_tilde': fab_h_t_tilde[begin:end],
                    })
                    cnt += 1

            is_done = [board.is_done() for board in boards]
            length += 1
            
        # pick n-best hypothesis
        batch_sentences, batch_probs = [], []
        
        # collect the results
        for i, board in enumerate(boards):
            sentences, probs = board.get_n_best(n_best, length_penalty=length_penalty)
            
            batch_sentences += [sentences]
            batch_probs += [probs]
        
        return batch_sentences, batch_probs