In [18]:
#!/usr/bin/env python
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# from torch.nn import TransformerEncoder, TransformerEncoderLayer

import torchtext
from torchtext import data, datasets
from torchtext.vocab import FastText

import os
import math
import MeCab
import shutil

device="cuda:0" if torch.cuda.is_available() else "cpu"
SEED = 1
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
mecab = MeCab.Tagger("-Owakati")

# BLEU

from Util.selfbleu import CalcSelfBLEU
from Util.bleu import calc_all_bleu

from matplotlib import pyplot as plt
%matplotlib inline

In [19]:
# HYPER PARAMETER
BATCH_SIZE=100
EMBED_SIZE=300
LSTM_HIDDEN_SIZE=300
AE_EPOCHS=100
REPORT_INTERVAL=1
ae_lr=1e-2

In [44]:
# LABEL NAME
data_source="wiki"
feature="only_ja_small"
feature="only_ja_sample"
# feature="all"

log_dir='./log/{}/'.format("_".join(["train",data_source,feature]))
if os.path.isdir(log_dir):
    shutil.rmtree(log_dir)
os.mkdir(log_dir)

In [21]:
if data_source=="wiki" or data_source=="orphans":
    MAX_LEN=20 #paddingをバッチにかかわらず固定長にする
    TEXT = data.Field(sequential=True, lower=True, include_lengths=True, init_token="<SOS>",eos_token="<EOS>", fix_length=MAX_LEN, batch_first=True)
    train, val, test = data.TabularDataset.splits(
        path='./data/foruse', train="_".join(["train",data_source,feature])+'.tsv',
        validation="_".join(["eval",data_source,feature])+'.tsv',test="_".join(["test",data_source,feature])+'.tsv', 
        format='tsv',
        fields=[('Text', TEXT),])
    TEXT.build_vocab(train, vectors=FastText(language="ja"))
    train_iter, val_iter, test_iter = data.Iterator.splits((train, val, test), batch_sizes=(BATCH_SIZE, BATCH_SIZE, BATCH_SIZE),sort = False, device=torch.device(device))
    vocab_size=TEXT.vocab.vectors.size()[0]

In [22]:
share_conv=nn.Embedding(vocab_size, EMBED_SIZE).to(device)

In [23]:
class Encoder(nn.Module):
    def __init__(self,embed_size,n_hid):
        super().__init__()
        self.lstm = nn.LSTM(input_size = embed_size,
                            hidden_size = n_hid,
                            batch_first=True)
        
    def forward(self, x):
        """
        input: 
            x = (batch, max_len, embed_dim)
        output:
            x = (batch, hidden_size)
        """
        _, (h,c)=self.lstm(x) # h=(max_len, batch, n_hid)
        return (h,c)

In [24]:
class vae_classifier_2layer(nn.Module):
    def __init__(self,n_hid):
        super().__init__()
        self.fc = nn.Linear(n_hid*2, n_hid)
        self.fcmean = nn.Linear(n_hid, n_hid)
        self.fcvar = nn.Linear(n_hid, n_hid)
        self.ReLU=nn.ReLU()
        self.lnorm=nn.LayerNorm(n_hid)

    def forward(self, hidden):# h,c = (batch, hidden_size)
        h, c = hidden
        tmp = self.ReLU(self.lnorm(self.fc(torch.cat([h,c],dim=-1))))
        mean = self.fcmean(tmp)
        log_sigma_sq = self.fcvar(tmp)
        return mean, log_sigma_sq


In [25]:
class Decoder(nn.Module):
    def __init__(self,embed_size,n_hid,max_len,vocab_size):
        super().__init__()
        self.lstm = nn.LSTM(input_size = embed_size,
                            hidden_size = n_hid,
                            batch_first=True)
        self.fc = nn.Linear(n_hid, vocab_size)
        self.max_len=max_len
        self.embedding=share_conv
        
    def forward(self, hidden, x, teacher):
        """
        input:
            x_emb=(batch, hidden_size, embedding_size)
        output:
            logits = (batch, max_len, vocab_size)
            sentence = (batch, max_len) : 中身はindex
        """
        logits = torch.tensor([]).to(device)
        sentence = torch.tensor([],dtype=torch.long).to(device)
        
        for i in range(self.max_len):
            if teacher or i==0:
                tmp = torch.unsqueeze(x[:,i,:],1) # tmp = (batch, 1, embed_dim)
            else:
                # word = (batch, 1, 1)
                tmp = self.embedding(word) # tmp = (batch, 1, embed_dim)
                x_emb = F.normalize(tmp)
            x_input = tmp # x_input = (batch, 1, (embed_size + n_hid))
            out, hidden = self.lstm(x_input, hidden)
                # out = (batch, 1, n_hid)
                # hidden = ((batch, 1, n_hid),(batch, 1, n_hid))
            logit = self.fc(out) # logit = (batch, 1, vocab_size)
            word = torch.argmax(logit, dim=2) # word = (batch, 1)

            sentence = torch.cat([sentence, word],1)
            logits = torch.cat([logits,logit],1)   
        return logits, sentence

In [26]:
class AutoEncoder(nn.Module):
    def __init__(self,maxlen):
        super().__init__()
        self.n_hid=LSTM_HIDDEN_SIZE
        self.maxlen=maxlen
        self.encoder=Encoder(embed_size=EMBED_SIZE,n_hid=self.n_hid)
        self.vae_classifer=vae_classifier_2layer(n_hid=self.n_hid)
        self.decoder=Decoder(embed_size=EMBED_SIZE,n_hid=self.n_hid,max_len=MAX_LEN,vocab_size=vocab_size)
        self.embedding=share_conv
        self.fc1 = nn.Linear(self.n_hid, self.n_hid)
        self.fc2 = nn.Linear(self.n_hid, self.n_hid)
        
    def forward(self, x): # x=(batch, max_len)
        x_emb=self.embedding(x)  # x_emb = (batch, maxlen, embed_dim)
        x_emb = F.normalize(x_emb)
        
        hidden = self.encoder(x_emb) # h,c = (1, batch, hidden_size)
        mean, log_sigma_sq=self.vae_classifer(hidden)
        eps = torch.empty(len(x), self.n_hid).normal_(mean=0,std=1).to(device) # N(0, 1)
        h = mean + eps * torch.sqrt(torch.exp(log_sigma_sq)) # H_dec = (1, batch, n_gan)
        h,c=self.fc1(h), self.fc2(h)
        logits, sentence = self.decoder((h,c), x_emb, teacher=True)
        
        # loss
        criterion = nn.CrossEntropyLoss(ignore_index=1)
        loss = 0
        for i in range(self.maxlen-1):
            # <SOS>を除くためindexをずらす
            loss += criterion(torch.squeeze(logits[:,i,:]), torch.squeeze(x[:,i+1]))
        loss/=(self.maxlen-1)
        
        # KL loss
        # 標準正規分布と(μ,σ^2)正規分布を仮定しているので以下の計算式になる
        # nn.klDivLossを使うと仮定が甘い
        # kl_loss = Σ0.5Σ(μ^2+exp(ln(σ^2))-ln(σ^2)-1)を使う
        kl_loss = torch.sum(0.5 * torch.sum((mean**2 + torch.exp(log_sigma_sq) - log_sigma_sq - 1),dim=1))
        loss += epoch/AE_EPOCHS*kl_loss
        
        return loss, sentence


In [39]:
def change_to_sent(sents,j):
    word_list=[]
    for i in sents[j]:
        if i==TEXT.vocab.stoi["<EOS>"]:
            break
        if i!=TEXT.vocab.stoi["<PAD>"]:
            word_list.append(TEXT.vocab.itos[int(i)])
    return word_list

def write_out(url, origin_sents, syn_sents):
    with open(url, "a") as f:
        for j in range(len(syn_sents)):
            f.write("input : "+" ".join(change_to_sent(origin_sents,j))+"\n")
            f.write("output: "+" ".join(change_to_sent(syn_sents,j))+"\n")

In [45]:
def train():
    auto_encoder.train()
    epoch_loss = 0
    count=0
    for idx, batch in enumerate(train_iter):
        (x, x_l) = batch.Text
            # xには文章のID表現が、x_lにはxの単語数が入る
            # x=(batch, max_len)
        if len(x)!=BATCH_SIZE:break
        optimizer.zero_grad()
        loss, syn_sents=auto_encoder(x)
        with torch.autograd.detect_anomaly():
            loss.backward()
            optimizer.step()
        epoch_loss+=loss.item()
        count+=1
        
    sample_x=x[0][1:]
    source_sentence=' '.join([TEXT.vocab.itos[int(i)] for i in sample_x if i != 1])
    gen_sentence=' '.join([TEXT.vocab.itos[int(i)] for i in syn_sents[0] if i != 1])

    write_out(log_dir+"{:03}.txt".format(epoch), x[:,1:], syn_sents)

    history_train.append(epoch_loss/count)
    if (epoch+1) % REPORT_INTERVAL==0:
        print("epoch: "+str(epoch+1)+'/'+str(AE_EPOCHS)+' ')
        print("training loss: "+str(history_train[epoch]))
#         print("kl_loss: "+str(kl_loss))
        print("source(train): "+str(source_sentence))
        print("result(train): "+str(gen_sentence))


In [46]:
'''
学習を始める
'''
print("start train...")
auto_encoder=AutoEncoder(maxlen=MAX_LEN)
auto_encoder.to(device)
optimizer = optim.Adam(auto_encoder.parameters(), lr=ae_lr)
history_train=[]
history_eval=[]

for epoch in range(AE_EPOCHS):
    train()

start train...
epoch: 1/100 
training loss: 5.864137172698975
source(train): 逆 に マルタ は 長い 間 アフリカ に 属する 島 と 受け止め られ て い た <EOS>
result(train): 性 性 あり 予算 その後 し いう これ 石材 街 語族 また また 述べる 分け 今日 し し 今日 今日
epoch: 2/100 
training loss: 397.9302978515625
source(train): 数 多く の 病院 が パリ に 設置 さ れ て いる <EOS>
result(train): この この は は は <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
epoch: 3/100 
training loss: 140.99566650390625
source(train): 自立 語 は 活用 の ない もの と 活用 の ある もの と に 分け られる <EOS>
result(train): この の の <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
epoch: 4/100 
training loss: 490.1822814941406
source(train): この ギャル 文字 を 練習 する ため の 本 も 現れ た <EOS>
result(train): 尊敬 手法 変わっ 発音 母語 母語 <EOS> 属する <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
epoch: 5/100 
training loss: 274.2007751464844
source(train): 日本 で 生まれ 育っ た ほとんど の 人 は 日本語 を 母語 と する <EOS>
result(train): <EOS> <EOS> ある 

In [47]:
history_train

[5.864137172698975,
 397.9302978515625,
 140.99566650390625,
 490.1822814941406,
 274.2007751464844,
 167.2931671142578,
 195.38124084472656,
 191.14199829101562,
 123.7173080444336,
 112.40652465820312,
 104.27486419677734,
 79.95535278320312,
 82.34085083007812,
 72.95103454589844,
 57.52729797363281,
 60.771663665771484,
 50.45553970336914,
 43.89338302612305,
 43.9288444519043,
 38.214881896972656,
 33.027042388916016,
 31.795289993286133,
 28.54407501220703,
 26.149019241333008,
 23.388418197631836,
 20.770788192749023,
 21.38357162475586,
 16.788061141967773,
 16.18769645690918,
 15.893067359924316,
 14.711787223815918,
 17.838897705078125,
 28.916244506835938,
 36.587608337402344,
 22.129924774169922,
 9.178131103515625,
 18.452518463134766,
 24.651037216186523,
 12.983352661132812,
 6.618736743927002,
 16.245880126953125,
 16.10665512084961,
 6.281676292419434,
 6.945587635040283,
 13.212850570678711,
 10.889521598815918,
 5.155546188354492,
 4.321680545806885,
 9.1200160980224

## fasttext確認用

In [14]:
nihongo_id=0

In [125]:
nihongo_id+=1
norm_list=[]
word=TEXT.vocab.itos[nihongo_id]
nihongo_vector=share_conv.weight[nihongo_id]
for i, val in enumerate(share_conv.weight):
    diff=nihongo_vector-val
    norm=torch.norm(diff)
    norm_list.append((norm,i))
norm_list.sort()
print("「"+word+"」の似ている単語")
for i in range(1,11):
    val=norm_list[i][0]
    id_=norm_list[i][1]
    print(TEXT.vocab.itos[id_]+":"+str(val))

「方言」の似ている単語
方法:tensor(21.5570, device='cuda:0', grad_fn=<NormBackward0>)
ひとつ:tensor(21.7216, device='cuda:0', grad_fn=<NormBackward0>)
うち:tensor(21.8818, device='cuda:0', grad_fn=<NormBackward0>)
独自:tensor(21.9748, device='cuda:0', grad_fn=<NormBackward0>)
県庁:tensor(22.0037, device='cuda:0', grad_fn=<NormBackward0>)
品詞:tensor(22.0942, device='cuda:0', grad_fn=<NormBackward0>)
ゆれ:tensor(22.1198, device='cuda:0', grad_fn=<NormBackward0>)
日干し:tensor(22.1313, device='cuda:0', grad_fn=<NormBackward0>)
にくい:tensor(22.2012, device='cuda:0', grad_fn=<NormBackward0>)
生活:tensor(22.2619, device='cuda:0', grad_fn=<NormBackward0>)
