In [1]:
import os
import torch
import numpy as np

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from word_dictionary import WordDictionary
from my_dataset import MyDataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

special_token = {'<pad>': 0, '<bos>': 1, '<eos>': 2, '<unk>': 3}

BATCH_SIZE = 64

# id化
word_dict = WordDictionary()
word_dict.create_dict()

en_id2w_dict = word_dict.get_dict("en", "id2w")

# データローダーに使う関数
def collate_func(batch):
    src_t = []
    dst_t = []

    for src, dst in batch:
        src_t.append(torch.tensor(src))
        dst_t.append(torch.tensor(dst))

    return pad_sequence(src_t, batch_first=True), pad_sequence(dst_t, batch_first=True)


# データローダー作成
dataset_test = MyDataset(word_dict, "test")
dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_func)

In [3]:
from seq2seq import Seq2Seq

hidden_size = 256
embed_size = 256
padding_idx = special_token["<pad>"]
vocab_size_src, vocab_size_dst = dataset_test.get_vocab_size()

lr = 0.001

model = Seq2Seq(hidden_size, vocab_size_src, vocab_size_dst, padding_idx, embed_size, device).to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)

Seq2Seq(
  (encoder): LSTM_Encoder(
    (embedding): Embedding(16134, 256, padding_idx=0)
    (lstm_cell): LSTMCell(256, 256)
  )
  (decoder): LSTM_Decoder(
    (embedding): Embedding(17260, 256, padding_idx=0)
    (lstm_cell): LSTMCell(256, 256)
    (fc): Linear(in_features=256, out_features=17260, bias=True)
  )
)


In [4]:
WEITGHT_PATH = "/home/morioka/workspace/git_projects/lab-tutorial-nmt/data/model_weight/lstm_s2s_24_0.02667330542932727.pth"

model.load_state_dict(torch.load(WEITGHT_PATH))

<All keys matched successfully>

In [5]:
def test(model, test_dataloader):
    model.train(False)
    pred_text_clean = []
    
    with torch.no_grad():
        for src, dst in test_dataloader:
            src_tensor = src.clone().detach().to(device)
            dst_tensor = dst.clone().detach().to(device)
        
            pred = model(src_tensor, dst_tensor)
            
            pred_text = []
            en_id2w = np.vectorize(lambda id: en_id2w_dict[id])
            for sentence in pred:
                pred_text.append(en_id2w(sentence))

            for sentence in pred_text:
                tmp_list = []
                for word in sentence:
                    if word != "<bos>" and word != "<pad>" and word != "<eos>":
                        tmp_list.append(word)
                pred_text_clean.append(" ".join(tmp_list))
    return pred_text_clean

In [6]:
ouput_text = test(model, dataloader_test)

print(ouput_text)

['the system of the ddjb computer which is a simple model of the information which is a component of the mechanism analysis was examined .', 'the absorption edge of kmf single crystals and rna is secreted from the cell of the cell surface at the time of the cluster of the cluster are measured .', 'the use of antipyretics should be avoided fundamentally to the petroleum .', 'the system constitution of the titled broadcast was shown in the case of the system which is shown in the experiment .', 'in the calculation , the rss method is applied to the square of the asymmetric device for the temperature and the temperature dependency of the wavelength shape for the temperature dependency of the silk to the surface .', 'the authors have developed a charcoal board generated by a built-in magnetic field generated by a built-in light and a voltage , and a freezer to make a strong core of the high pressure rotating disk has been developed , and it was shown to be a knowledge on the case of which 

In [7]:
SAVE_TEXT_PATH = "/home/morioka/workspace/git_projects/lab-tutorial-nmt/data/model_weight/lstm_s2s_24_0.02667330542932727.pth.en"

with open(SAVE_TEXT_PATH, "w") as f:
    f.write("\n".join(ouput_text))