# Data Processing

## 오픈 도메인 대화 테스크를 위한 데이터 처리
### dataset : https://github.com/songys/Chatbot_data

In [8]:
import sentencepiece as spm
import pandas as pd
import numpy as np

train_data = pd.read_csv('./data/dataset/ChatbotData.csv')
train_data.head()

Unnamed: 0,Q,A,label
0,12시 땡!,하루가 또 가네요.,0
1,1지망 학교 떨어졌어,위로해 드립니다.,0
2,3박4일 놀러가고 싶다,여행은 언제나 좋죠.,0
3,3박4일 정도 놀러가고 싶다,여행은 언제나 좋죠.,0
4,PPL 심하네,눈살이 찌푸려지죠.,0


## sentence piece 로 vocab생성
## SentencePiece: A simple and language independent subword tokenizer and detokenizer for Neural Text Processing
### * Taku Kudo, John Richardson, Google

RNN은 기본적으로 vocab의 크기가 계산량에 영향을 주고 있습니다.
그래서 적당한 크기의 vocab을 사용하게 됩니다. 문제는 여기서 많이 발생합니다.
우리는 vocab을 만들때 미등록 단어가 발생하게 되고 실제로 입력으로 들어왔을때 UNK토큰으로 대체하게 됩니다.
이 과정에서 정보의 손실이 발생하고 성능의 문제를 일으킬수 있습니다.
그런 점을 보완하고자 sentencepiece를 tokenizer로 사용하려고 합니다.
sentencepiece의 기본 아이디어는 단어(word)의 부분단어(subword)로 모든 단어를 표현하고자 하는게 아이디어입니다.
이때 사용하는게 단어들의 빈도수를 사용하여 subword로 나눌지 말지를 판단하게 됩니다.

In [9]:
corpus = "data/dataset/chit-chat_corpus.txt"
prefix = "chatbot"
vocab_size = 16000
spm.SentencePieceTrainer.train(
    f"--input={corpus} --model_prefix={prefix} --vocab_size={vocab_size + 7}" + 
    " --model_type=bpe" +
    " --max_sentence_length=999999" + # 문장 최대 길이
    " --pad_id=0 --pad_piece=[PAD]" + # pad (0)
    " --unk_id=1 --unk_piece=[UNK]" + # unknown (1)
    " --bos_id=2 --bos_piece=[BOS]" + # begin of sequence (2)
    " --eos_id=3 --eos_piece=[EOS]" + # end of sequence (3)
    " --user_defined_symbols=[SEP],[CLS],[MASK]") # 사용자 정의 토큰

# Load & Test

In [10]:
vocab_file = "chatbot.model"
vocab = spm.SentencePieceProcessor()
vocab.load(vocab_file)
line = "3박4일 정도 놀러가고 싶다"
pieces = vocab.encode_as_pieces(line)
ids = vocab.encode_as_ids(line)


print(line)
print(pieces)
print(ids)

3박4일 정도 놀러가고 싶다
['▁3', '박', '4', '일', '▁정도', '▁놀러가고', '▁싶다']
[473, 15432, 15399, 14972, 982, 3503, 201]


In [11]:
import os
import sys
import json
import torch
import random
import torch.utils.data as data
import numpy as np
import pandas as pd

from torch.autograd import Variable 
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

from tqdm import tqdm
from tqdm import trange
import torch.nn.functional as F
#from torch.utils.tensorboard import SummaryWriter

from src.model import save

In [12]:
class Preprocessing:
    '''
    데이터의 최대 token길이가 10이지만
    실제 환경에서는 얼마의 길이가 들어올지 몰라 적당한 길이 부여
    '''
    
    def __init__(self, max_len = 20):
        self.max_len = max_len
        self.PAD = 0
    
    def pad_idx_sequencing(self, q_vec):
        q_len = len(q_vec)
        diff_len = q_len - self.max_len
        if(diff_len>0):
            q_vec = q_vec[:self.max_len]
            q_len = self.max_len
        else:
            pad_vac = [0] * abs(diff_len)
            q_vec += pad_vac

        return q_vec
    
    def make_batch(self):
        pass

class ChitChatDataset(data.Dataset):
    def __init__(self, x_tensor, y_tensor, labels):
        super(ChitChatDataset, self).__init__()

        self.x = x_tensor
        self.y = y_tensor
        self.labels = labels
        
    def __getitem__(self, index):
        return self.x[index], self.y[index], self.labels[index]

    def __len__(self):
        return len(self.x)
    
class MakeDataset:
    def __init__(self):
        
        self.chitchat_data_dir = "./data/dataset/ChatbotData.csv"
        
        self.prep = Preprocessing()
        vocab_file = "chatbot.model"
        self.transformers_tokenizer = spm.SentencePieceProcessor()
        self.transformers_tokenizer.load(vocab_file)
    
    def encode_dataset(self, dataset):
        token_dataset = []
        for data in dataset:
            token_dataset.append( [2] + self.transformers_tokenizer.encode_as_ids(data) + [3])
        return token_dataset

    def make_chitchat_dataset(self, train_ratio = 0.8):
        chitchat_dataset = pd.read_csv(self.chitchat_data_dir)
        Qs = chitchat_dataset["Q"].tolist()
        As = chitchat_dataset["A"].tolist()
        label = chitchat_dataset["label"].tolist()
        
        Qs = self.encode_dataset(Qs)
        As = self.encode_dataset(As)
        
        self.prep.max_len = 40
        x, y = [], []
        for q, a in zip(Qs,As):
            x.append(self.prep.pad_idx_sequencing(q))
            y.append(self.prep.pad_idx_sequencing(a))
        x = torch.tensor(x)
        y = torch.tensor(y)
        x_len = x.size()[0]
        train_size = int(x_len*train_ratio)
        
        if(train_ratio == 1.0):
            train_x = x[:train_size]
            train_y = y[:train_size]
            train_label = label[:train_size]
            train_dataset = ChitChatDataset(train_x,train_y,train_label)
            return train_dataset, None
        else:
            train_x = x[:train_size]
            train_y = y[:train_size]
            train_label = label[:train_size]

            test_x = x[train_size+1:]
            test_y = y[train_size+1:]
            test_label = label[train_size+1:]

            train_dataset = ChitChatDataset(train_x,train_y,train_label)
            test_dataset = ChitChatDataset(test_x,test_y,test_label)

            return train_dataset, test_dataset

In [14]:
dataset = MakeDataset()

train_dataset, test_dataset = dataset.make_chitchat_dataset(1.0)

train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
#test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

# Attention Is All You Need
## * Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
### tensorflow transformer chatbot code : https://blog.tensorflow.org/2019/05/transformer-chatbot-tutorial-with-tensorflow-2.html

In [15]:
from torch.nn import Transformer
from torch import nn
import torch
import math
from tqdm import tqdm
class Tformer(nn.Module):
    def __init__(self, num_tokens, dim_model, num_heads, dff, num_layers, dropout_p=0.5):
        super(Tformer, self).__init__()
        self.transformer = Transformer(dim_model, num_heads, dim_feedforward=dff, num_encoder_layers=num_layers, num_decoder_layers=num_layers,dropout=dropout_p)
        self.pos_encoder = PositionalEncoding(dim_model, dropout_p)
        self.encoder = nn.Embedding(num_tokens, dim_model)

        self.pos_encoder_d = PositionalEncoding(dim_model, dropout_p)
        self.encoder_d = nn.Embedding(num_tokens, dim_model)

        self.dim_model = dim_model
        self.num_tokens = num_tokens

        self.linear = nn.Linear(dim_model, num_tokens)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, srcmask, tgtmask, srcpadmask, tgtpadmask):
        src = self.encoder(src) * math.sqrt(self.dim_model)
        src = self.pos_encoder(src)

        tgt = self.encoder_d(tgt) * math.sqrt(self.dim_model)
        tgt = self.pos_encoder_d(tgt)

        output = self.transformer(src.transpose(0,1), tgt.transpose(0,1), srcmask, tgtmask, src_key_padding_mask=srcpadmask, tgt_key_padding_mask=tgtpadmask)
        output = self.linear(output)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

def gen_attention_mask(x):
    mask = torch.eq(x, 0)
    return mask

In [76]:
model = Tformer(
     num_tokens=vocab_size+7, dim_model=256, num_heads=8, dff=512, num_layers=2, dropout_p=0.1
 ).cuda()

In [77]:
lr = 1e-4
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
MAX_LENGTH = 40

In [78]:
epoch = 70
save_dir = "./data/pretraining/4_chitchat_transformer_model/"
save_prefix = "chitchat_transformer"
prev_loss_all = float("inf")
train_steps = 0
test_steps = 0
model.train()
for i in range(epoch):
    batchloss = 0.0
    progress = tqdm(train_dataloader)
    for (inputs, y, _) in progress:
        optimizer.zero_grad()

        dec_inputs = y[:,:-1]
        outputs = y[:,1:]
        
        src_mask = model.generate_square_subsequent_mask(MAX_LENGTH).cuda()
        src_padding_mask = gen_attention_mask(inputs).cuda()
        tgt_mask = model.generate_square_subsequent_mask(MAX_LENGTH-1).cuda()
        tgt_padding_mask = gen_attention_mask(dec_inputs).cuda()

        result = model(inputs.long().cuda(), dec_inputs.long().cuda(), src_mask, tgt_mask, src_padding_mask,tgt_padding_mask)
        loss = criterion(result.permute(1,2,0), outputs.long().cuda())
        progress.set_description("{:0.3f}".format(loss))

        train_steps += 1
        loss.backward()
        optimizer.step()
        batchloss += loss
    
    print("train epoch:",i+1,"|","loss:",batchloss.cpu().item() / len(train_dataloader))

#     model.eval()
#     test_batchloss = 0.0
#     progress_test = tqdm(test_dataloader)
#     for (inputs, y, _) in progress_test:

#         dec_inputs = y[:,:-1]
#         outputs = y[:,1:]
        
#         src_mask = model.generate_square_subsequent_mask(MAX_LENGTH).cuda()
#         src_padding_mask = gen_attention_mask(inputs).cuda()
#         tgt_mask = model.generate_square_subsequent_mask(MAX_LENGTH-1).cuda()
#         tgt_padding_mask = gen_attention_mask(dec_inputs).cuda()

#         result = model(inputs.long().cuda(), dec_inputs.long().cuda(), src_mask, tgt_mask, src_padding_mask,tgt_padding_mask)
 
#         loss = criterion(result.permute(1,2,0), outputs.long().cuda())
#         progress_test.set_description("{:0.3f}".format(loss.cpu().item()))

#         test_steps += 1
#         test_batchloss += loss.cpu().item()
#     loss_all = test_batchloss/len(test_dataloader)
#     print("test epoch:",i+1,"|","loss:",loss_all)
#     model.train()
#     if(loss_all<prev_loss_all):
#         prev_loss_all = loss_all
#         save(model, save_dir, save_prefix + "_" + str(round(loss_all,6)), i)

1.281: 100%|██████████| 93/93 [00:11<00:00,  8.12it/s]
1.356:   1%|          | 1/93 [00:00<00:11,  8.33it/s]

train epoch: 1 | loss: 2.3234646704889115


0.911: 100%|██████████| 93/93 [00:11<00:00,  8.11it/s]
1.025:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 2 | loss: 1.1157909106182795


1.004: 100%|██████████| 93/93 [00:11<00:00,  8.30it/s]
0.916:   1%|          | 1/93 [00:00<00:11,  8.26it/s]

train epoch: 3 | loss: 0.9776405416509156


0.880: 100%|██████████| 93/93 [00:11<00:00,  8.27it/s]
0.935:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 4 | loss: 0.9354389970020581


0.837: 100%|██████████| 93/93 [00:11<00:00,  8.29it/s]
0.884:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 5 | loss: 0.9105411652595766


0.967: 100%|██████████| 93/93 [00:11<00:00,  8.28it/s]
0.831:   1%|          | 1/93 [00:00<00:11,  8.26it/s]

train epoch: 6 | loss: 0.8927132391160534


0.932: 100%|██████████| 93/93 [00:11<00:00,  8.28it/s]
0.826:   1%|          | 1/93 [00:00<00:11,  8.26it/s]

train epoch: 7 | loss: 0.876556888703377


0.970: 100%|██████████| 93/93 [00:11<00:00,  8.27it/s]
0.832:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 8 | loss: 0.8623861497448336


0.807: 100%|██████████| 93/93 [00:11<00:00,  8.26it/s]
0.868:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 9 | loss: 0.848002197921917


0.884: 100%|██████████| 93/93 [00:11<00:00,  8.22it/s]
0.801:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 10 | loss: 0.8365371868174564


0.815: 100%|██████████| 93/93 [00:11<00:00,  8.15it/s]
0.783:   1%|          | 1/93 [00:00<00:11,  8.00it/s]

train epoch: 11 | loss: 0.8242870864047799


0.709: 100%|██████████| 93/93 [00:11<00:00,  8.18it/s]
0.754:   1%|          | 1/93 [00:00<00:11,  8.00it/s]

train epoch: 12 | loss: 0.8122200094243531


0.828: 100%|██████████| 93/93 [00:11<00:00,  8.17it/s]
0.745:   1%|          | 1/93 [00:00<00:11,  8.00it/s]

train epoch: 13 | loss: 0.8009871616158434


0.734: 100%|██████████| 93/93 [00:11<00:00,  8.09it/s]
0.789:   1%|          | 1/93 [00:00<00:11,  8.26it/s]

train epoch: 14 | loss: 0.7890248452463458


0.727: 100%|██████████| 93/93 [00:11<00:00,  8.13it/s]
0.788:   1%|          | 1/93 [00:00<00:11,  7.94it/s]

train epoch: 15 | loss: 0.7768357799899194


0.869: 100%|██████████| 93/93 [00:11<00:00,  8.09it/s]
0.744:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 16 | loss: 0.766856942125546


0.680: 100%|██████████| 93/93 [00:11<00:00,  8.21it/s]
0.746:   1%|          | 1/93 [00:00<00:11,  8.26it/s]

train epoch: 17 | loss: 0.7539601479807208


0.867: 100%|██████████| 93/93 [00:11<00:00,  8.12it/s]
0.726:   1%|          | 1/93 [00:00<00:11,  8.06it/s]

train epoch: 18 | loss: 0.743398933000462


0.809: 100%|██████████| 93/93 [00:11<00:00,  8.10it/s]
0.761:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 19 | loss: 0.7318354780955981


0.716: 100%|██████████| 93/93 [00:11<00:00,  8.17it/s]
0.657:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 20 | loss: 0.7197605256111391


0.796: 100%|██████████| 93/93 [00:11<00:00,  8.05it/s]
0.713:   1%|          | 1/93 [00:00<00:11,  7.75it/s]

train epoch: 21 | loss: 0.7089048816311744


0.640: 100%|██████████| 93/93 [00:11<00:00,  8.01it/s]
0.665:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 22 | loss: 0.6971188617008989


0.611: 100%|██████████| 93/93 [00:11<00:00,  7.97it/s]
0.720:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 23 | loss: 0.6860599722913516


0.663: 100%|██████████| 93/93 [00:11<00:00,  8.03it/s]
0.656:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 24 | loss: 0.6752831243699596


0.607: 100%|██████████| 93/93 [00:11<00:00,  8.02it/s]
0.620:   1%|          | 1/93 [00:00<00:11,  7.75it/s]

train epoch: 25 | loss: 0.6642456464870001


0.601: 100%|██████████| 93/93 [00:11<00:00,  8.03it/s]
0.629:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 26 | loss: 0.653590335640856


0.577: 100%|██████████| 93/93 [00:11<00:00,  8.01it/s]
0.583:   1%|          | 1/93 [00:00<00:11,  7.81it/s]

train epoch: 27 | loss: 0.6430304742628529


0.703: 100%|██████████| 93/93 [00:11<00:00,  8.17it/s]
0.652:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 28 | loss: 0.6326892401582451


0.637: 100%|██████████| 93/93 [00:11<00:00,  8.15it/s]
0.581:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 29 | loss: 0.6222906625399025


0.590: 100%|██████████| 93/93 [00:11<00:00,  8.16it/s]
0.612:   1%|          | 1/93 [00:00<00:11,  8.07it/s]

train epoch: 30 | loss: 0.6117681277695523


0.673: 100%|██████████| 93/93 [00:11<00:00,  8.17it/s]
0.597:   1%|          | 1/93 [00:00<00:11,  7.75it/s]

train epoch: 31 | loss: 0.602153080765919


0.613: 100%|██████████| 93/93 [00:11<00:00,  8.12it/s]
0.554:   1%|          | 1/93 [00:00<00:11,  8.26it/s]

train epoch: 32 | loss: 0.5921946699901294


0.547: 100%|██████████| 93/93 [00:11<00:00,  8.21it/s]
0.630:   1%|          | 1/93 [00:00<00:11,  8.33it/s]

train epoch: 33 | loss: 0.5819651285807291


0.558: 100%|██████████| 93/93 [00:11<00:00,  8.12it/s]
0.572:   1%|          | 1/93 [00:00<00:11,  8.06it/s]

train epoch: 34 | loss: 0.5722229250015751


0.612: 100%|██████████| 93/93 [00:11<00:00,  8.17it/s]
0.557:   1%|          | 1/93 [00:00<00:11,  7.94it/s]

train epoch: 35 | loss: 0.563101327547463


0.564: 100%|██████████| 93/93 [00:11<00:00,  8.02it/s]
0.558:   1%|          | 1/93 [00:00<00:11,  7.94it/s]

train epoch: 36 | loss: 0.5528981608729209


0.542: 100%|██████████| 93/93 [00:11<00:00,  8.15it/s]
0.547:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 37 | loss: 0.5437782451670657


0.498: 100%|██████████| 93/93 [00:11<00:00,  8.17it/s]
0.527:   1%|          | 1/93 [00:00<00:11,  8.33it/s]

train epoch: 38 | loss: 0.5337357059601815


0.461: 100%|██████████| 93/93 [00:11<00:00,  8.11it/s]
0.536:   1%|          | 1/93 [00:00<00:11,  8.06it/s]

train epoch: 39 | loss: 0.5236127709829679


0.439: 100%|██████████| 93/93 [00:11<00:00,  8.10it/s]
0.557:   1%|          | 1/93 [00:00<00:12,  7.64it/s]

train epoch: 40 | loss: 0.5144784168530536


0.535: 100%|██████████| 93/93 [00:11<00:00,  8.20it/s]
0.499:   1%|          | 1/93 [00:00<00:11,  7.94it/s]

train epoch: 41 | loss: 0.5061038437710014


0.543: 100%|██████████| 93/93 [00:11<00:00,  8.20it/s]
0.486:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 42 | loss: 0.4971876246954805


0.535: 100%|██████████| 93/93 [00:11<00:00,  8.13it/s]
0.474:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 43 | loss: 0.48777098296790994


0.442: 100%|██████████| 93/93 [00:11<00:00,  8.14it/s]
0.472:   1%|          | 1/93 [00:00<00:12,  7.58it/s]

train epoch: 44 | loss: 0.47800023068663894


0.538: 100%|██████████| 93/93 [00:11<00:00,  8.13it/s]
0.474:   1%|          | 1/93 [00:00<00:11,  8.00it/s]

train epoch: 45 | loss: 0.4703238292406964


0.363: 100%|██████████| 93/93 [00:11<00:00,  8.18it/s]
0.444:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 46 | loss: 0.45990626017252606


0.474: 100%|██████████| 93/93 [00:11<00:00,  8.18it/s]
0.424:   1%|          | 1/93 [00:00<00:11,  8.06it/s]

train epoch: 47 | loss: 0.45153775779149863


0.534: 100%|██████████| 93/93 [00:11<00:00,  8.04it/s]
0.432:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 48 | loss: 0.4442020539314516


0.464: 100%|██████████| 93/93 [00:11<00:00,  8.17it/s]
0.391:   1%|          | 1/93 [00:00<00:11,  7.94it/s]

train epoch: 49 | loss: 0.43474070231119794


0.430: 100%|██████████| 93/93 [00:11<00:00,  8.14it/s]
0.396:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 50 | loss: 0.4257750562442246


0.447: 100%|██████████| 93/93 [00:11<00:00,  8.18it/s]
0.417:   1%|          | 1/93 [00:00<00:11,  8.26it/s]

train epoch: 51 | loss: 0.41748797508978075


0.401: 100%|██████████| 93/93 [00:11<00:00,  7.94it/s]
0.401:   1%|          | 1/93 [00:00<00:11,  8.33it/s]

train epoch: 52 | loss: 0.4087578558152722


0.411: 100%|██████████| 93/93 [00:11<00:00,  8.22it/s]
0.387:   1%|          | 1/93 [00:00<00:12,  7.63it/s]

train epoch: 53 | loss: 0.40111566358996975


0.353: 100%|██████████| 93/93 [00:11<00:00,  8.12it/s]
0.380:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 54 | loss: 0.3915512331070439


0.381: 100%|██████████| 93/93 [00:11<00:00,  8.19it/s]
0.375:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 55 | loss: 0.3840788974556872


0.436: 100%|██████████| 93/93 [00:11<00:00,  8.18it/s]
0.346:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 56 | loss: 0.3764425298219086


0.339: 100%|██████████| 93/93 [00:11<00:00,  8.11it/s]
0.384:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 57 | loss: 0.36801959622290825


0.325: 100%|██████████| 93/93 [00:11<00:00,  8.15it/s]
0.346:   1%|          | 1/93 [00:00<00:11,  8.26it/s]

train epoch: 58 | loss: 0.3607729019657258


0.380: 100%|██████████| 93/93 [00:11<00:00,  8.13it/s]
0.325:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 59 | loss: 0.35287639658938175


0.380: 100%|██████████| 93/93 [00:11<00:00,  8.22it/s]
0.357:   1%|          | 1/93 [00:00<00:11,  8.26it/s]

train epoch: 60 | loss: 0.3451576232910156


0.372: 100%|██████████| 93/93 [00:11<00:00,  8.08it/s]
0.328:   1%|          | 1/93 [00:00<00:11,  8.06it/s]

train epoch: 61 | loss: 0.3381924988121115


0.299: 100%|██████████| 93/93 [00:11<00:00,  8.24it/s]
0.297:   1%|          | 1/93 [00:00<00:11,  8.06it/s]

train epoch: 62 | loss: 0.3298845598774572


0.312: 100%|██████████| 93/93 [00:11<00:00,  8.10it/s]
0.308:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 63 | loss: 0.322441654820596


0.272: 100%|██████████| 93/93 [00:11<00:00,  8.13it/s]
0.298:   1%|          | 1/93 [00:00<00:11,  8.00it/s]

train epoch: 64 | loss: 0.3148705267137097


0.367: 100%|██████████| 93/93 [00:11<00:00,  8.16it/s]
0.281:   1%|          | 1/93 [00:00<00:11,  8.13it/s]

train epoch: 65 | loss: 0.3082339379095262


0.321: 100%|██████████| 93/93 [00:11<00:00,  8.04it/s]
0.301:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 66 | loss: 0.3013299921507476


0.317: 100%|██████████| 93/93 [00:11<00:00,  8.24it/s]
0.295:   1%|          | 1/93 [00:00<00:11,  8.33it/s]

train epoch: 67 | loss: 0.29473634945449007


0.252: 100%|██████████| 93/93 [00:11<00:00,  8.10it/s]
0.274:   1%|          | 1/93 [00:00<00:11,  8.20it/s]

train epoch: 68 | loss: 0.28686312193511637


0.267: 100%|██████████| 93/93 [00:11<00:00,  8.19it/s]
0.259:   1%|          | 1/93 [00:00<00:11,  7.75it/s]

train epoch: 69 | loss: 0.2803343906197497


0.278: 100%|██████████| 93/93 [00:11<00:00,  8.12it/s]

train epoch: 70 | loss: 0.2735347952893985





In [79]:
loss

tensor(0.2783, device='cuda:0', grad_fn=<NllLoss2DBackward>)

In [80]:
save(model, save_dir, save_prefix + "_" + str(round(loss.cpu().item(),6)), i)

In [81]:
def preprocess_sentence(sentence):
    sentence = re.sub(r"([?.!,])", r" \1 ", sentence)
    sentence = sentence.strip()
    return sentence

def evaluate(sentence):
    sentence = preprocess_sentence(sentence)
    input = torch.tensor([[2] + vocab.encode_as_ids(sentence) + [3]]).cuda()
    output = torch.tensor([[2]]).cuda()

    # 디코더의 예측 시작
    model.eval()
    for i in range(MAX_LENGTH):
        src_mask = model.generate_square_subsequent_mask(input.shape[1]).cuda()
        tgt_mask = model.generate_square_subsequent_mask(output.shape[1]).cuda()

        src_padding_mask = gen_attention_mask(input).cuda()
        tgt_padding_mask = gen_attention_mask(output).cuda()

        predictions = model(input, output, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask).transpose(0,1)
        # 현재(마지막) 시점의 예측 단어를 받아온다.
        predictions = predictions[:, -1:, :]
        predicted_id = torch.LongTensor(torch.argmax(predictions.cpu(), axis=-1))


        # 만약 마지막 시점의 예측 단어가 종료 토큰이라면 예측을 중단
        if torch.equal(predicted_id[0][0], torch.tensor(3)):
            break

        # 마지막 시점의 예측 단어를 출력에 연결한다.
        # 이는 for문을 통해서 디코더의 입력으로 사용될 예정이다.
        output = torch.cat([output, predicted_id.cuda()], axis=1)

    return torch.squeeze(output, axis=0).cpu().numpy()

def predict(sentence):
    prediction = evaluate(sentence)
    predicted_sentence = vocab.Decode(list(map(int,[i for i in prediction if i < vocab_size+7])))

    print('Input: {}'.format(sentence))
    print('Output: {}'.format(predicted_sentence))

    return predicted_sentence

In [58]:
model.load_state_dict(torch.load("./data/pretraining/save/4_chitchat_transformer_model/chitchat_transformer_1.215381_steps_81.pt"))

model.eval()

Tformer(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=512, bias=True)
       

In [82]:
result = predict("난 뭘 해야 할까?")

Input: 난 뭘 해야 할까?
Output: 정말 힘드신가봐요. 본인의 의사를 확실히 밝혀보세요 대한 두려움을 가지고 손해배상을 청구하세요.


In [83]:
result = predict("힘들다")

Input: 힘들다
Output: 이제 회사와 자신에 대해서 더 공부해서 자신감을 가지세요.


In [84]:
result = predict("난 혼자인게 좋아")

Input: 난 혼자인게 좋아
Output: 스트레스 받으시는 말고 적극적으로 장점을 찾아서 인정하고 호의를 보여보세요.


In [85]:
result = predict("결혼해줘")

Input: 결혼해줘
Output: 더 힘들 겠지만 못해요하는 어제 더 면밀히 더 면밀히 더 면밀히 걸리겠지만 해낼 수도 있어요. 작은하는 어제 더 힘들 겠지만요하는 어제 더 힘들 겠지만요하는 어제 더 힘들 겠지만요군요
