# Example: translation

In [1]:
# import packages
import torch
import pickle
import pandas as pd
import random
from pathlib import Path
from torch.utils.data import DataLoader
from model.split import Stemmer
from model.net import BidiEncoder, AttnDecoder
from model.utils import SourceProcessor, TargetProcessor
from model.data import NMTCorpus, batchify
from utils import Config, CheckpointManager, SummaryManager

In [2]:
# load configs
model_config = Config("conf/model/luongattn.json")
dataset_config = Config("conf/dataset/sample.json")

In [3]:
# get processors
def get_processor(dataset_config):
    with open(dataset_config.source_vocab, mode="rb") as io:
        src_vocab = pickle.load(io)
    src_stemmer = Stemmer(language="ko")
    src_processor = SourceProcessor(src_vocab, src_stemmer.extract_stem)

    with open(dataset_config.target_vocab, mode="rb") as io:
        tgt_vocab = pickle.load(io)
    tgt_stemmer = Stemmer(language="en")
    tgt_processor = TargetProcessor(tgt_vocab, tgt_stemmer.extract_stem)
    return src_processor, tgt_processor

src_processor, tgt_processor = get_processor(dataset_config)

In [4]:
# restore model
exp_dir = Path('experiments') / model_config.type
exp_dir = next(exp_dir.iterdir())
checkpoint_manager = CheckpointManager(exp_dir)

encoder = BidiEncoder(
    src_processor.vocab, model_config.encoder_hidden_dim, model_config.drop_ratio
)
decoder = AttnDecoder(
    tgt_processor.vocab,
    model_config.method,
    model_config.encoder_hidden_dim * 2,
    model_config.decoder_hidden_dim,
    model_config.drop_ratio
)

checkpoint = checkpoint_manager.load_checkpoint("best.tar")
encoder.load_state_dict(checkpoint["encoder_state_dict"])
decoder.load_state_dict(checkpoint["decoder_state_dict"])

encoder.eval()
decoder.eval()

AttnDecoder(
  (_emb): Embedding(
    (_ops): Embedding(5085, 300, padding_idx=1)
  )
  (_ops): LSTM(300, 512, num_layers=2, batch_first=True, dropout=0.3)
  (_attn): GlobalAttn()
  (_concat): Linear(in_features=1024, out_features=300, bias=False)
  (_dropout): Dropout(p=0.3, inplace=False)
)

## translation

In [5]:
# prepair example pair
tr = pd.read_csv(dataset_config.train, sep='\t')
example_pair = tr.iloc[random.randint(0, len(tr))]

In [6]:
print(src_processor.vocab.to_tokens(src_processor.process(example_pair.ko)))
print(tgt_processor.vocab.to_tokens(tgt_processor.process(example_pair.en)))

['나', '의', '장래', '희망', '은', '자동차', '디자이너', '입니다', '.']
['my', 'dream', 'is', 'to', 'be', 'a', 'car', 'design', '.', '<eos>']


In [7]:
def translate(source_sentence, max_len=50, src_processor=src_processor, tgt_processor=tgt_processor):
    with torch.no_grad():
        src = torch.tensor(src_processor.process(source_sentence)).unsqueeze(0)
        
        enc_outputs, src_length, enc_hc = encoder(src)
        dec_input = torch.ones((1, 1)).long()
        dec_input *= tgt_processor.vocab.to_indices(tgt_processor.vocab.bos_token)
        dec_hc = None

        translation = [tgt_processor.vocab.bos_token]
        
        for time_step in range(max_len):
            dec_output, dec_hc = decoder(dec_input, dec_hc, enc_outputs, src_length)
            dec_input = dec_output.topk(1).indices
            token = tgt_processor.vocab.to_tokens(dec_input.item())
            translation.append(token)
            
            if token == tgt_processor.vocab.eos_token:
                break
    
    return translation

In [8]:
translation = translate(example_pair.ko)
print(translation)

['<bos>', 'my', 'dream', 'is', 'a', 'car', 'design', '.', '<eos>']
