데이터셋: http://www.manythings.org/anki fra-eng


In [1]:
import sys
import pathlib
import os

sys.path.append(str(pathlib.Path(os.getcwd()).parent))

import copy
import re
import unicodedata
import numpy as np
import pandas as pd
import torch
from collections import Counter
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader

from src.base_module import *
from src.transformer import Transformer

In [2]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f'device: {device}')

device: mps


In [3]:
num_samples = 33000

In [4]:
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

def preprocess_sentence(sent):
    sent = unicode_to_ascii(sent.lower())
    sent = re.sub(r"([?.!,¿])", r" \1", sent)
    sent = re.sub(r"[^a-zA-Z!.?]+", r" ", sent)
    sent = re.sub(r"\s+", " ", sent)

    return sent

def load_preprocess_data():
    encoder_input, decoder_input, decoder_target = [], [], []

    with open('./fra.txt', 'r') as f:
        lines = f.readlines()[:num_samples]
        for i, line in enumerate(lines):
            src, tar, _ = line.strip().split('\t')
            src = [w for w in preprocess_sentence(src).split()]
            tar = preprocess_sentence(tar)
            tgt_in = [w for w in f'<sos> {tar}'.split()]
            tgt_out = [w for w in f'{tar} <eos>'.split()]

            encoder_input.append(src)
            decoder_input.append(tgt_in)
            decoder_target.append(tgt_out)

    return encoder_input, decoder_input, decoder_target

In [5]:
en_sent = u"Have you had dinner?"
fr_sent = u"Avez-vous déjà diné?"

print(f'{en_sent} -> {preprocess_sentence(en_sent)}')
print(f'{fr_sent} -> {preprocess_sentence(fr_sent)}')

Have you had dinner? -> have you had dinner ?
Avez-vous déjà diné? -> avez vous deja dine ?


In [6]:
sents_en_in, sents_fra_in, sents_fra_out = load_preprocess_data()

print(len(sents_en_in))
print(len(sents_fra_in))
print(len(sents_fra_out))

33000
33000
33000


In [7]:
print(sents_en_in[:5])
print(sents_fra_in[:5])
print(sents_fra_out[:5])

[['go', '.'], ['go', '.'], ['go', '.'], ['go', '.'], ['hi', '.']]
[['<sos>', 'va', '!'], ['<sos>', 'marche', '.'], ['<sos>', 'en', 'route', '!'], ['<sos>', 'bouge', '!'], ['<sos>', 'salut', '!']]
[['va', '!', '<eos>'], ['marche', '.', '<eos>'], ['en', 'route', '!', '<eos>'], ['bouge', '!', '<eos>'], ['salut', '!', '<eos>']]


In [8]:
def build_vocab(sents):
    words = []
    for sent in sents:
        for word in sent:
            words.append(word)

    word_counts = Counter(words)
    vocab = sorted(word_counts, key=word_counts.get, reverse=True)

    word2index = {}
    word2index['<PAD>'] = 0
    word2index['<UNK>'] = 1

    for i, word in enumerate(vocab):
        word2index[word] = i + 2

    return word2index

In [9]:
src_vocab = build_vocab(sents_en_in)
tgt_vocab = build_vocab(sents_fra_in + sents_fra_out)

src_vocab_size = len(src_vocab)
tgt_vocab_size = len(tgt_vocab)

print(f'src vocab size: {src_vocab_size}')
print(f'tar vocab size: {tgt_vocab_size}')

src vocab size: 4486
tar vocab size: 7879


In [10]:
index2src = {v: k for k, v in src_vocab.items()}
index2tar = {v: k for k, v in tgt_vocab.items()}

In [11]:
def encode_sentences(sents, word2index):
    encoded_data = []
    for sent in tqdm(sents):
        encoded_sent = []
        for word in sent:
            try:
                encoded_sent.append(word2index[word])
            except KeyError:
                encoded_sent.append(word2index['<UNK>'])
        encoded_data.append(encoded_sent)

    return encoded_data

In [12]:
encoder_input = encode_sentences(sents_en_in, src_vocab)
decoder_input = encode_sentences(sents_fra_in, tgt_vocab)
decoder_target = encode_sentences(sents_fra_out, tgt_vocab)

100%|██████████| 33000/33000 [00:00<00:00, 537589.26it/s]
100%|██████████| 33000/33000 [00:00<00:00, 2220597.01it/s]
100%|██████████| 33000/33000 [00:00<00:00, 2283273.38it/s]


In [13]:
print(encoder_input[:5])
print(decoder_input[:5])
print(decoder_target[:5])

[[27, 2], [27, 2], [27, 2], [27, 2], [736, 2]]
[[3, 68, 11], [3, 204, 2], [3, 26, 491, 11], [3, 561, 11], [3, 954, 11]]
[[68, 11, 4], [204, 2, 4], [26, 491, 11, 4], [561, 11, 4], [954, 11, 4]]


In [14]:
def pad_sentences(sents, max_len=None):
    if max_len is None:
        max_len = max([len(s) for s in sents])

    features = np.zeros((len(sents), max_len), dtype=int)
    for i, sent in enumerate(sents):
        features[i, :len(sent)] = np.array(sent)[:max_len]

    return features

In [15]:
encoder_input = pad_sentences(encoder_input)
decoder_input = pad_sentences(decoder_input)
decoder_target = pad_sentences(decoder_target)

In [16]:
print(encoder_input[:3])
print(decoder_input[:3])
print(decoder_target[:3])

[[27  2  0  0  0  0  0]
 [27  2  0  0  0  0  0]
 [27  2  0  0  0  0  0]]
[[  3  68  11   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  3 204   2   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  3  26 491  11   0   0   0   0   0   0   0   0   0   0   0   0]]
[[ 68  11   4   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [204   2   4   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [ 26 491  11   4   0   0   0   0   0   0   0   0   0   0   0   0]]


In [17]:
print(encoder_input.shape)
print(decoder_input.shape)
print(decoder_target.shape)

(33000, 7)
(33000, 16)
(33000, 16)


In [18]:
indices = np.arange(encoder_input.shape[0])
np.random.shuffle(indices)
print('랜덤 시퀀스 :',indices)

랜덤 시퀀스 : [30083 30064  7979 ...  4385  3419  5667]


In [19]:
encoder_input = encoder_input[indices]
decoder_input = decoder_input[indices]
decoder_target = decoder_target[indices]

In [20]:
print([index2src[word] for word in encoder_input[30997]])
print([index2tar[word] for word in decoder_input[30997]])
print([index2tar[word] for word in decoder_target[30997]])

['slow', 'down', '.', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
['<sos>', 'calmos', '!', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
['calmos', '!', '<eos>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']


In [21]:
n_of_val = int(num_samples * 0.1)
print('검증 데이터의 개수 :',n_of_val)

encoder_input_train = encoder_input[:-n_of_val]
decoder_input_train = decoder_input[:-n_of_val]
decoder_target_train = decoder_target[:-n_of_val]

encoder_input_test = encoder_input[-n_of_val:]
decoder_input_test = decoder_input[-n_of_val:]
decoder_target_test = decoder_target[-n_of_val:]

검증 데이터의 개수 : 3300


In [22]:
print('훈련 source 데이터의 크기 :',encoder_input_train.shape)
print('훈련 target 데이터의 크기 :',decoder_input_train.shape)
print('훈련 target 레이블의 크기 :',decoder_target_train.shape)
print('테스트 source 데이터의 크기 :',encoder_input_test.shape)
print('테스트 target 데이터의 크기 :',decoder_input_test.shape)
print('테스트 target 레이블의 크기 :',decoder_target_test.shape)

훈련 source 데이터의 크기 : (29700, 7)
훈련 target 데이터의 크기 : (29700, 16)
훈련 target 레이블의 크기 : (29700, 16)
테스트 source 데이터의 크기 : (3300, 7)
테스트 target 데이터의 크기 : (3300, 16)
테스트 target 레이블의 크기 : (3300, 16)


In [23]:
encoder_input_train_tensor = torch.tensor(encoder_input_train, dtype=torch.long)
decoder_input_train_tensor = torch.tensor(decoder_input_train, dtype=torch.long)
decoder_target_train_tensor = torch.tensor(decoder_target_train, dtype=torch.long)

encoder_input_test_tensor = torch.tensor(encoder_input_test, dtype=torch.long)
decoder_input_test_tensor = torch.tensor(decoder_input_test, dtype=torch.long)
decoder_target_test_tensor = torch.tensor(decoder_target_test, dtype=torch.long)

# 데이터셋 및 데이터로더 생성
batch_size = 512

train_dataset = TensorDataset(encoder_input_train_tensor, decoder_input_train_tensor, decoder_target_train_tensor)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = TensorDataset(encoder_input_test_tensor, decoder_input_test_tensor, decoder_target_test_tensor)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [24]:
max_len = 20
d_embed = 128
d_model = 128
n_layer = 4
h = 4
d_ff = 256

src_token_embed = TokenEmbedding(d_embed=d_embed, vocab_size=src_vocab_size)
tgt_token_embed = TokenEmbedding(d_embed=d_embed, vocab_size=tgt_vocab_size)
pos_embed = PositionalEncoding(d_embed=d_embed, max_len=max_len)
src_embed = TransformerEmbedding(token_embed=src_token_embed, pos_embed=copy.deepcopy(pos_embed))
tgt_embed = TransformerEmbedding(token_embed=tgt_token_embed, pos_embed=copy.deepcopy(pos_embed))

attention = MultiHeadAttentionLayer(d_model=d_model, h=h, qkv_fc=nn.Linear(d_embed, d_model), out_fc=nn.Linear(d_model, d_embed))
position_ff = PositionWiseFeedForwardLayer(fc1=nn.Linear(d_embed, d_ff), fc2=nn.Linear(d_ff, d_embed))

encoder_block = EncoderBlock(self_attention=copy.deepcopy(attention), position_ff=copy.deepcopy(position_ff), d_model=d_model)
decoder_block = DecoderBlock(self_attention=copy.deepcopy(attention), cross_attention=copy.deepcopy(attention), position_ff=copy.deepcopy(position_ff), d_model=d_model)

encoder = Encoder(encoder_block=encoder_block, n_layer=n_layer)
decoder = Decoder(decoder_block=decoder_block, n_layer=n_layer)

generator = nn.Linear(d_embed, tgt_vocab_size)

model = Transformer(
    src_embed=src_embed,
    tgt_embed=tgt_embed,
    encoder=encoder,
    decoder=decoder,
    generator=generator
).to(device)
loss_function = nn.CrossEntropyLoss(ignore_index=0).to(device)
optimizer = torch.optim.AdamW(model.parameters())

In [25]:
def evaluation(model, dataloader, loss_function, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_count = 0

    with torch.no_grad():
        for encoder_inputs, decoder_inputs, decoder_targets in dataloader:
            encoder_inputs = encoder_inputs.to(device)
            decoder_inputs = decoder_inputs.to(device)
            decoder_targets = decoder_targets.to(device)

            outputs, _ = model(encoder_inputs, decoder_inputs)

            loss = loss_function(outputs.view(-1, outputs.size(-1)), decoder_targets.view(-1))
            total_loss += loss.item()

            mask = decoder_targets != 0
            total_correct += ((outputs.argmax(dim=-1) == decoder_targets) * mask).sum().item()
            total_count += mask.sum().item()

    return total_loss / len(dataloader), total_correct / total_count

In [26]:
num_epochs = 50

In [27]:
for encoder_inputs, decoder_inputs, decoder_targets in tqdm(train_dataloader):
    encoder_inputs = encoder_inputs.to(device)
    decoder_inputs = decoder_inputs.to(device)
    decoder_targets = decoder_targets.to(device)
    print(decoder_inputs[0])
    print(decoder_targets[0])
    break

  0%|          | 0/59 [00:00<?, ?it/s]

tensor([  3,  14,   6, 200,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0], device='mps:0')
tensor([ 14,   6, 200,   2,   4,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0], device='mps:0')





In [28]:
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()

    for encoder_inputs, decoder_inputs, decoder_targets in tqdm(train_dataloader):
        encoder_inputs = encoder_inputs.to(device)
        decoder_inputs = decoder_inputs.to(device)
        decoder_targets = decoder_targets.to(device)
        optimizer.zero_grad()

        outputs, _ = model(encoder_inputs, decoder_inputs)

        loss = loss_function(outputs.view(-1, outputs.size(-1)), decoder_targets.view(-1))
        loss.backward()
        optimizer.step()

    train_loss, train_acc = evaluation(model, train_dataloader, loss_function, device)
    valid_loss, valid_acc = evaluation(model, valid_dataloader, loss_function, device)

    print(f'Epoch: {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc:.4f}')

    if valid_loss < best_val_loss:
        print(f'Validation loss improved from {best_val_loss:.4f} to {valid_loss:.4f}. 체크포인트를 저장합니다.')
        best_val_loss = valid_loss
        torch.save(model.state_dict(), 'best_model_checkpoint.pth')


100%|██████████| 59/59 [00:09<00:00,  6.07it/s]


Epoch: 1/50 | Train Loss: 5.1982 | Train Acc: 0.4573 | Valid Loss: 5.6552 | Valid Acc: 0.4486
Validation loss improved from inf to 5.6552. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:09<00:00,  6.55it/s]


Epoch: 2/50 | Train Loss: 3.3235 | Train Acc: 0.5239 | Valid Loss: 3.9800 | Valid Acc: 0.5081
Validation loss improved from 5.6552 to 3.9800. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:09<00:00,  6.47it/s]


Epoch: 3/50 | Train Loss: 2.6583 | Train Acc: 0.5593 | Valid Loss: 3.3764 | Valid Acc: 0.5390
Validation loss improved from 3.9800 to 3.3764. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:09<00:00,  6.42it/s]


Epoch: 4/50 | Train Loss: 2.2992 | Train Acc: 0.5785 | Valid Loss: 3.0723 | Valid Acc: 0.5510
Validation loss improved from 3.3764 to 3.0723. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:09<00:00,  6.36it/s]


Epoch: 5/50 | Train Loss: 2.0457 | Train Acc: 0.5979 | Valid Loss: 2.8548 | Valid Acc: 0.5685
Validation loss improved from 3.0723 to 2.8548. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.56it/s]


Epoch: 6/50 | Train Loss: 1.7998 | Train Acc: 0.6292 | Valid Loss: 2.6705 | Valid Acc: 0.5949
Validation loss improved from 2.8548 to 2.6705. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.72it/s]


Epoch: 7/50 | Train Loss: 1.6410 | Train Acc: 0.6467 | Valid Loss: 2.5663 | Valid Acc: 0.6004
Validation loss improved from 2.6705 to 2.5663. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:09<00:00,  6.51it/s]


Epoch: 8/50 | Train Loss: 1.4749 | Train Acc: 0.6698 | Valid Loss: 2.4467 | Valid Acc: 0.6152
Validation loss improved from 2.5663 to 2.4467. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.64it/s]


Epoch: 9/50 | Train Loss: 1.3340 | Train Acc: 0.6904 | Valid Loss: 2.3471 | Valid Acc: 0.6240
Validation loss improved from 2.4467 to 2.3471. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.79it/s]


Epoch: 10/50 | Train Loss: 1.2071 | Train Acc: 0.7127 | Valid Loss: 2.2630 | Valid Acc: 0.6357
Validation loss improved from 2.3471 to 2.2630. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.86it/s]


Epoch: 11/50 | Train Loss: 1.1119 | Train Acc: 0.7213 | Valid Loss: 2.2156 | Valid Acc: 0.6363
Validation loss improved from 2.2630 to 2.2156. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.86it/s]


Epoch: 12/50 | Train Loss: 1.0813 | Train Acc: 0.7302 | Valid Loss: 2.2142 | Valid Acc: 0.6331
Validation loss improved from 2.2156 to 2.2142. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:09<00:00,  6.48it/s]


Epoch: 13/50 | Train Loss: 0.9361 | Train Acc: 0.7536 | Valid Loss: 2.1092 | Valid Acc: 0.6585
Validation loss improved from 2.2142 to 2.1092. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:09<00:00,  6.45it/s]


Epoch: 14/50 | Train Loss: 0.8828 | Train Acc: 0.7677 | Valid Loss: 2.0738 | Valid Acc: 0.6624
Validation loss improved from 2.1092 to 2.0738. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:09<00:00,  6.45it/s]


Epoch: 15/50 | Train Loss: 0.8178 | Train Acc: 0.7788 | Valid Loss: 2.0516 | Valid Acc: 0.6707
Validation loss improved from 2.0738 to 2.0516. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.85it/s]


Epoch: 16/50 | Train Loss: 0.7772 | Train Acc: 0.7854 | Valid Loss: 2.0425 | Valid Acc: 0.6719
Validation loss improved from 2.0516 to 2.0425. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.96it/s]


Epoch: 17/50 | Train Loss: 0.7496 | Train Acc: 0.7909 | Valid Loss: 2.0282 | Valid Acc: 0.6727
Validation loss improved from 2.0425 to 2.0282. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.70it/s]


Epoch: 18/50 | Train Loss: 0.6984 | Train Acc: 0.8005 | Valid Loss: 1.9978 | Valid Acc: 0.6776
Validation loss improved from 2.0282 to 1.9978. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.75it/s]


Epoch: 19/50 | Train Loss: 0.6755 | Train Acc: 0.8052 | Valid Loss: 1.9939 | Valid Acc: 0.6824
Validation loss improved from 1.9978 to 1.9939. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.76it/s]


Epoch: 20/50 | Train Loss: 0.6696 | Train Acc: 0.8092 | Valid Loss: 1.9927 | Valid Acc: 0.6840
Validation loss improved from 1.9939 to 1.9927. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.69it/s]


Epoch: 21/50 | Train Loss: 0.6279 | Train Acc: 0.8155 | Valid Loss: 1.9763 | Valid Acc: 0.6855
Validation loss improved from 1.9927 to 1.9763. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.63it/s]


Epoch: 22/50 | Train Loss: 0.6253 | Train Acc: 0.8125 | Valid Loss: 1.9713 | Valid Acc: 0.6806
Validation loss improved from 1.9763 to 1.9713. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.78it/s]


Epoch: 23/50 | Train Loss: 0.6030 | Train Acc: 0.8199 | Valid Loss: 1.9713 | Valid Acc: 0.6908
Validation loss improved from 1.9713 to 1.9713. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.79it/s]


Epoch: 24/50 | Train Loss: 0.5776 | Train Acc: 0.8247 | Valid Loss: 1.9738 | Valid Acc: 0.6961


100%|██████████| 59/59 [00:08<00:00,  6.80it/s]


Epoch: 25/50 | Train Loss: 0.5660 | Train Acc: 0.8313 | Valid Loss: 1.9419 | Valid Acc: 0.6978
Validation loss improved from 1.9713 to 1.9419. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.87it/s]


Epoch: 26/50 | Train Loss: 0.5496 | Train Acc: 0.8227 | Valid Loss: 1.9507 | Valid Acc: 0.6897


100%|██████████| 59/59 [00:08<00:00,  6.95it/s]


Epoch: 27/50 | Train Loss: 0.5192 | Train Acc: 0.8399 | Valid Loss: 1.9377 | Valid Acc: 0.7016
Validation loss improved from 1.9419 to 1.9377. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  7.00it/s]


Epoch: 28/50 | Train Loss: 0.5009 | Train Acc: 0.8432 | Valid Loss: 1.9283 | Valid Acc: 0.7020
Validation loss improved from 1.9377 to 1.9283. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.73it/s]


Epoch: 29/50 | Train Loss: 0.4956 | Train Acc: 0.8450 | Valid Loss: 1.9194 | Valid Acc: 0.7047
Validation loss improved from 1.9283 to 1.9194. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:09<00:00,  6.44it/s]


Epoch: 30/50 | Train Loss: 0.4993 | Train Acc: 0.8441 | Valid Loss: 1.9236 | Valid Acc: 0.7052


100%|██████████| 59/59 [00:09<00:00,  6.41it/s]


Epoch: 31/50 | Train Loss: 0.4804 | Train Acc: 0.8497 | Valid Loss: 1.9213 | Valid Acc: 0.7074


100%|██████████| 59/59 [00:09<00:00,  6.49it/s]


Epoch: 32/50 | Train Loss: 0.4722 | Train Acc: 0.8500 | Valid Loss: 1.9063 | Valid Acc: 0.7089
Validation loss improved from 1.9194 to 1.9063. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.90it/s]


Epoch: 33/50 | Train Loss: 0.4514 | Train Acc: 0.8557 | Valid Loss: 1.9059 | Valid Acc: 0.7089
Validation loss improved from 1.9063 to 1.9059. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.99it/s]


Epoch: 34/50 | Train Loss: 0.4460 | Train Acc: 0.8583 | Valid Loss: 1.9147 | Valid Acc: 0.7052


100%|██████████| 59/59 [00:08<00:00,  7.00it/s]


Epoch: 35/50 | Train Loss: 0.4322 | Train Acc: 0.8610 | Valid Loss: 1.8986 | Valid Acc: 0.7119
Validation loss improved from 1.9059 to 1.8986. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.62it/s]


Epoch: 36/50 | Train Loss: 0.4260 | Train Acc: 0.8614 | Valid Loss: 1.9150 | Valid Acc: 0.7121


100%|██████████| 59/59 [00:08<00:00,  6.77it/s]


Epoch: 37/50 | Train Loss: 0.4007 | Train Acc: 0.8676 | Valid Loss: 1.9004 | Valid Acc: 0.7146


100%|██████████| 59/59 [00:08<00:00,  6.83it/s]


Epoch: 38/50 | Train Loss: 0.4185 | Train Acc: 0.8642 | Valid Loss: 1.9186 | Valid Acc: 0.7120


100%|██████████| 59/59 [00:08<00:00,  6.92it/s]


Epoch: 39/50 | Train Loss: 0.3955 | Train Acc: 0.8694 | Valid Loss: 1.8961 | Valid Acc: 0.7133
Validation loss improved from 1.8986 to 1.8961. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.86it/s]


Epoch: 40/50 | Train Loss: 0.4116 | Train Acc: 0.8634 | Valid Loss: 1.9275 | Valid Acc: 0.7024


100%|██████████| 59/59 [00:08<00:00,  6.93it/s]


Epoch: 41/50 | Train Loss: 0.3908 | Train Acc: 0.8690 | Valid Loss: 1.9102 | Valid Acc: 0.7139


100%|██████████| 59/59 [00:08<00:00,  6.96it/s]


Epoch: 42/50 | Train Loss: 0.3865 | Train Acc: 0.8712 | Valid Loss: 1.9228 | Valid Acc: 0.7138


100%|██████████| 59/59 [00:08<00:00,  6.79it/s]


Epoch: 43/50 | Train Loss: 0.3775 | Train Acc: 0.8744 | Valid Loss: 1.8944 | Valid Acc: 0.7168
Validation loss improved from 1.8961 to 1.8944. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.77it/s]


Epoch: 44/50 | Train Loss: 0.3697 | Train Acc: 0.8740 | Valid Loss: 1.9021 | Valid Acc: 0.7150


100%|██████████| 59/59 [00:08<00:00,  6.97it/s]


Epoch: 45/50 | Train Loss: 0.3528 | Train Acc: 0.8772 | Valid Loss: 1.8913 | Valid Acc: 0.7196
Validation loss improved from 1.8944 to 1.8913. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.92it/s]


Epoch: 46/50 | Train Loss: 0.3398 | Train Acc: 0.8816 | Valid Loss: 1.8814 | Valid Acc: 0.7201
Validation loss improved from 1.8913 to 1.8814. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:08<00:00,  6.95it/s]


Epoch: 47/50 | Train Loss: 0.3399 | Train Acc: 0.8835 | Valid Loss: 1.8869 | Valid Acc: 0.7205


100%|██████████| 59/59 [00:08<00:00,  6.89it/s]


Epoch: 48/50 | Train Loss: 0.3334 | Train Acc: 0.8855 | Valid Loss: 1.8887 | Valid Acc: 0.7210


100%|██████████| 59/59 [00:08<00:00,  6.93it/s]


Epoch: 49/50 | Train Loss: 0.3265 | Train Acc: 0.8857 | Valid Loss: 1.8837 | Valid Acc: 0.7206


100%|██████████| 59/59 [00:08<00:00,  6.97it/s]


Epoch: 50/50 | Train Loss: 0.3044 | Train Acc: 0.8922 | Valid Loss: 1.8626 | Valid Acc: 0.7276
Validation loss improved from 1.8814 to 1.8626. 체크포인트를 저장합니다.


In [29]:
model.load_state_dict(torch.load('best_model_checkpoint.pth', weights_only=True))
model.to(device)

val_loss, val_accuracy = evaluation(model, valid_dataloader, loss_function, device)

print(f'Best model validation loss: {val_loss:.4f}')
print(f'Best model validation accuracy: {val_accuracy:.4f}')

Best model validation loss: 1.8626
Best model validation accuracy: 0.7276


In [30]:
def seq_to_src(input_seq):
  sentence = ''
  for encoded_word in input_seq:
    if(encoded_word != 0):
      sentence = sentence + index2src[encoded_word] + ' '
  return sentence

def seq_to_tar(input_seq):
  sentence = ''
  for encoded_word in input_seq:
    if(encoded_word != 0 and encoded_word != tgt_vocab['<sos>'] and encoded_word != tgt_vocab['<eos>']):
      sentence = sentence + index2tar[encoded_word] + ' '
  return sentence

In [31]:
def decode_sequence(input_seq, model):
    model.eval()
    encoder_inputs = torch.LongTensor(input_seq).unsqueeze(0).to(device)
    src_mask = model.make_src_mask(encoder_inputs)
    encoder_out = model.encode(encoder_inputs, src_mask)

    decoded_tokens = [tgt_vocab['<sos>']]

    with torch.no_grad():
        for _ in range(max_len):
            decoder_input = torch.LongTensor(decoded_tokens).unsqueeze(0).to(device)
            tgt_mask = model.make_tgt_mask(decoder_input)
            src_tgt_mask = model.make_src_tgt_mask(encoder_inputs, decoder_input)
            output = model.decode(decoder_input, encoder_out, tgt_mask, src_tgt_mask)
            output = model.generator(output)

            output_token = output.argmax(dim=2)[:, -1].item()

            if output_token == tgt_vocab['<eos>']:
                break

            decoded_tokens.append(output_token)

    return ' '.join(index2tar[token] for token in decoded_tokens if token != tgt_vocab['<sos>'])


In [32]:
for seq_index in [3, 50, 100, 300, 1001]:
    input_seq = encoder_input_train[seq_index]
    translated_text = decode_sequence(input_seq, model)

    print("입력문장: ", seq_to_src(encoder_input_train[seq_index]))
    print("정답문장: ", seq_to_tar(decoder_input_train[seq_index]))
    print("번역문장: ", translated_text)
    print("-"*50)

입력문장:  keep back . 
정답문장:  restez en retrait . 
번역문장:  restez en arriere .
--------------------------------------------------
입력문장:  have you finished ? 
정답문장:  avez vous fini ? 
번역문장:  avez vous termine ?
--------------------------------------------------
입력문장:  please sing . 
정답문장:  s il vous plait chantez ! 
번역문장:  s il vous plait chantez !
--------------------------------------------------
입력문장:  are you surprised ? 
정답문장:  es tu surprise ? 
번역문장:  es tu surpris ?
--------------------------------------------------
입력문장:  quit hassling me . 
정답문장:  arretez de m embeter ! 
번역문장:  arrete de m embeter !
--------------------------------------------------
