In [1]:
import os
import torch
import numpy as np

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from my_dataset import MyDataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from word_dictionary import WordDictionary

special_token = {'<pad>': 0, '<pad>': 1, '<pad>': 2, '<pad>': 3}

BATCH_SIZE = 64

# id化
word_dict = WordDictionary()
word_dict.create_dict()

en_id2w_dict = word_dict.get_dict("en", "id2w")

# データローダーに使う関数
def collate_func(batch):
  src_t = []
  dst_t = []
  
  for src, dst in batch:
    src_t.append(torch.tensor(src))
    dst_t.append(torch.tensor(dst))
  
  return pad_sequence(src_t, batch_first=True), pad_sequence(dst_t, batch_first=True)


# データローダー作成
dataset_train = MyDataset(word_dict, "train")
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_func)

dataset_dev = MyDataset(word_dict, "dev")
dataloader_dev = DataLoader(dataset_dev, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_func)

In [10]:
from seq2seq import Seq2Seq

hidden_size = 256
embed_size = 256
padding_idx = special_token["<pad>"]
vocab_size_src, vocab_size_dst = dataset_train.get_vocab_size()

lr = 0.001

model = Seq2Seq(hidden_size, vocab_size_src, vocab_size_dst, padding_idx, embed_size, device).to(device)
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)

Seq2Seq(
  (encoder): LSTM_Encoder(
    (embedding): Embedding(16134, 256, padding_idx=3)
    (lstm_cell): LSTMCell(256, 256)
  )
  (decoder): LSTM_Decoder(
    (embedding): Embedding(17260, 256, padding_idx=3)
    (lstm_cell): LSTMCell(256, 256)
    (fc): Linear(in_features=256, out_features=17260, bias=True)
  )
)


In [11]:
i=0
for src, dst in dataloader_train:
  if i < 2:
    print(i)
    print(src)
    print(dst)
    i+=1

0
tensor([[    1,   117,     5,  ...,     0,     0,     0],
        [    1,  3461,   399,  ...,     0,     0,     0],
        [    1,   648,     9,  ...,     0,     0,     0],
        ...,
        [    1,   117,   830,  ...,     0,     0,     0],
        [    1,  1522,    18,  ...,     0,     0,     0],
        [    1, 13516,     9,  ...,    36,    22,     2]])
tensor([[   1,    7, 1096,  ...,    0,    0,    0],
        [   1,    7, 1132,  ...,    0,    0,    0],
        [   1, 4814,   15,  ...,    0,    0,    0],
        ...,
        [   1,    7,  842,  ...,    0,    0,    0],
        [   1, 9436,   28,  ...,    0,    0,    0],
        [   1,  695,    7,  ...,    2,    0,    0]])
1
tensor([[   1, 2916,  477,  ...,    0,    0,    0],
        [   1, 1859,   98,  ...,    0,    0,    0],
        [   1,  435,    9,  ...,    0,    0,    0],
        ...,
        [   1,  117,    5,  ...,    0,    0,    0],
        [   1,    4,   58,  ...,    0,    0,    0],
        [   1,   28,   12,  ...,   

In [12]:
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction

ref = ["I","have","pan","cake","and","apple", "pen","<EOS>","<PAD>"]
can = ["I","have","pen","case","and","apple", "pan","<EOS>","<PAD>"]

print(sentence_bleu([ref], can,  smoothing_function=SmoothingFunction().method1))

0.09438595268231116


In [13]:
from statistics import mean

EPOCH_NUM = 24

def train(model, train_dataloader, dev_dataloader, optimizer, criterion):
    for epoch in range(1, EPOCH_NUM+1):
        model.train()
        epoch_loss = 0
        bleu_list = []
        
        for src, dst in train_dataloader:
            optimizer.zero_grad()
            
            src_tensor = src.clone().detach().to(device)
            dst_tensor = dst.clone().detach().to(device)

            pred = model(src_tensor, dst_tensor)

            loss = torch.tensor(0, dtype=torch.float)
            for s_pred, s_dst in zip(pred, dst):
                # 教師側は<BOS>を削除し、後ろに<PAD>を挿入
                loss += criterion(s_pred, torch.cat((s_dst[1:], torch.zeros(1, dtype=torch.int32))))

            epoch_loss += loss.to("cpu").detach().numpy().copy()

            loss.backward()
            optimizer.step()

        model.train(False)
        for src, dst in dev_dataloader:            
            with torch.no_grad():
                src_tensor = src.clone().detach().to(device)
                dst_tensor = dst.clone().detach().to(device)
                
                pred = model(src_tensor, dst_tensor)
                
                pred_text = []
                en_id2w = np.vectorize(lambda id: en_id2w_dict[id])
                for sentence in pred:
                    pred_text.append(en_id2w(sentence)) 
                
                dst_text = en_id2w(dst.to("cpu").detach().numpy().copy())
                dst_text_clean = []
                
                for sentence in dst_text:
                    tmp_list = []
                    for word in sentence:
                        if word != "<bos>" and word != "<pad>":
                            tmp_list.append(word)
                    dst_text_clean.append(tmp_list)
                
                bleu = 0
                for pred, dst in zip(pred_text, dst_text_clean):
                    bleu += sentence_bleu([dst], pred,  smoothing_function=SmoothingFunction().method1)
                bleu = bleu / BATCH_SIZE
                bleu_list.append(bleu)
                print(f"bleu: {bleu}")
        
        if epoch % 2 == 0:
            torch.save(model.state_dict(), f"../../data/model_weight/lstm_s2s_{epoch}_{mean(bleu_list)}.pth")
        
        print(f"epoch {epoch} in {EPOCH_NUM} ---- epoch loss:{epoch_loss}, bleu score:{mean(bleu_list)}")
        
    

In [14]:
train(model, dataloader_train, dataloader_dev, optimizer, criterion)

bleu: 0.01627569276532695
bleu: 0.019901474016532358
bleu: 0.019309124507309833
bleu: 0.016746722175924607
bleu: 0.020745475259073495
bleu: 0.020119679150728834
bleu: 0.021740957078582247
bleu: 0.018596844243892793
bleu: 0.01382337901059476
bleu: 0.017691621020579876
bleu: 0.012057785767108411
bleu: 0.01732552746190705
bleu: 0.01856287614695845
bleu: 0.01839161193652485
bleu: 0.01825646174164372
bleu: 0.01866476823310106
bleu: 0.018179459775171294
bleu: 0.014570519058516723
bleu: 0.018210170002489773
bleu: 0.019061285761899662
bleu: 0.021590391138450193
bleu: 0.018577894701261276
bleu: 0.01934284230408907
bleu: 0.01809561067260788
bleu: 0.01890109780094523
bleu: 0.01509086087012893
bleu: 0.019410714697052096
bleu: 0.014134439890392713
epoch 1 in 24 ---- epoch loss:113910.33076477051, bleu score:0.017977688828171218
bleu: 0.02575121332240382
bleu: 0.026048919871733432
bleu: 0.021581607629881094
bleu: 0.029796890524142333
bleu: 0.02694869653217876
bleu: 0.025211536346232282
bleu: 0.02145