Implenntation of Transformer Arch 

<img src="image.png" width="500" height="600">



In [4]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import time

import torch
from tensorboardX import SummaryWriter
from tqdm import tqdm
import torch.optim as optim
from dataset import problem





UTILS

In [5]:
import shutil


class utils:
    @staticmethod
    def get_loss(pred, ans, vocab_size, label_smoothing, pad):
        # took this "normalizing" from tensor2tensor. We subtract it for
        # readability. This makes no difference on learning.
        confidence = 1.0 - label_smoothing
        low_confidence = (1.0 - confidence) / float(vocab_size - 1)
        normalizing = -(
            confidence * math.log(confidence) + float(vocab_size - 1) *
            low_confidence * math.log(low_confidence + 1e-20))
        one_hot = torch.zeros_like(pred).scatter_(1, ans.unsqueeze(1), 1)
        one_hot = one_hot * confidence + (1 - one_hot) * low_confidence
        log_prob = F.log_softmax(pred, dim=1)
        xent = -(one_hot * log_prob).sum(dim=1)
        xent = xent.masked_select(ans != pad)
        loss = (xent - normalizing).mean()
        return loss

    @staticmethod
    def get_accuracy(pred, ans, pad):
        pred = pred.max(1)[1]
        n_correct = pred.eq(ans)
        n_correct = n_correct.masked_select(ans != pad)
        return n_correct.sum().item() / n_correct.size(0)

    @staticmethod
    def save_checkpoint(model, filepath, global_step, is_best):
        model_save_path = filepath + '/last_model.pt'
        torch.save(model, model_save_path)
        torch.save(global_step, filepath + '/global_step.pt')
        if is_best:
            best_save_path = filepath + '/best_model.pt'
            shutil.copyfile(model_save_path, best_save_path)

    @staticmethod
    def load_checkpoint(model_path, device, is_eval=True):
        if is_eval:
            model = torch.load(model_path + '/best_model.pt')
            model.eval()
            return model.to(device=device)
        model = torch.load(model_path + '/last_model.pt')
        global_step = torch.load(model_path + '/global_step.pt')
        return model.to(device=device), global_step

    @staticmethod
    def create_pad_mask(t, pad):
        mask = (t == pad).unsqueeze(-2)
        return mask

    @staticmethod
    def create_trg_self_mask(target_len, device=None):
        # Prevent leftward information flow in self-attention.
        ones = torch.ones(target_len, target_len, dtype=torch.uint8,
                          device=device)
        t_self_mask = torch.triu(ones, diagonal=1).unsqueeze(0)
        return t_self_mask

def forward(self, inputs, targets):
    enc_output, i_mask = None, None
    if self.has_inputs:
        i_mask = utils.create_pad_mask(inputs, self.src_pad_idx)
        enc_output = self.encode(inputs, i_mask)
    
    t_mask = utils.create_pad_mask(targets, self.trg_pad_idx)
    target_size = targets.size()[1]
    t_self_mask = utils.create_trg_self_mask(target_size,
                                             device=targets.device)
    return self.decode(targets, enc_output, i_mask, t_self_mask, t_mask)

In [6]:
class LRScheduler:
    def __init__(self, parameters, hidden_size, warmup, step=0):
        self.constant = 2.0 * (hidden_size ** -0.5)
        self.cur_step = step
        self.warmup = warmup
        self.optimizer = optim.Adam(parameters, lr=self.learning_rate(),
                                    betas=(0.9, 0.997), eps=1e-09)

    def step(self):
        self.cur_step += 1
        rate = self.learning_rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self.optimizer.step()

    def zero_grad(self):
        self.optimizer.zero_grad()

    def learning_rate(self):
        lr = self.constant
        lr *= min(1.0, self.cur_step / self.warmup)
        lr *= max(self.cur_step, self.warmup) ** -0.5
        return lr

In [7]:
def initialize_weight(x):
    nn.init.xavier_uniform_(x.weight)
    if x.bias is not None:
        nn.init.constant_(x.bias, 0)
        
        
class FeedForwardNetwork(nn.Module):
     def __init__(self, hidden_size, filter_size, dropout_rate):
        super(FeedForwardNetwork, self).__init__()

        self.layer1 = nn.Linear(hidden_size, filter_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.layer2 = nn.Linear(filter_size, hidden_size)

        initialize_weight(self.layer1)
        initialize_weight(self.layer2)
        
     def forward(self,x):
        x= self.layer1(x)
        x= self.relu(x)
        x= self.dropout(x)
        x= self.layer(x)
        
        return x

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, dropout_rate, head_size=8):
        super(MultiHeadAttention,self).__init__()
        
        self.head_size = head_size
        self.att_size = att_size = hidden_size // head_size
        self.scale = att_size ** -0.5
        
        self.linear_q = nn.Linear(hidden_size,head_size*att_size, bias = False)
        self.linear_k = nn.Linear(hidden_size,head_size*att_size, bias = False)
        self.linear_v = nn.Linear(hidden_size,head_size*att_size, bias = False)
        
        initialize_weight(self.linear_q)
        initialize_weight(self.linear_k)
        initialize_weight(self.linear_v)
        
        self.att_dropout = nn.Dropout(dropout_rate)
        self.output_layer = nn.Linear(head_size*att_size,hidden_size,bias = False)
        
        initialize_weight(self.output_layer)
    
    def forward(self, q, k, v, mask, cache=None):
        orig_q_size = q.size()

        d_k = self.att_size
        d_v = self.att_size
        batch_size = q.size(0)

        # head_i = Attention(Q(W^Q)_i, K(W^K)_i, V(W^V)_i)
        q = self.linear_q(q).view(batch_size, -1, self.head_size, d_k)
        if cache is not None and 'encdec_k' in cache:
            k, v = cache['encdec_k'], cache['encdec_v']
        else:
            k = self.linear_k(k).view(batch_size, -1, self.head_size, d_k)
            v = self.linear_v(v).view(batch_size, -1, self.head_size, d_v)

            if cache is not None:
                cache['encdec_k'], cache['encdec_v'] = k, v

        q = q.transpose(1, 2)                  # [b, h, q_len, d_k]
        v = v.transpose(1, 2)                  # [b, h, v_len, d_v]
        k = k.transpose(1, 2).transpose(2, 3)  # [b, h, d_k, k_len]

        # Scaled Dot-Product Attention.
        # Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V
        q.mul_(self.scale)
        x = torch.matmul(q, k)  # [b, h, q_len, k_len]
        x.masked_fill_(mask.unsqueeze(1), -1e9)
        x = torch.softmax(x, dim=3)
        x = self.att_dropout(x)
        x = x.matmul(v)  # [b, h, q_len, attn]

        x = x.transpose(1, 2).contiguous()  # [b, q_len, h, attn]
        x = x.view(batch_size, -1, self.head_size * d_v)

        x = self.output_layer(x)

        assert x.size() == orig_q_size
        return x

        
    
        
        

In [9]:
class EncoderLayer(nn.Module):
    def __init__(self,hidden_size,filter_size,dropout_rate):
        super(EncoderLayer,self).__init__()
        
        self.self_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.self_attention = MultiHeadAttention(hidden_size, dropout_rate)
        self.self_attention_dropout = nn.Dropout(dropout_rate)

        self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.ffn = FeedForwardNetwork(hidden_size, filter_size, dropout_rate)
        self.ffn_dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask):  
        y = self.self_attention_norm(x)
        y = self.self_attention(y, y, y, mask)
        y = self.self_attention_dropout(y)
        x = x + y

        y = self.ffn_norm(x)
        y = self.ffn(y)
        y = self.ffn_dropout(y)
        x = x + y
        return x


In [10]:
class  DecoderLayer(nn.Module):
    def __init__(self, hidden_size, filter_size, dropout_rate):
        super(DecoderLayer, self).__init__()

        self.self_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.self_attention = MultiHeadAttention(hidden_size, dropout_rate)
        self.self_attention_dropout = nn.Dropout(dropout_rate)

        self.enc_dec_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.enc_dec_attention = MultiHeadAttention(hidden_size, dropout_rate)
        self.enc_dec_attention_dropout = nn.Dropout(dropout_rate)

        self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6)
        self.ffn = FeedForwardNetwork(hidden_size, filter_size, dropout_rate)
        self.ffn_dropout = nn.Dropout(dropout_rate)

    def forward(self, x, enc_output, self_mask, i_mask, cache):
        y = self.self_attention_norm(x)
        y = self.self_attention(y, y, y, self_mask)
        y = self.self_attention_dropout(y)
        x = x + y

        if enc_output is not None:
            y = self.enc_dec_attention_norm(x)
            y = self.enc_dec_attention(y, enc_output, enc_output, i_mask,
                                       cache)
            y = self.enc_dec_attention_dropout(y)
            x = x + y

        y = self.ffn_norm(x)
        y = self.ffn(y)
        y = self.ffn_dropout(y)
        x = x + y
        return x


In [11]:
class Encoder(nn.Module):
    def __init__(self, hidden_size, filter_size, dropout_rate, n_layers):
        super(Encoder, self).__init__()

        encoders = [EncoderLayer(hidden_size, filter_size, dropout_rate)
                    for _ in range(n_layers)]
        self.layers = nn.ModuleList(encoders)

        self.last_norm = nn.LayerNorm(hidden_size, eps=1e-6)

    def forward(self, inputs, mask):
        encoder_output = inputs
        for enc_layer in self.layers:
            encoder_output = enc_layer(encoder_output, mask)
        return self.last_norm(encoder_output)


class Decoder(nn.Module):
    def __init__(self, hidden_size, filter_size, dropout_rate, n_layers):
        super(Decoder, self).__init__()

        decoders = [DecoderLayer(hidden_size, filter_size, dropout_rate)
                    for _ in range(n_layers)]
        self.layers = nn.ModuleList(decoders)

        self.last_norm = nn.LayerNorm(hidden_size, eps=1e-6)

    def forward(self, targets, enc_output, i_mask, t_self_mask, cache):
        decoder_output = targets
        for i, dec_layer in enumerate(self.layers):
            layer_cache = None
            if cache is not None:
                if i not in cache:
                    cache[i] = {}
                layer_cache = cache[i]
            decoder_output = dec_layer(decoder_output, enc_output,
                                       t_self_mask, i_mask, layer_cache)
        return self.last_norm(decoder_output)

In [12]:
class Transformer(nn.Module):
    def __init__(self, i_vocab_size, t_vocab_size,
                 n_layers=6,
                 hidden_size=512,
                 filter_size=2048,
                 dropout_rate=0.1,
                 share_target_embedding=True,
                 has_inputs=True,
                 src_pad_idx=None,
                 trg_pad_idx=None):
        super(Transformer, self).__init__()

        self.hidden_size = hidden_size
        self.emb_scale = hidden_size ** 0.5
        self.has_inputs = has_inputs
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx

        self.t_vocab_embedding = nn.Embedding(t_vocab_size, hidden_size)
        nn.init.normal_(self.t_vocab_embedding.weight, mean=0,
                        std=hidden_size**-0.5)
        self.t_emb_dropout = nn.Dropout(dropout_rate)
        self.decoder = Decoder(hidden_size, filter_size,
                               dropout_rate, n_layers)

        if has_inputs:
            if not share_target_embedding:
                self.i_vocab_embedding = nn.Embedding(i_vocab_size,
                                                      hidden_size)
                nn.init.normal_(self.i_vocab_embedding.weight, mean=0,
                                std=hidden_size**-0.5)
            else:
                self.i_vocab_embedding = self.t_vocab_embedding

            self.i_emb_dropout = nn.Dropout(dropout_rate)

            self.encoder = Encoder(hidden_size, filter_size,
                                   dropout_rate, n_layers)

        # For positional encoding
        num_timescales = self.hidden_size // 2
        max_timescale = 10000.0
        min_timescale = 1.0
        log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            max(num_timescales - 1, 1))
        inv_timescales = min_timescale * torch.exp(
            torch.arange(num_timescales, dtype=torch.float32) *
            -log_timescale_increment)
        self.register_buffer('inv_timescales', inv_timescales)

    def forward(self, inputs, targets):
        enc_output, i_mask = None, None
        if self.has_inputs:
            i_mask = utils.create_pad_mask(inputs, self.src_pad_idx)
            enc_output = self.encode(inputs, i_mask)

        t_mask = utils.create_pad_mask(targets, self.trg_pad_idx)
        target_size = targets.size()[1]
        t_self_mask = utils.create_trg_self_mask(target_size,
                                                 device=targets.device)
        return self.decode(targets, enc_output, i_mask, t_self_mask, t_mask)

    def encode(self, inputs, i_mask):
        # Input embedding
        input_embedded = self.i_vocab_embedding(inputs)
        input_embedded.masked_fill_(i_mask.squeeze(1).unsqueeze(-1), 0)
        input_embedded *= self.emb_scale
        input_embedded += self.get_position_encoding(inputs)
        input_embedded = self.i_emb_dropout(input_embedded)

        return self.encoder(input_embedded, i_mask)

    def decode(self, targets, enc_output, i_mask, t_self_mask, t_mask,
               cache=None):
        # target embedding
        target_embedded = self.t_vocab_embedding(targets)
        target_embedded.masked_fill_(t_mask.squeeze(1).unsqueeze(-1), 0)

        # Shifting
        target_embedded = target_embedded[:, :-1]
        target_embedded = F.pad(target_embedded, (0, 0, 1, 0))

        target_embedded *= self.emb_scale
        target_embedded += self.get_position_encoding(targets)
        target_embedded = self.t_emb_dropout(target_embedded)

        # decoder
        decoder_output = self.decoder(target_embedded, enc_output, i_mask,
                                      t_self_mask, cache)
        # linear
        output = torch.matmul(decoder_output,
                              self.t_vocab_embedding.weight.transpose(0, 1))

        return output

    def get_position_encoding(self, x):
        max_length = x.size()[1]
        position = torch.arange(max_length, dtype=torch.float32,
                                device=x.device)
        scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
        signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
                           dim=1)
        signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
        signal = signal.view(1, max_length, self.hidden_size)
        return signal

TRAIN

In [13]:
def summarize_train(writer, global_step, last_time, model, opt,
                    inputs, targets, optimizer, loss, pred, ans):
    if opt.summary_grad:
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue

            norm = torch.norm(param.grad.data.view(-1))
            writer.add_scalar('gradient_norm/' + name, norm,
                              global_step)

    writer.add_scalar('input_stats/batch_size',
                      targets.size(0), global_step)

    if inputs is not None:
        writer.add_scalar('input_stats/input_length',
                          inputs.size(1), global_step)
        i_nonpad = (inputs != opt.src_pad_idx).view(-1).type(torch.float32)
        writer.add_scalar('input_stats/inputs_nonpadding_frac',
                          i_nonpad.mean(), global_step)

    writer.add_scalar('input_stats/target_length',
                      targets.size(1), global_step)
    t_nonpad = (targets != opt.trg_pad_idx).view(-1).type(torch.float32)
    writer.add_scalar('input_stats/target_nonpadding_frac',
                      t_nonpad.mean(), global_step)

    writer.add_scalar('optimizer/learning_rate',
                      optimizer.learning_rate(), global_step)

    writer.add_scalar('loss', loss.item(), global_step)

    acc = utils.get_accuracy(pred, ans, opt.trg_pad_idx)
    writer.add_scalar('training/accuracy',
                      acc, global_step)

    steps_per_sec = 100.0 / (time.time() - last_time)
    writer.add_scalar('global_step/sec', steps_per_sec,
                      global_step)


def train(train_data, model, opt, global_step, optimizer, t_vocab_size,
          label_smoothing, writer):
    model.train()
    last_time = time.time()
    pbar = tqdm(total=len(train_data.dataset), ascii=True)
    for batch in train_data:
        inputs = None
        if opt.has_inputs:
            inputs = batch.src

        targets = batch.trg
        pred = model(inputs, targets)

        pred = pred.view(-1, pred.size(-1))
        ans = targets.view(-1)

        loss = utils.get_loss(pred, ans, t_vocab_size,
                              label_smoothing, opt.trg_pad_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if global_step % 100 == 0:
            summarize_train(writer, global_step, last_time, model, opt,
                            inputs, targets, optimizer, loss, pred, ans)
            last_time = time.time()

        pbar.set_description('[Loss: {:.4f}]'.format(loss.item()))

        global_step += 1
        pbar.update(targets.size(0))

    pbar.close()
    train_data.reload_examples()
    return global_step


def validation(validation_data, model, global_step, t_vocab_size, val_writer,
               opt):
    model.eval()
    total_loss = 0.0
    total_cnt = 0
    for batch in validation_data:
        inputs = None
        if opt.has_inputs:
            inputs = batch.src
        targets = batch.trg

        with torch.no_grad():
            pred = model(inputs, targets)

            pred = pred.view(-1, pred.size(-1))
            ans = targets.view(-1)
            loss = utils.get_loss(pred, ans, t_vocab_size, 0,
                                  opt.trg_pad_idx)
        total_loss += loss.item() * len(batch)
        total_cnt += len(batch)

    val_loss = total_loss / total_cnt
    print("Validation Loss", val_loss)
    val_writer.add_scalar('loss', val_loss, global_step)
    return val_loss


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--problem', required=True)
    parser.add_argument('--train_step', type=int, default=200)
    parser.add_argument('--batch_size', type=int, default=4096)
    parser.add_argument('--max_length', type=int, default=100)
    parser.add_argument('--n_layers', type=int, default=6)
    parser.add_argument('--hidden_size', type=int, default=512)
    parser.add_argument('--filter_size', type=int, default=2048)
    parser.add_argument('--warmup', type=int, default=16000)
    parser.add_argument('--val_every', type=int, default=5)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--label_smoothing', type=float, default=0.1)
    parser.add_argument('--model', type=str, default='transformer')
    parser.add_argument('--output_dir', type=str, default='./output')
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--parallel', action='store_true')
    parser.add_argument('--summary_grad', action='store_true')
    opt = parser.parse_args()

    device = torch.device('cpu' if opt.no_cuda else 'cuda')

    if not os.path.exists(opt.output_dir + '/last/models'):
        os.makedirs(opt.output_dir + '/last/models')
    if not os.path.exists(opt.data_dir):
        os.makedirs(opt.data_dir)

    train_data, validation_data, i_vocab_size, t_vocab_size, opt = \
        problem.prepare(opt.problem, opt.data_dir, opt.max_length,
                        opt.batch_size, device, opt)
    if i_vocab_size is not None:
        print("# of vocabs (input):", i_vocab_size)
    print("# of vocabs (target):", t_vocab_size)

    if opt.model == 'transformer':
        from model.transformer import Transformer
        model_fn = Transformer
    elif opt.model == 'fast_transformer':
        from model.fast_transformer import FastTransformer
        model_fn = FastTransformer

    if os.path.exists(opt.output_dir + '/last/models/last_model.pt'):
        print("Load a checkpoint...")
        last_model_path = opt.output_dir + '/last/models'
        model, global_step = utils.load_checkpoint(last_model_path, device,
                                                   is_eval=False)
    else:
        model = model_fn(i_vocab_size, t_vocab_size,
                         n_layers=opt.n_layers,
                         hidden_size=opt.hidden_size,
                         filter_size=opt.filter_size,
                         dropout_rate=opt.dropout,
                         share_target_embedding=opt.share_target_embedding,
                         has_inputs=opt.has_inputs,
                         src_pad_idx=opt.src_pad_idx,
                         trg_pad_idx=opt.trg_pad_idx)
        model = model.to(device=device)
        global_step = 0

    if opt.parallel:
        print("Use", torch.cuda.device_count(), "GPUs")
        model = torch.nn.DataParallel(model)

    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("# of parameters: {}".format(num_params))

    optimizer = LRScheduler(
        filter(lambda x: x.requires_grad, model.parameters()),
        opt.hidden_size, opt.warmup, step=global_step)

    writer = SummaryWriter(opt.output_dir + '/last')
    val_writer = SummaryWriter(opt.output_dir + '/last/val')
    best_val_loss = float('inf')

    for t_step in range(opt.train_step):
        print("Epoch", t_step)
        start_epoch_time = time.time()
        global_step = train(train_data, model, opt, global_step,
                            optimizer, t_vocab_size, opt.label_smoothing,
                            writer)
        print("Epoch Time: {:.2f} sec".format(time.time() - start_epoch_time))

        if t_step % opt.val_every != 0:
            continue

        val_loss = validation(validation_data, model, global_step,
                              t_vocab_size, val_writer, opt)
        utils.save_checkpoint(model, opt.output_dir + '/last/models',
                              global_step, val_loss < best_val_loss)
        best_val_loss = min(val_loss, best_val_loss)


Decode

In [None]:
# import argparse
# import time

# import torch
# import torch.nn.functional as F



# # pylint: disable=not-callable


# def encode_inputs(sentence, model, src_data, beam_size, device):
#     inputs = src_data['field'].preprocess(sentence)
#     inputs.append(src_data['field'].eos_token)
#     inputs = [inputs]
#     inputs = src_data['field'].process(inputs, device=device)
#     with torch.no_grad():
#         src_mask = utils.create_pad_mask(inputs, src_data['pad_idx'])
#         enc_output = model.encode(inputs, src_mask)
#         enc_output = enc_output.repeat(beam_size, 1, 1)
#     return enc_output, src_mask


# def update_targets(targets, best_indices, idx, vocab_size):
#     best_tensor_indices = torch.div(best_indices, vocab_size)
#     best_token_indices = torch.fmod(best_indices, vocab_size)
#     new_batch = torch.index_select(targets, 0, best_tensor_indices)
#     new_batch[:, idx] = best_token_indices
#     return new_batch


# def get_result_sentence(indices_history, trg_data, vocab_size):
#     result = []
#     k = 0
#     for best_indices in indices_history[::-1]:
#         best_idx = best_indices[k]
#         # TODO: get this vocab_size from target.pt?
#         k = best_idx // vocab_size
#         best_token_idx = best_idx % vocab_size
#         best_token = trg_data['field'].vocab.itos[best_token_idx]
#         result.append(best_token)
#     return ' '.join(result[::-1])


# def main():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--data_dir', type=str, required=True)
#     parser.add_argument('--model_dir', type=str, required=True)
#     parser.add_argument('--max_length', type=int, default=100)
#     parser.add_argument('--beam_size', type=int, default=4)
#     parser.add_argument('--alpha', type=float, default=0.6)
#     parser.add_argument('--no_cuda', action='store_true')
#     parser.add_argument('--translate', action='store_true')
#     args = parser.parse_args()

#     beam_size = args.beam_size

#     # Load fields.
#     if args.translate:
#         src_data = torch.load(args.data_dir + '/source.pt')
#     trg_data = torch.load(args.data_dir + '/target.pt')

#     # Load a saved model.
#     device = torch.device('cpu' if args.no_cuda else 'cuda')
#     model = utils.load_checkpoint(args.model_dir, device)

#     pads = torch.tensor([trg_data['pad_idx']] * beam_size, device=device)
#     pads = pads.unsqueeze(-1)

#     # We'll find a target sequence by beam search.
#     scores_history = [torch.zeros((beam_size,), dtype=torch.float,
#                                   device=device)]
#     indices_history = []
#     cache = {}

#     eos_idx = trg_data['field'].vocab.stoi[trg_data['field'].eos_token]

#     if args.translate:
#         sentence = input('Source? ')

#     # Encoding inputs.
#     if args.translate:
#         start_time = time.time()
#         enc_output, src_mask = encode_inputs(sentence, model, src_data,
#                                              beam_size, device)
#         targets = pads
#         start_idx = 0
#     else:
#         enc_output, src_mask = None, None
#         sentence = input('Target? ').split()
#         for idx, _ in enumerate(sentence):
#             sentence[idx] = trg_data['field'].vocab.stoi[sentence[idx]]
#         sentence.append(trg_data['pad_idx'])
#         targets = torch.tensor([sentence], device=device)
#         start_idx = targets.size(1) - 1
#         start_time = time.time()

#     with torch.no_grad():
#         for idx in range(start_idx, args.max_length):
#             if idx > start_idx:
#                 targets = torch.cat((targets, pads), dim=1)
#             t_self_mask = utils.create_trg_self_mask(targets.size()[1],
#                                                      device=targets.device)

#             t_mask = utils.create_pad_mask(targets, trg_data['pad_idx'])
#             pred = model.decode(targets, enc_output, src_mask,
#                                 t_self_mask, t_mask, cache)
#             pred = pred[:, idx].squeeze(1)
#             vocab_size = pred.size(1)

#             pred = F.log_softmax(pred, dim=1)
#             if idx == start_idx:
#                 scores = pred[0]
#             else:
#                 scores = scores_history[-1].unsqueeze(1) + pred
#             length_penalty = pow(((5. + idx + 1.) / 6.), args.alpha)
#             scores = scores / length_penalty
#             scores = scores.view(-1)

#             best_scores, best_indices = scores.topk(beam_size, 0)
#             scores_history.append(best_scores)
#             indices_history.append(best_indices)

#             # Stop searching when the best output of beam is EOS.
#             if best_indices[0].item() % vocab_size == eos_idx:
#                 break

#             targets = update_targets(targets, best_indices, idx, vocab_size)

#     result = get_result_sentence(indices_history, trg_data, vocab_size)
#     print("Result: {}".format(result))

#     print("Elapsed Time: {:.2f} sec".format(time.time() - start_time))


# if __name__ == '__main__':
#     main()

usage: ipykernel_launcher.py [-h] --data_dir DATA_DIR --model_dir MODEL_DIR
                             [--max_length MAX_LENGTH] [--beam_size BEAM_SIZE]
                             [--alpha ALPHA] [--no_cuda] [--translate]
ipykernel_launcher.py: error: the following arguments are required: --data_dir, --model_dir


SystemExit: 2