In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import re
import io
import torch
import random
import unicodedata
import numpy as np

from tqdm import tqdm
from konlpy.tag import Mecab, Okt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split, Dataset, DataLoader, SubsetRandomSampler

from torch.nn.utils.rnn import pad_sequence
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from data.utils import get_total_data
from sklearn.model_selection import train_test_split

In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DATA_DIR = "/home/pervinco/Datasets/KORENG"
SAVE_DIR = "/home/pervinco/Models/KORENG"
SRC_LANG, TRG_LANG = "ko", "en"
NUM_SAMPLES = 10000
MAX_SEQ_LEN = 10

EPOCHS = 50
BATCH_SIZE = 32
LEARNING_RATE = 0.001

PAD_TOKEN = 0
SOS_TOKEN = 1
EOS_TOKEN = 2


In [3]:
if SRC_LANG == "en":
    dataset = get_total_data(DATA_DIR, reverse=True)
else:
    dataset = get_total_data(DATA_DIR)

src_sentences, trg_sentences = dataset[0][:NUM_SAMPLES], dataset[1][:NUM_SAMPLES]
print(len(src_sentences), len(trg_sentences))

total_data.csv exist.
10000 10000


In [4]:
# ko_tokenizer = Mecab()
ko_tokenizer = Okt()
en_tokenizer = get_tokenizer("spacy", language="en_core_web_sm")

In [5]:
class WordVocab():
    def __init__(self):
        self.word2index = {
            '<PAD>': PAD_TOKEN,
            '<SOS>': SOS_TOKEN, 
            '<EOS>': EOS_TOKEN,
        }
        self.word2count = {}
        self.index2word = {
            PAD_TOKEN: '<PAD>', 
            SOS_TOKEN: '<SOS>', 
            EOS_TOKEN: '<EOS>'
        }
        
        self.n_words = 3  # PAD, SOS, EOS 포함

    def add_sentence(self, sentence):
        for word in sentence.split(' '):
            self.add_word(word)

    def add_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

    def indices_to_words(self, indices):
        words = [self.index2word[index] for index in indices]
        
        return words

In [6]:
class TextDataset(Dataset):
    def __init__(self, data_path, src_lang, trg_lang, max_seq_len=100, num_samples=10000):
        super(TextDataset, self).__init__()
        self.normalizer = re.compile(r'[^ ?,.!A-Za-z0-9가-힣+]')

        self.max_seq_len = max_seq_len

        self.PAD_TOKEN = 0
        self.SOS_TOKEN = 1
        self.EOS_TOKEN = 2
        
        if src_lang == "en":
            dataset = get_total_data(data_path, reverse=True)
            src_tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
            trg_tokenizer = Okt()
        else:
            dataset = get_total_data(data_path)
            src_tokenizer = Okt()
            trg_tokenizer = get_tokenizer("spacy", language="en_core_web_sm")

        self.src_sentences, self.trg_sentences = [], []
        self.src_vocab, self.trg_vocab = WordVocab(), WordVocab()
        for src_sentence, trg_sentence in zip(dataset[0][:num_samples], dataset[1][:num_samples]):
            src_sentence = self.clean_text(src_sentence, src_tokenizer, src_lang)
            trg_sentence = self.clean_text(trg_sentence, trg_tokenizer, trg_lang)

            self.src_vocab.add_sentence(src_sentence)
            self.src_sentences.append(src_sentence)
            
            self.trg_vocab.add_sentence(trg_sentence)
            self.trg_sentences.append(trg_sentence)

    def __len__(self):
        return len(self.src_sentences)


    def normalize(self, sentence):
        return self.normalizer.sub("", sentence)


    def clean_text(self, sentence, tokenizer, lang):
        sentence = self.normalize(sentence)

        if lang == "ko":
            sentence = tokenizer.morphs(sentence)
        else:
            sentence = tokenizer(sentence)

        sentence = ' '.join(sentence)
        sentence = sentence.lower()
        
        return sentence
    

    def texts_to_sequences(self, vocab, sentence):
        return [vocab.word2index[w] for w in sentence.split()]
    

    def pad_sequence(self, sentence_tokens):
        sentence_tokens = sentence_tokens[:(self.max_seq_len - 1)]
        token_length = len(sentence_tokens)
        
        sentence_tokens.append(self.EOS_TOKEN)
        for i in range(token_length, (self.max_seq_len - 1)):
                sentence_tokens.append(self.PAD_TOKEN)

        return sentence_tokens
    

    def __getitem__(self, idx):
        src, trg = self.src_sentences[idx], self.trg_sentences[idx]
        
        src_sequences = self.texts_to_sequences(self.src_vocab, src)
        trg_sequences = self.texts_to_sequences(self.trg_vocab, trg)

        src_padded = self.pad_sequence(src_sequences)
        trg_padded = self.pad_sequence(trg_sequences)

        return torch.tensor(src_padded), torch.tensor(trg_padded)

In [7]:
dataset = TextDataset(DATA_DIR, SRC_LANG, TRG_LANG, MAX_SEQ_LEN, NUM_SAMPLES)

train_size = int(len(dataset) * 0.8)
valid_size = len(dataset) - train_size
print(train_size, valid_size)

train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

total_data.csv exist.
8000 2000


In [8]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=16)

In [9]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, embedding_dim, num_layers):
        super(Encoder, self).__init__()    
        self.embedding = nn.Embedding(input_size, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_size, num_layers=num_layers, bidirectional=False)
        
    def forward(self, x):
        x = self.embedding(x).permute(1, 0, 2) ## (batch_size, max_seq_len, embedd_dim) -> (max_seq_len, batch_size, embedd_dim)
        output, hidden = self.gru(x) ## output : (sequence_length, batch_size, hidden_size x bidirectional) hidden_state: (bidirectional x number of layers, batch_size, hidden_size)

        return output, hidden
    
    def print_parameters(self):
        for name, param in self.gru.named_parameters():
            print(f"Param : {name}, Shape : {param.shape}")

In [10]:
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, embedding_dim, num_layers=1, dropout=0.2):
        super(Decoder, self).__init__()
        self.input_size = input_size
        self.embedding = nn.Embedding(input_size, embedding_dim)
        self.dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(embedding_dim, hidden_size, num_layers=num_layers, bidirectional=False)
        
        self.fc = nn.Linear(hidden_size, input_size)
        
    def forward(self, x, hidden_state):
        x = x.unsqueeze(0) # (1, batch_size)
        embedded = F.relu(self.embedding(x))
        embedded = self.dropout(embedded)
        output, hidden = self.gru(embedded, hidden_state)
        output = self.fc(output.squeeze(0)) # (sequence_length, batch_size, hidden_size x bidirectional)
        
        return output, hidden
    
    def print_parameters(self):
        for name, param in self.gru.named_parameters():
            print(f"Param : {name}, Shape : {param.shape}")

In [11]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, inputs, outputs, teacher_forcing_ratio=0.5):
        ## inputs : (batch_size, sequence_length)
        ## outputs: (batch_size, sequence_length)
        batch_size, output_length = outputs.shape
        output_num_vocabs = self.decoder.input_size
        
        predicted_outputs = torch.zeros(output_length, batch_size, output_num_vocabs).to(self.device) ## 예측을 저장할 변수. (sequence_length, batch_size, num_vocabs)
        
        _, decoder_hidden = self.encoder(inputs) ## output은 사용하지 않고 마지막 hidden_state(context_vector)를 사용.
        
        decoder_input = torch.full((batch_size,), SOS_TOKEN, device=self.device) ## (batch_size) shape의 SOS TOKEN으로 채워진 디코더 입력 생성
        
        ## 순회하면서 출력 단어를 생성.
        ## 0번째는 SOS TOKEN이 위치하므로, 1번째 인덱스부터 순회.
        for t in range(0, output_length):
            ## decoder_output: (batch_size, num_vocabs),  decoder_hidden: (Bidirectional x num layers, batch_size, hidden_size)
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)

            predicted_outputs[t] = decoder_output ## t번째 단어로 decoder_output을 저장.
            
            ## teacher forcing 적용 여부 확률로 결정
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = decoder_output.argmax(1) 
            
            ## teacher forcing 인 경우 ground truth 값을, 그렇지 않은 경우, 예측 값을 다음 input으로 지정
            decoder_input = outputs[:, t] if teacher_force else top1
        
        return predicted_outputs.permute(1, 0, 2) # (batch_size, sequence_length, num_vocabs)로 변경

In [12]:
SRC_SIZE = dataset.src_vocab.n_words
TRG_SIZE = dataset.trg_vocab.n_words
HIDDEN_SIZE = 512
EMBEDDING_DIM = 256
NUM_LAYERS = 1

encoder = Encoder(SRC_SIZE, HIDDEN_SIZE, EMBEDDING_DIM, NUM_LAYERS)
decoder = Decoder(TRG_SIZE, HIDDEN_SIZE, EMBEDDING_DIM, NUM_LAYERS)
model = Seq2Seq(encoder.to(DEVICE), decoder.to(DEVICE), DEVICE)

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_func = nn.CrossEntropyLoss()

In [14]:
def train(model, dataloader, optimizer, loss_fn, device):
    model.train()
    
    train_loss = 0
    for x, y in tqdm(dataloader, desc='Training', leave=False):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        
        output = model(x, y) ## output: (batch_size, sequence_length, num_vocabs)
        output_dim = output.size(2)
        
        # 1번 index 부터 슬라이싱한 이유는 0번 index가 SOS TOKEN 이기 때문
        # (batch_size*sequence_length, num_vocabs) 로 변경
        output = output.reshape(-1, output_dim)
        
        # (batch_size*sequence_length) 로 변경
        y = y.view(-1)
        
        # Loss 계산
        loss = loss_fn(output, y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * x.size(0)
        
    return train_loss / len(dataloader)

In [15]:
def evaluate(model, data_loader, loss_fn, device):
    model.eval()
    
    eval_loss = 0
    
    with torch.no_grad():
        for x, y in tqdm(data_loader, desc='Evaluating', leave=False):
            x, y = x.to(device), y.to(device)
            output = model(x, y)
            output_dim = output.size(2)
            output = output.reshape(-1, output_dim)
            y = y.view(-1)
            
            # Loss 계산
            loss = loss_fn(output, y)
            
            eval_loss += loss.item() * x.size(0)
            
    return eval_loss / len(data_loader)

In [18]:
best_loss = np.inf

for epoch in range(EPOCHS):
    loss = train(model, train_dataloader, optimizer, loss_func, DEVICE)
    
    val_loss = evaluate(model, valid_dataloader, loss_func, DEVICE)
    
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), f"{SAVE_DIR}/best.pt")
    
    print(f'epoch: {epoch+1}, loss: {loss:.4f}, val_loss: {val_loss:.4f}')

                   
model.load_state_dict(torch.load(f"{SAVE_DIR}/best.pt"))
torch.save(model.state_dict(), f"{SAVE_DIR}/best.pt")

                                                             

epoch: 1, loss: 96.5346, val_loss: 90.8689


                                                             

epoch: 2, loss: 82.8937, val_loss: 89.2483


                                                             

epoch: 3, loss: 70.2675, val_loss: 88.9715


                                                             

epoch: 4, loss: 55.4620, val_loss: 91.2322


                                                             

epoch: 5, loss: 40.7641, val_loss: 94.0173


                                                             

epoch: 6, loss: 30.2688, val_loss: 98.3399


                                                             

epoch: 7, loss: 23.5668, val_loss: 101.9385


                                                             

epoch: 8, loss: 18.1712, val_loss: 104.9359


                                                             

epoch: 9, loss: 14.0083, val_loss: 108.7348


                                                             

epoch: 10, loss: 10.6166, val_loss: 111.1065


                                                             

epoch: 11, loss: 8.2793, val_loss: 114.6534


                                                             

epoch: 12, loss: 6.7008, val_loss: 116.6550


                                                             

epoch: 13, loss: 5.5192, val_loss: 120.9443


                                                             

epoch: 14, loss: 4.3951, val_loss: 122.1172


                                                             

epoch: 15, loss: 3.7047, val_loss: 123.9714


                                                             

epoch: 16, loss: 3.2202, val_loss: 126.4316


                                                             

epoch: 17, loss: 2.9037, val_loss: 129.4315


                                                             

epoch: 18, loss: 2.8508, val_loss: 130.3208


                                                             

epoch: 19, loss: 2.7425, val_loss: 133.5095


                                                             

epoch: 20, loss: 2.6825, val_loss: 134.1253


                                                             

epoch: 21, loss: 2.7304, val_loss: 137.3765


                                                             

epoch: 22, loss: 2.6223, val_loss: 138.3438


                                                             

epoch: 23, loss: 2.3814, val_loss: 139.8089


                                                             

epoch: 24, loss: 2.4827, val_loss: 141.3422


                                                             

epoch: 25, loss: 2.5424, val_loss: 142.7616


                                                             

epoch: 26, loss: 2.0838, val_loss: 145.4718


                                                             

epoch: 27, loss: 1.6170, val_loss: 146.1205


                                                             

epoch: 28, loss: 1.3984, val_loss: 148.0793


                                                             

epoch: 29, loss: 1.2781, val_loss: 147.9606


                                                             

epoch: 30, loss: 1.3433, val_loss: 149.8410


                                                             

epoch: 31, loss: 1.8416, val_loss: 150.8937


                                                             

epoch: 32, loss: 2.8740, val_loss: 151.2245


                                                             

epoch: 33, loss: 2.8864, val_loss: 151.6571


                                                             

epoch: 34, loss: 2.5421, val_loss: 152.9570


                                                             

epoch: 35, loss: 1.7736, val_loss: 155.2676


                                                             

epoch: 36, loss: 1.3158, val_loss: 155.7565


                                                             

epoch: 37, loss: 1.0850, val_loss: 155.9357


                                                             

epoch: 38, loss: 1.3198, val_loss: 159.7298


                                                             

epoch: 39, loss: 1.5683, val_loss: 159.8485


                                                             

epoch: 40, loss: 1.9097, val_loss: 159.5841


                                                             

epoch: 41, loss: 2.1813, val_loss: 160.9427


                                                             

epoch: 42, loss: 2.2998, val_loss: 161.1937


                                                             

epoch: 43, loss: 2.1756, val_loss: 162.5742


                                                             

epoch: 44, loss: 1.9453, val_loss: 163.6903


                                                             

epoch: 45, loss: 1.5301, val_loss: 164.0657


                                                             

epoch: 46, loss: 1.3211, val_loss: 164.0737


                                                             

epoch: 47, loss: 1.4855, val_loss: 165.1812


                                                             

epoch: 48, loss: 1.9220, val_loss: 165.0474


                                                             

epoch: 49, loss: 1.9793, val_loss: 167.6479


                                                             

epoch: 50, loss: 1.8119, val_loss: 167.8461
