In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.normal import Normal as Gaussian

import torchtext
from torchtext import data, datasets
from torchtext.vocab import Vectors

cuda = True if torch.cuda.is_available() else False #CUDA使用可能でTrue
device="cuda:0" if cuda else "cpu"
# randomを固定
SEED = 1
# random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

import matplotlib.pyplot as plt
%matplotlib inline
UNK_token=0
PAD_token=1
SOS_token=2
EOS_token=3

In [2]:
def tokenizer(text):
    malist = text.split()
    return malist

In [3]:
# 定数
MAXLEN=20 #paddingをバッチにかかわらず固定長にする
NUM_EPOCHS = 100
KL_LOSS_ZERO_EPOCH=15
EMBED_SIZE=300
BATCH_SIZE=512
LSTM_HIDDEN_DIM=100
Z_DIM=100
WORD_DROP_RATE=0.5
# data_type="wiki_mecab"+"_fortest"
# data_type="wiki_mecab"+"_small"
data_source="wiki"
splited_type="mecab"
feature="hiragana_kanji_only"
data_type=data_source+"_"+splited_type+"_"+feature
vector="fasttext"
out_dir="./result/lstmvae"

In [4]:
TEXT = data.Field(sequential=True, tokenize=tokenizer,lower=True, include_lengths=True, init_token="<SOS>",eos_token="<EOS>", fix_length=MAXLEN, batch_first=True)
train, val, test = data.TabularDataset.splits(
        path='./data/{}/'.format(data_source), train=data_type+'_train.tsv',
        validation=data_type+'_valid.tsv', test=data_type+'_test.tsv', format='tsv',
        fields=[('Text', TEXT),])


TEXT = data.Field(sequential=True, tokenize=tokenizer,lower=True, include_lengths=True, init_token="<SOS>",eos_token="<EOS>", fix_length=MAXLEN, batch_first=True)
train, val, test = data.TabularDataset.splits(
        path='./data/{}/'.format(data_source), train=data_type+'_train.tsv',
        validation=data_type+'_valid.tsv', test=data_type+'_test.tsv', format='tsv',
        fields=[('Text', TEXT),])
train_iter, val_iter, test_iter = data.Iterator.splits((train, val, test), batch_sizes=(BATCH_SIZE, BATCH_SIZE, BATCH_SIZE),sort = False, device=torch.device(device))

if vector=="fasttext":
    TEXT.build_vocab(train,val,test, vectors=Vectors(name="fasttext.vec"))
    vocab_size=TEXT.vocab.vectors.size()[0]
else:
    TEXT.build_vocab(train,val,test)
    vocab_size=len(TEXT.vocab.itos)

In [5]:
share_emb=nn.Embedding(vocab_size, EMBED_SIZE).to(device)
share_emb.weight.data.copy_(TEXT.vocab.vectors) # set learned vector
# share_emb.weight.requires_grad = False # embedding固定

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.1041,  0.0847, -0.2219,  ...,  0.5101, -0.0855, -0.1055],
        [ 0.1450, -0.3459, -0.1778,  ...,  0.3307,  0.3743,  0.0613],
        [ 0.1914, -0.1240, -0.1963,  ..., -0.2575, -0.1688, -0.1909]])

In [6]:
class RNNEncoder(nn.Module):
    def __init__(self, hidden_dim, z_dim):
        super(RNNEncoder, self).__init__()
        self.lstm = nn.LSTM(EMBED_SIZE, hidden_dim, batch_first=True,bidirectional=True)
        self.fc1=nn.Linear(hidden_dim*4, hidden_dim*2)
        self.fc2=nn.Linear(hidden_dim*2, z_dim)
        self.ReLU=nn.ReLU()

    def forward(self, x_emb):
        _, (h,c) = self.lstm(x_emb) # out, (h=(1,batch,hidden_dim), c)
        hf,hb,cf,cb=h[0],h[1],c[0],c[1] # bidirectionalのそれぞれの最後の隠れ状態をとってくる
        state=torch.cat([hf,hb,cf,cb],dim=-1)
        output=self.fc1(state) #  #(batch, hidden*4) -> (batch, z_dim)
        output=self.fc2(self.ReLU(output))
        return output

In [7]:
class vae_classifier_2layer(nn.Module):
    def __init__(self,z_dim):
        super().__init__()
        self.fc11 = nn.Linear(z_dim, z_dim)
        self.fc12 = nn.Linear(z_dim, z_dim)
        self.ReLU=nn.ReLU()

    def forward(self, h):
        mean = self.fc11(self.ReLU(h))
        log_sigma_sq = self.fc12(self.ReLU(h))
        return mean, log_sigma_sq

In [8]:
class RNNDecoder(nn.Module):
    def __init__(self, hidden_dim, z_dim):
        super(RNNDecoder, self).__init__()
        self.maxlen=MAXLEN
        self.embedding=share_emb
        self.fc11=nn.Linear(z_dim, hidden_dim*2)
        self.fc12=nn.Linear(z_dim, hidden_dim*2)
        self.lstm = nn.LSTM(EMBED_SIZE, hidden_dim*2, batch_first=True)
        self.fc2=nn.Linear(hidden_dim*2, vocab_size)
        self.tanh=nn.Tanh()
        
    def forward(self, x_emb, h,is_prob=False): #h=(1,batch,hidden_dim)
        bsize=len(x_emb)
        logits = torch.tensor([]).to(device)
        sentence = torch.tensor([],dtype=torch.long).to(device)
        h0=self.fc11(h)
        c0=self.fc12(h)
        hidden=(h0,c0)
        for i in range(self.maxlen):
            _teacher=False if is_prob or torch.rand(1)<WORD_DROP_RATE else True
            if _teacher or i==0: # 教師入力するならここに
                lstm_input = torch.unsqueeze(x_emb[:,i,:],1)
            else:
                lstm_input=self.embedding(word)
#                 lstm_input=torch.zeros(bsize,1,EMBED_SIZE).to(device) # word drop
            lstm_input=F.normalize(lstm_input,dim=-1)
            out,hidden=self.lstm(lstm_input, hidden)
                # out = (batch, 1, n_hid)
                # hidden = ((batch, 1, n_hid),(batch, 1, n_hid))
            logit = self.fc2(out) # logit = (batch, 1, vocab_size)
            word = torch.argmax(logit, dim=2)# word = (batch, 1)
        
            sentence = torch.cat([sentence, word],1)
            logits = torch.cat([logits,logit],1)   
        return logits, sentence

In [19]:
class VAE(nn.Module):
    def __init__(self, hidden_dim, z_dim):
        super(VAE,self).__init__()
        self.z_dim=z_dim
        self.embedding=share_emb
        self.maxlen=MAXLEN
        self.vae_classifier=vae_classifier_2layer(z_dim)
        self.RNNEncoder = RNNEncoder(hidden_dim, z_dim)
        self.RNNDecoder = RNNDecoder(hidden_dim, z_dim)

    def forward(self, x,is_prob=False): # x = (batch, max_len)
        bsize=len(x)
        x_emb = self.embedding(x)
        x_emb=F.normalize(x_emb,dim=-1)
        H=self.RNNEncoder(x_emb)
        
        H_mean, H_log_sigma_sq = self.vae_classifier(H) # H_mean, H_log_sigma_sq= (batch, z_dim)
        eps = torch.empty(bsize, self.z_dim).normal_(mean=0,std=1).to(device) # N(0, 1)
        H_dec = H_mean + eps * torch.sqrt(torch.exp(H_log_sigma_sq)) # H_dec = (batch, z_dim)
        H_dec=torch.unsqueeze(H_dec,0)
        logits, sentence=self.RNNDecoder(x_emb,H_dec,is_prob=is_prob)
            # logits = (batch, max_len, vocab_size)
            # sentence = (batch, max_len)
            # 中身はindex

        # loss
        seq_logits = logits[:,:-1,:] # 一番最後の出力は無意味
        target = x[:,1:] # 一番最初の<SOS>を取り除く
        criterion = nn.CrossEntropyLoss(ignore_index=1)
        loss = 0
        for i in range(self.maxlen-1):
            loss += criterion(torch.squeeze(seq_logits[:,i,:]), torch.squeeze(target[:,i]))
        loss/=(self.maxlen-1)
        # KL loss
        # 標準正規分布と(μ,σ^2)正規分布を仮定しているので以下の計算式になる
        # nn.klDivLossを使うと仮定が甘い
        # kl_loss = Σ0.5Σ(μ^2+exp(ln(σ^2))-ln(σ^2)-1)を使う
        kl_loss = torch.sum(0.5 * torch.sum((H_mean**2 + torch.exp(H_log_sigma_sq) - H_log_sigma_sq - 1),dim=1))
        
        return loss, kl_loss, sentence


In [20]:
def train(epoch):
    model.train()
    epoch_loss = 0
    count=0
    for idx, batch in enumerate(train_iter):
        (x, x_l) = batch.Text
            # xには文章のID表現が、x_lにはxの単語数が入る
            # x=(batch, max_len)
        optimizer.zero_grad()
        loss, kl_loss, syn_sents=model(x)
        all_loss=loss
        if epoch>=KL_LOSS_ZERO_EPOCH:
            all_loss+=kl_loss*(epoch-KL_LOSS_ZERO_EPOCH)/(NUM_EPOCHS-KL_LOSS_ZERO_EPOCH)
        with torch.autograd.detect_anomaly(): #NaN検出
            all_loss.backward()
            optimizer.step()
        epoch_loss+=loss.item()
        count+=len(x)
        with open(out_dir+"/train_sentences/{:0>3}.txt".format(epoch),"w") as f:
            for j in range(len(syn_sents)):
                word_list=[]
                for i in syn_sents[j]:
                    if i==EOS_token:
                        break
                    if i!=PAD_token:
                        word_list.append(TEXT.vocab.itos[int(i)])
                sent=' '.join(word_list)
                sent+="\n"
                f.write(sent)
    return epoch_loss/count

def val(epoch):
    model.eval()
    epoch_loss = 0
    count=0
    for idx, batch in enumerate(val_iter):
        (x, x_l) = batch.Text
            # xには文章のID表現が、x_lにはxの単語数が入る
            # x=(batch, max_len)
        loss,kl_loss, syn_sents=model(x)
        all_loss=loss

        epoch_loss+=loss.item()
        count+=len(x)
        with open(out_dir+"/validation_sentences/{:0>3}.txt".format(epoch),"w") as f:
            for j in range(len(syn_sents)):
                word_list=[]
                for i in syn_sents[j]:
                    if i==EOS_token:
                        break
                    if i!=PAD_token:
                        word_list.append(TEXT.vocab.itos[int(i)])
                sent=' '.join(word_list)
                sent+="\n"
                f.write(sent)
    return epoch_loss/count

In [21]:
model=VAE(hidden_dim=LSTM_HIDDEN_DIM,z_dim=Z_DIM).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-2)
train_loss_history = []
valid_loss_history = []
for epoch in range(1, NUM_EPOCHS + 1):
    train_loss = train(epoch)
    valid_loss = val(epoch)
    if epoch%1==0:
        print('epoch [{}/{}], loss: {:.4f} test_loss: {:.4f}'.format(
            epoch,
            NUM_EPOCHS,
            train_loss,
            valid_loss))

    # logging
    train_loss_history.append(train_loss)
    valid_loss_history.append(valid_loss)

# save the training model
torch.save(model.state_dict(), out_dir+'/lstmvae.model')

epoch [1/100], loss: 0.0002 test_loss: 0.0002
epoch [2/100], loss: -0.0000 test_loss: -0.0006
epoch [3/100], loss: -0.0020 test_loss: -0.0124
epoch [4/100], loss: -0.0111 test_loss: -0.0292
epoch [5/100], loss: -0.0223 test_loss: -0.0488
epoch [6/100], loss: -0.0342 test_loss: -0.0805
epoch [7/100], loss: -0.0538 test_loss: -0.1014




RuntimeError: Function 'ExpBackward' returned nan values in its 0th output.

In [None]:
x = [i for i in range(len(train_loss_history))]
plt.plot(x, train_loss_history, label="train_loss")
plt.plot(x, valid_loss_history, label="valid_loss")
plt.ylim(0,0.02)
plt.legend()
plt.show()

In [None]:
def test():
    model.eval()
    batch =next(iter(val_iter))
    (x, x_l) = batch.Text
    bsize=len(x)
    x_emb = model.embedding(x)
    x_emb=F.normalize(x_emb,dim=-1)

    H_dec = torch.empty(bsize, model.z_dim).normal_(mean=0,std=3).to(device) # N(0, 1)
    H_dec=torch.unsqueeze(H_dec,0)
    logits, sentence=model.RNNDecoder(x_emb,H_dec,is_prob=True)
        # logits = (batch, max_len, vocab_size)
        # sentence = (batch, max_len)
        # 中身はindex
    for j in range(len(sentence)):
        word_list=[]
        for i in sentence[j]:
            if i==EOS_token:
                break
            if i!=PAD_token:
                word_list.append(TEXT.vocab.itos[int(i)])
        sent=' '.join(word_list)
        if j<100:
            print(sent)
test()

## TODO
lr_schedulerでkl_loss追加前後だけlrを下げる

In [81]:
ans=0
for idx, batch in enumerate(train_iter):
    (x, x_l) = batch.Text
    ans+=len(x)
print(ans)

10000
