In [None]:
#!/usr/bin/env python
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# from torch.nn import TransformerEncoder, TransformerEncoderLayer

import torchtext
from torchtext import data, datasets
from torchtext.vocab import FastText

import os
import math
import MeCab

device="cuda:0" if torch.cuda.is_available() else "cpu"
SEED = 1
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
mecab = MeCab.Tagger("-Owakati")

# BLEU

from Util.selfbleu import CalcSelfBLEU
from Util.bleu import calc_all_bleu

from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
# HYPER PARAMETER
BATCH_SIZE=100
EMBED_SIZE=300
LSTM_HIDDEN_SIZE=300
AE_EPOCHS=100
REPORT_INTERVAL=5
ae_lr=1e-2

In [None]:
# LABEL NAME
data_source="wiki"
# feature="only_ja_small"
feature="only_ja_sample"

In [None]:
if data_source=="wiki":
    MAX_LEN=20 #paddingをバッチにかかわらず固定長にする
    TEXT = data.Field(sequential=True, lower=True, include_lengths=True, init_token="<SOS>",eos_token="<EOS>", fix_length=MAX_LEN, batch_first=True)
    train, val, test = data.TabularDataset.splits(
        path='./data/foruse', train="_".join(["train",data_source,feature])+'.tsv',
        validation="_".join(["eval",data_source,feature])+'.tsv',test="_".join(["test",data_source,feature])+'.tsv', 
        format='tsv',
        fields=[('Text', TEXT),])
    TEXT.build_vocab(train, vectors=FastText(language="ja"))
    train_iter, val_iter, test_iter = data.Iterator.splits((train, val, test), batch_sizes=(BATCH_SIZE, BATCH_SIZE, BATCH_SIZE),sort = False, device=torch.device(device))
    vocab_size=TEXT.vocab.vectors.size()[0]

In [None]:
share_conv=nn.Embedding(vocab_size, EMBED_SIZE).to(device)

In [None]:
class Encoder(nn.Module):
    def __init__(self,embed_size,n_hid):
        super().__init__()
        self.lstm = nn.LSTM(input_size = embed_size,
                            hidden_size = n_hid,
                            batch_first=True)
        
    def forward(self, x):
        """
        input: 
            x = (batch, max_len, embed_dim)
        output:
            x = (batch, hidden_size)
        """
        _, (h,c)=self.lstm(x) # h=(max_len, batch, n_hid)
        return (h,c)

In [None]:
class Decoder(nn.Module):
    def __init__(self,embed_size,n_hid,max_len,vocab_size):
        super().__init__()
        self.lstm = nn.LSTM(input_size = embed_size,
                            hidden_size = n_hid,
                            batch_first=True)
        self.fc = nn.Linear(n_hid, vocab_size)
        self.lnorm=nn.LayerNorm(n_hid)
        self.max_len=max_len
        self.embedding=share_conv
        
    def forward(self, hidden, x, teacher):
        """
        input:
            x_emb=(batch, hidden_size, embedding_size)
        output:
            logits = (batch, max_len, vocab_size)
            sentence = (batch, max_len) : 中身はindex
        """
        logits = torch.tensor([]).to(device)
        sentence = torch.tensor([],dtype=torch.long).to(device)
        
        for i in range(self.max_len):
            if teacher or i==0:
                tmp = torch.unsqueeze(x[:,i,:],1) # tmp = (batch, 1, embed_dim)
            else:
                # word = (batch, 1, 1)
                tmp = self.embedding(word) # tmp = (batch, 1, embed_dim)
            x_input = tmp # x_input = (batch, 1, (embed_size + n_hid))
            out, hidden = self.lstm(x_input, hidden)
                # out = (batch, 1, n_hid)
                # hidden = ((batch, 1, n_hid),(batch, 1, n_hid))
            logit = self.fc(out) # logit = (batch, 1, vocab_size)
            word = torch.argmax(logit, dim=2) # word = (batch, 1)

            sentence = torch.cat([sentence, word],1)
            logits = torch.cat([logits,logit],1)   
        return logits, sentence

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self,maxlen):
        super().__init__()
        self.maxlen=maxlen
        self.encoder=Encoder(embed_size=EMBED_SIZE,n_hid=LSTM_HIDDEN_SIZE)
        self.decoder=Decoder(embed_size=EMBED_SIZE,n_hid=LSTM_HIDDEN_SIZE,max_len=MAX_LEN,vocab_size=vocab_size)
        self.embedding=share_conv
        
    def forward(self, x): # x=(batch, max_len)
        x_emb=self.embedding(x)  # x_emb = (batch, maxlen, embed_dim)
        
        hidden = self.encoder(x_emb) # h,c = (batch, hidden_size)
        logits, sentence = self.decoder(hidden, x_emb, teacher=True)
        
        # loss
        criterion = nn.CrossEntropyLoss(ignore_index=1)
        loss = 0
        for i in range(self.maxlen-1):
            # <SOS>を除くためindexをずらす
            loss += criterion(torch.squeeze(logits[:,i,:]), torch.squeeze(x[:,i+1]))
        loss/=(self.maxlen-1)
        """
        # KL loss
        # 標準正規分布と(μ,σ^2)正規分布を仮定しているので以下の計算式になる
        # nn.klDivLossを使うと仮定が甘い
        # kl_loss = Σ0.5Σ(μ^2+exp(ln(σ^2))-ln(σ^2)-1)を使う
        kl_loss = torch.sum(0.5 * torch.sum((H_mean**2 + torch.exp(H_log_sigma_sq) - H_log_sigma_sq - 1),dim=1))
        loss += epoch/ae_epoch_number*kl_loss
        """
        
        return loss, sentence


In [None]:
def train():
    auto_encoder.train()
    epoch_loss = 0
    count=0
    for idx, batch in enumerate(train_iter):
        (x, x_l) = batch.Text
            # xには文章のID表現が、x_lにはxの単語数が入る
            # x=(batch, max_len)
        if len(x)!=BATCH_SIZE:break
        optimizer.zero_grad()
        loss, syn_sents=auto_encoder(x)
        with torch.autograd.detect_anomaly():
            loss.backward()
            optimizer.step()
        epoch_loss+=loss.item()
        count+=1
    sample_x=x[0][1:]
    source_sentence=' '.join([TEXT.vocab.itos[int(i)] for i in sample_x if i != 1])
    gen_sentence=' '.join([TEXT.vocab.itos[int(i)] for i in syn_sents[0] if i != 1])
    history_train.append(epoch_loss/count)
    if (epoch+1) % REPORT_INTERVAL==0:
        print("epoch: "+str(epoch+1)+'/'+str(AE_EPOCHS)+' ')
        print("training loss: "+str(history_train[epoch]))
#         print("kl_loss: "+str(kl_loss))
        print("source(test): "+str(source_sentence))
        print("result(test): "+str(gen_sentence))


In [None]:
'''
学習を始める
'''
print("start train...")
auto_encoder=AutoEncoder(maxlen=MAX_LEN)
auto_encoder.to(device)
optimizer = optim.Adam(auto_encoder.parameters(), lr=ae_lr)
history_train=[]
history_eval=[]

for epoch in range(AE_EPOCHS):
    train()