In [1]:
import os
import torch
import numpy as np

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from my_dataset import MyDataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from word_dictionary import WordDictionary

special_token = {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, '<UNK>': 3}

BATCH_SIZE = 64

# id化
word_dict = WordDictionary()
word_dict.create_dict()

print(word_dict.get_dict("en", "w2id"))
print(word_dict.get_dict("en", "id2w"))
print(word_dict.get_dict("ja", "w2id"))
print(word_dict.get_dict("ja", "id2w"))

en_id2w_dict = word_dict.get_dict("en", "id2w")

print(word_dict.get_id("train-1.short", "ja"))
print(word_dict.get_id("dev", "ja"))
print(word_dict.get_id("train-1.short", "en"))
print(word_dict.get_id("dev", "en"))

# データローダーに使う関数
def collate_func(batch):
  src_t = []
  dst_t = []
  
  for src, dst in batch:
    src_t.append(torch.tensor(src))
    dst_t.append(torch.tensor(dst))
  
  return pad_sequence(src_t, batch_first=True), pad_sequence(dst_t, batch_first=True)


# データローダー作成
dataset_train = MyDataset(word_dict, "train")
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_func)

dataset_dev = MyDataset(word_dict, "dev")
dataloader_dev = DataLoader(dataset_dev, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_func)

{'<pad>': 0, '<bos>': 1, '<eos>': 2, '<unk>': 3, '現在': 4, '，': 5, '筋': 6, 'ジストロフィー': 7, '患者': 8, 'の': 9, '移動': 10, '介助': 11, 'に': 12, 'お': 13, 'い': 14, 'て': 15, '文書': 16, 'マニュアル': 17, 'を': 18, '使用': 19, 'し': 20, 'る': 21, '。': 22, '最後': 23, '将来': 24, '展望': 25, 'つ': 26, '記述': 27, '次': 28, 'インバータ': 29, '共通': 30, 'する': 31, '重要': 32, 'な': 33, '技術': 34, '整理': 35, 'た': 36, 'ｄｎａ': 37, '解析': 38, '解説': 39, '疾患': 40, '関連': 41, '遺伝': 42, '子': 43, '医療': 44, 'へ': 45, '活用': 46, '説明': 47, 'その': 48, 'うち': 49, '６': 50, '分野': 51, '１２': 52, '紹介': 53, '液晶': 54, '特性': 55, '評価': 56, '装置': 57, 'で': 58, 'あ': 59, 'これ': 60, 'ら': 61, 'よ': 62, 'り': 63, '診断': 64, '指針': 65, 'や': 66, '治療': 67, '方針': 68, '確立': 69, '貢献': 70, 'こと': 71, 'が': 72, '期待': 73, 'さ': 74, 'れ': 75, 'ｅｂｍ': 76, 'ため': 77, '精度': 78, '高': 79, '臨床': 80, '疫学': 81, '研究': 82, '求め': 83, 'られ': 84, 'メタン': 85, '部分': 86, '酸化': 87, '反応': 88, '例': 89, 'は': 90, '死亡': 91, 'ヨーロッパ': 92, 'も': 93, '同様': 94, '動き': 95, '単板': 96, 'ガラス': 97, 'と': 98, '複合': 99, '日射': 100, 

In [3]:
from seq2seq import Seq2Seq

hidden_size = 256
embed_size = 256
padding_idx = special_token["<PAD>"]
vocab_size_src, vocab_size_dst = dataset_train.get_vocab_size()

lr = 0.001

model = Seq2Seq(hidden_size, vocab_size_src, vocab_size_dst, padding_idx, embed_size, device).to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)

Seq2Seq(
  (encoder): LSTM_Encoder(
    (embedding): Embedding(16134, 256, padding_idx=0)
    (lstm_cell): LSTMCell(256, 256)
  )
  (decoder): LSTM_Decoder(
    (embedding): Embedding(17260, 256, padding_idx=0)
    (lstm_cell): LSTMCell(256, 256)
    (fc): Linear(in_features=256, out_features=17260, bias=True)
  )
)


In [4]:
i=0
for src, dst in dataloader_train:
  if i < 2:
    print(i)
    print(src)
    print(dst)
    i+=1

0
tensor([[   1,  402, 7818,  ...,    0,    0,    0],
        [   1, 1472,   12,  ...,    0,    0,    0],
        [   1,  839,    9,  ...,    0,    0,    0],
        ...,
        [   1,   56,   12,  ...,    0,    0,    0],
        [   1,  544,  182,  ...,   14,   22,    2],
        [   1, 4966, 2121,  ...,    0,    0,    0]])
tensor([[   1,   47,  283,  ...,    0,    0,    0],
        [   1,    7, 3222,  ...,    0,    0,    0],
        [   1,    7, 5104,  ...,    0,    0,    0],
        ...,
        [   1,   12,    7,  ...,    0,    0,    0],
        [   1,   56,   10,  ...,    0,    0,    0],
        [   1,   32, 2576,  ...,    0,    0,    0]])
1
tensor([[    1, 16114,     9,  ...,     0,     0,     0],
        [    1,  5251,  4720,  ...,     0,     0,     0],
        [    1,  2628,     5,  ...,     0,     0,     0],
        ...,
        [    1, 12100,  3396,  ...,     0,     0,     0],
        [    1,  7536, 14245,  ...,     0,     0,     0],
        [    1,  3358,  2955,  ...,     0

In [5]:
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction

ref = ["I","have","pan","cake","and","apple", "pen","<EOS>","<PAD>"]
can = ["I","have","pen","case","and","apple", "pan","<EOS>","<PAD>"]

print(sentence_bleu([ref], can,  smoothing_function=SmoothingFunction().method1))

0.09438595268231116


In [6]:
from statistics import mean

EPOCH_NUM = 24

def train(model, train_dataloader, dev_dataloader, optimizer, criterion):
    for epoch in range(1, EPOCH_NUM+1):
        model.train()
        epoch_loss = 0
        bleu_list = []
        
        for src, dst in train_dataloader:
            optimizer.zero_grad()
            
            src_tensor = src.clone().detach().to(device)
            dst_tensor = dst.clone().detach().to(device)

            pred = model(src_tensor, dst_tensor)

            loss = torch.tensor(0, dtype=torch.float)
            for s_pred, s_dst in zip(pred, dst):
                # 教師側は<BOS>を削除し、後ろに<PAD>を挿入
                loss += criterion(s_pred, torch.cat((s_dst[1:], torch.zeros(1, dtype=torch.int32))))

            epoch_loss += loss.to("cpu").detach().numpy().copy()

            loss.backward()
            optimizer.step()
            epoch_first = 0

        model.train(False)
        for src, dst in dev_dataloader:            
            with torch.no_grad():
                src_tensor = src.clone().detach().to(device)
                dst_tensor = dst.clone().detach().to(device)
                
                pred = model(src_tensor, dst_tensor)
                
                pred_text = []
                en_id2w = np.vectorize(lambda id: en_id2w_dict[id])
                for sentence in pred:
                    pred_text.append(en_id2w(sentence)) 
                
                dst_text = en_id2w(dst.to("cpu").detach().numpy().copy())
                dst_text_clean = []
                
                for sentence in dst_text:
                    tmp_list = []
                    for word in sentence:
                        if word != "<BOS>" and word != "<PAD>":
                            tmp_list.append(word)
                    dst_text_clean.append(tmp_list)
                
                bleu = 0
                for pred, dst in zip(pred_text, dst_text_clean):
                    bleu += sentence_bleu([dst], pred,  smoothing_function=SmoothingFunction().method1)
                bleu = bleu / BATCH_SIZE
                bleu_list.append(bleu)
                print(f"bleu: {bleu}")
        
        if epoch % 2 == 0:
            torch.save(model.state_dict(), f"../../data/model_weight/lstm_s2s_{epoch}_{mean(bleu_list)}.pth")
        
        print(f"epoch {epoch} in {EPOCH_NUM} ---- epoch loss:{epoch_loss}, bleu score:{mean(bleu_list)}")
        
    

In [7]:
train(model, dataloader_train, dataloader_dev, optimizer, criterion)

pred for loss: tensor([[ 0.0257, -0.0541, -0.0505,  ...,  0.0398, -0.1287, -0.0891],
        [ 0.0090, -0.0723, -0.0327,  ...,  0.0316, -0.1464, -0.0139],
        [ 0.0204, -0.0366, -0.0915,  ..., -0.0032, -0.0858,  0.0338],
        ...,
        [ 0.0369, -0.0616, -0.0561,  ...,  0.0217, -0.0462, -0.0334],
        [ 0.0369, -0.0616, -0.0561,  ...,  0.0218, -0.0462, -0.0334],
        [ 0.0368, -0.0617, -0.0561,  ...,  0.0218, -0.0463, -0.0333]],
       grad_fn=<UnbindBackward0>)
ref for loss: tensor([  50, 1137,   15, 7711,  561,   11,   59,    7, 1523,  134,  307, 1007,
         528,    7, 1689, 8625,   19,    2,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0])
pred for loss: tensor([[ 0.0250, -0.0541, -0.0506,  ...,  0.0411, -0.1289, -0.0894],
        [-0.0392, -0.1038, -0.1372,  ..., -0.0195,  0.0192, -0.0125],
        [-0.0466, -0.0953, -0.2311,  ..., -0.0588,  0.0346, -0.1947],
        ...,
        [ 0.0368, -0.0616, -0.0562,  ..., 

KeyboardInterrupt: 