데이터셋: http://www.manythings.org/anki fra-eng


In [1]:
import sys
import pathlib
import os

sys.path.append(str(pathlib.Path(os.getcwd()).parent))

import copy
import re
import unicodedata
import numpy as np
import pandas as pd
import torch
from collections import Counter
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader

from src.base_module import *

In [2]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f'device: {device}')

device: mps


In [3]:
num_samples = 33000

In [4]:
def unicode_to_ascii(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

def preprocess_sentence(sent):
    sent = unicode_to_ascii(sent.lower())
    sent = re.sub(r"([?.!,¿])", r" \1", sent)
    sent = re.sub(r"[^a-zA-Z!.?]+", r" ", sent)
    sent = re.sub(r"\s+", " ", sent)

    return sent

def load_preprocess_data():
    encoder_input, decoder_input, decoder_target = [], [], []

    with open('./fra.txt', 'r') as f:
        lines = f.readlines()[:num_samples]
        for i, line in enumerate(lines):
            src, tar, _ = line.strip().split('\t')
            src = [w for w in preprocess_sentence(src).split()]
            tar = preprocess_sentence(tar)
            tgt_in = [w for w in f'<sos> {tar}'.split()]
            tgt_out = [w for w in f'{tar} <eos>'.split()]

            encoder_input.append(src)
            decoder_input.append(tgt_in)
            decoder_target.append(tgt_out)

    return encoder_input, decoder_input, decoder_target

In [5]:
en_sent = u"Have you had dinner?"
fr_sent = u"Avez-vous déjà diné?"

print(f'{en_sent} -> {preprocess_sentence(en_sent)}')
print(f'{fr_sent} -> {preprocess_sentence(fr_sent)}')

Have you had dinner? -> have you had dinner ?
Avez-vous déjà diné? -> avez vous deja dine ?


In [6]:
sents_en_in, sents_fra_in, sents_fra_out = load_preprocess_data()

print(len(sents_en_in))
print(len(sents_fra_in))
print(len(sents_fra_out))

33000
33000
33000


In [7]:
print(sents_en_in[:5])
print(sents_fra_in[:5])
print(sents_fra_out[:5])

[['go', '.'], ['go', '.'], ['go', '.'], ['go', '.'], ['hi', '.']]
[['<sos>', 'va', '!'], ['<sos>', 'marche', '.'], ['<sos>', 'en', 'route', '!'], ['<sos>', 'bouge', '!'], ['<sos>', 'salut', '!']]
[['va', '!', '<eos>'], ['marche', '.', '<eos>'], ['en', 'route', '!', '<eos>'], ['bouge', '!', '<eos>'], ['salut', '!', '<eos>']]


In [8]:
def build_vocab(sents):
    words = []
    for sent in sents:
        for word in sent:
            words.append(word)

    word_counts = Counter(words)
    vocab = sorted(word_counts, key=word_counts.get, reverse=True)

    word2index = {}
    word2index['<PAD>'] = 0
    word2index['<UNK>'] = 1

    for i, word in enumerate(vocab):
        word2index[word] = i + 2

    return word2index

In [9]:
src_vocab = build_vocab(sents_en_in)
tgt_vocab = build_vocab(sents_fra_in + sents_fra_out)

src_vocab_size = len(src_vocab)
tgt_vocab_size = len(tgt_vocab)

print(f'src vocab size: {src_vocab_size}')
print(f'tar vocab size: {tgt_vocab_size}')

src vocab size: 4486
tar vocab size: 7879


In [10]:
index2src = {v: k for k, v in src_vocab.items()}
index2tar = {v: k for k, v in tgt_vocab.items()}

In [11]:
def encode_sentences(sents, word2index):
    encoded_data = []
    for sent in tqdm(sents):
        encoded_sent = []
        for word in sent:
            try:
                encoded_sent.append(word2index[word])
            except KeyError:
                encoded_sent.append(word2index['<UNK>'])
        encoded_data.append(encoded_sent)

    return encoded_data

In [12]:
encoder_input = encode_sentences(sents_en_in, src_vocab)
decoder_input = encode_sentences(sents_fra_in, tgt_vocab)
decoder_target = encode_sentences(sents_fra_out, tgt_vocab)

100%|██████████| 33000/33000 [00:00<00:00, 536445.40it/s]
100%|██████████| 33000/33000 [00:00<00:00, 2197051.25it/s]
100%|██████████| 33000/33000 [00:00<00:00, 2177659.41it/s]


In [13]:
print(encoder_input[:5])
print(decoder_input[:5])
print(decoder_target[:5])

[[27, 2], [27, 2], [27, 2], [27, 2], [736, 2]]
[[3, 68, 11], [3, 204, 2], [3, 26, 491, 11], [3, 561, 11], [3, 954, 11]]
[[68, 11, 4], [204, 2, 4], [26, 491, 11, 4], [561, 11, 4], [954, 11, 4]]


In [14]:
def pad_sentences(sents, max_len=None):
    if max_len is None:
        max_len = max([len(s) for s in sents])

    features = np.zeros((len(sents), max_len), dtype=int)
    for i, sent in enumerate(sents):
        features[i, :len(sent)] = np.array(sent)[:max_len]

    return features

In [15]:
encoder_input = pad_sentences(encoder_input)
decoder_input = pad_sentences(decoder_input)
decoder_target = pad_sentences(decoder_target)

In [16]:
print(encoder_input[:3])
print(decoder_input[:3])
print(decoder_target[:3])

[[27  2  0  0  0  0  0]
 [27  2  0  0  0  0  0]
 [27  2  0  0  0  0  0]]
[[  3  68  11   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  3 204   2   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  3  26 491  11   0   0   0   0   0   0   0   0   0   0   0   0]]
[[ 68  11   4   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [204   2   4   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [ 26 491  11   4   0   0   0   0   0   0   0   0   0   0   0   0]]


In [17]:
print(encoder_input.shape)
print(decoder_input.shape)
print(decoder_target.shape)

(33000, 7)
(33000, 16)
(33000, 16)


In [18]:
indices = np.arange(encoder_input.shape[0])
np.random.shuffle(indices)
print('랜덤 시퀀스 :',indices)

랜덤 시퀀스 : [ 8594 15783 24996 ... 22517 20878 17362]


In [19]:
encoder_input = encoder_input[indices]
decoder_input = decoder_input[indices]
decoder_target = decoder_target[indices]

In [20]:
print([index2src[word] for word in encoder_input[30997]])
print([index2tar[word] for word in decoder_input[30997]])
print([index2tar[word] for word in decoder_target[30997]])

['i', 'm', 'ticklish', '.', '<PAD>', '<PAD>', '<PAD>']
['<sos>', 'je', 'suis', 'chatouilleuse', '.', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
['je', 'suis', 'chatouilleuse', '.', '<eos>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']


In [21]:
n_of_val = int(num_samples * 0.1)
print('검증 데이터의 개수 :',n_of_val)

encoder_input_train = encoder_input[:-n_of_val]
decoder_input_train = decoder_input[:-n_of_val]
decoder_target_train = decoder_target[:-n_of_val]

encoder_input_test = encoder_input[-n_of_val:]
decoder_input_test = decoder_input[-n_of_val:]
decoder_target_test = decoder_target[-n_of_val:]

검증 데이터의 개수 : 3300


In [22]:
print('훈련 source 데이터의 크기 :',encoder_input_train.shape)
print('훈련 target 데이터의 크기 :',decoder_input_train.shape)
print('훈련 target 레이블의 크기 :',decoder_target_train.shape)
print('테스트 source 데이터의 크기 :',encoder_input_test.shape)
print('테스트 target 데이터의 크기 :',decoder_input_test.shape)
print('테스트 target 레이블의 크기 :',decoder_target_test.shape)

훈련 source 데이터의 크기 : (29700, 7)
훈련 target 데이터의 크기 : (29700, 16)
훈련 target 레이블의 크기 : (29700, 16)
테스트 source 데이터의 크기 : (3300, 7)
테스트 target 데이터의 크기 : (3300, 16)
테스트 target 레이블의 크기 : (3300, 16)


In [23]:
encoder_input_train_tensor = torch.tensor(encoder_input_train, dtype=torch.long)
decoder_input_train_tensor = torch.tensor(decoder_input_train, dtype=torch.long)
decoder_target_train_tensor = torch.tensor(decoder_target_train, dtype=torch.long)

encoder_input_test_tensor = torch.tensor(encoder_input_test, dtype=torch.long)
decoder_input_test_tensor = torch.tensor(decoder_input_test, dtype=torch.long)
decoder_target_test_tensor = torch.tensor(decoder_target_test, dtype=torch.long)

# 데이터셋 및 데이터로더 생성
batch_size = 512

train_dataset = TensorDataset(encoder_input_train_tensor, decoder_input_train_tensor, decoder_target_train_tensor)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = TensorDataset(encoder_input_test_tensor, decoder_input_test_tensor, decoder_target_test_tensor)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [24]:
max_len = 20
d_embed = 128
d_model = 128
n_layer = 4
h = 4
d_ff = 256

src_token_embed = TokenEmbedding(d_embed=d_embed, vocab_size=src_vocab_size)
tgt_token_embed = TokenEmbedding(d_embed=d_embed, vocab_size=tgt_vocab_size)
pos_embed = PositionalEncoding(d_embed=d_embed, max_len=max_len, device=device)
src_embed = TransformerEmbedding(token_embed=src_token_embed, pos_embed=copy.deepcopy(pos_embed))
tgt_embed = TransformerEmbedding(token_embed=tgt_token_embed, pos_embed=copy.deepcopy(pos_embed))

attention = MultiHeadAttentionLayer(d_model=d_model, h=h, qkv_fc=nn.Linear(d_embed, d_model), out_fc=nn.Linear(d_model, d_embed)).to(device)
position_ff = PositionWiseFeedForwardLayer(fc1=nn.Linear(d_embed, d_ff), fc2=nn.Linear(d_ff, d_embed)).to(device)

encoder_block = EncoderBlock(self_attention=copy.deepcopy(attention), position_ff=copy.deepcopy(position_ff), d_model=d_model, device=device)
decoder_block = DecoderBlock(self_attention=copy.deepcopy(attention), cross_attention=copy.deepcopy(attention), position_ff=copy.deepcopy(position_ff), d_model=d_model, device=device)

encoder = Encoder(encoder_block=encoder_block, n_layer=n_layer)
decoder = Decoder(decoder_block=decoder_block, n_layer=n_layer)

generator = nn.Linear(d_embed, tgt_vocab_size)

model = Transformer(
    src_embed=src_embed,
    tgt_embed=tgt_embed,
    encoder=encoder,
    decoder=decoder,
    generator=generator
).to(device)
loss_function = nn.CrossEntropyLoss(ignore_index=0).to(device)
optimizer = torch.optim.Adam(model.parameters())

In [25]:
def evaluation(model, dataloader, loss_function, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_count = 0

    with torch.no_grad():
        for encoder_inputs, decoder_inputs, decoder_targets in dataloader:
            encoder_inputs = encoder_inputs.to(device)
            decoder_inputs = decoder_inputs.to(device)
            decoder_targets = decoder_targets.to(device)

            outputs, _ = model(encoder_inputs, decoder_inputs)

            loss = loss_function(outputs.view(-1, outputs.size(-1)), decoder_targets.view(-1))
            total_loss += loss.item()

            mask = decoder_targets != 0
            total_correct += ((outputs.argmax(dim=-1) == decoder_targets) * mask).sum().item()
            total_count += mask.sum().item()

    return total_loss / len(dataloader), total_correct / total_count

In [26]:
num_epochs = 40

In [27]:
for encoder_inputs, decoder_inputs, decoder_targets in tqdm(train_dataloader):
    encoder_inputs = encoder_inputs.to(device)
    decoder_inputs = decoder_inputs.to(device)
    decoder_targets = decoder_targets.to(device)
    print(decoder_inputs[0])
    print(decoder_targets[0])
    break

  0%|          | 0/59 [00:00<?, ?it/s]

tensor([  3, 224,  16, 872,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0], device='mps:0')
tensor([224,  16, 872,   2,   4,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0], device='mps:0')





In [28]:
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()

    for encoder_inputs, decoder_inputs, decoder_targets in tqdm(train_dataloader):
        encoder_inputs = encoder_inputs.to(device)
        decoder_inputs = decoder_inputs.to(device)
        decoder_targets = decoder_targets.to(device)
        optimizer.zero_grad()

        outputs, _ = model(encoder_inputs, decoder_inputs)

        loss = loss_function(outputs.view(-1, outputs.size(-1)), decoder_targets.view(-1))
        loss.backward()
        optimizer.step()

    train_loss, train_acc = evaluation(model, train_dataloader, loss_function, device)
    valid_loss, valid_acc = evaluation(model, valid_dataloader, loss_function, device)

    print(f'Epoch: {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Valid Loss: {valid_loss:.4f} | Valid Acc: {valid_acc:.4f}')

    if valid_loss < best_val_loss:
        print(f'Validation loss improved from {best_val_loss:.4f} to {valid_loss:.4f}. 체크포인트를 저장합니다.')
        best_val_loss = valid_loss
        torch.save(model.state_dict(), 'best_model_checkpoint.pth')


100%|██████████| 59/59 [00:07<00:00,  7.55it/s]


Epoch: 1/40 | Train Loss: 5.4994 | Train Acc: 0.4404 | Valid Loss: 5.9098 | Valid Acc: 0.4348
Validation loss improved from inf to 5.9098. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.43it/s]


Epoch: 2/40 | Train Loss: 3.6005 | Train Acc: 0.5078 | Valid Loss: 4.1833 | Valid Acc: 0.4924
Validation loss improved from 5.9098 to 4.1833. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.51it/s]


Epoch: 3/40 | Train Loss: 2.8520 | Train Acc: 0.5426 | Valid Loss: 3.5429 | Valid Acc: 0.5253
Validation loss improved from 4.1833 to 3.5429. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.51it/s]


Epoch: 4/40 | Train Loss: 2.4394 | Train Acc: 0.5641 | Valid Loss: 3.2080 | Valid Acc: 0.5441
Validation loss improved from 3.5429 to 3.2080. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:07<00:00,  8.28it/s]


Epoch: 5/40 | Train Loss: 2.1807 | Train Acc: 0.5836 | Valid Loss: 2.9929 | Valid Acc: 0.5561
Validation loss improved from 3.2080 to 2.9929. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.49it/s]


Epoch: 6/40 | Train Loss: 1.9272 | Train Acc: 0.6086 | Valid Loss: 2.7914 | Valid Acc: 0.5813
Validation loss improved from 2.9929 to 2.7914. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:07<00:00,  8.33it/s]


Epoch: 7/40 | Train Loss: 1.7340 | Train Acc: 0.6322 | Valid Loss: 2.6533 | Valid Acc: 0.5931
Validation loss improved from 2.7914 to 2.6533. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.61it/s]


Epoch: 8/40 | Train Loss: 1.6052 | Train Acc: 0.6451 | Valid Loss: 2.5651 | Valid Acc: 0.5988
Validation loss improved from 2.6533 to 2.5651. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.66it/s]


Epoch: 9/40 | Train Loss: 1.4422 | Train Acc: 0.6702 | Valid Loss: 2.4631 | Valid Acc: 0.6137
Validation loss improved from 2.5651 to 2.4631. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.61it/s]


Epoch: 10/40 | Train Loss: 1.3228 | Train Acc: 0.6889 | Valid Loss: 2.3813 | Valid Acc: 0.6237
Validation loss improved from 2.4631 to 2.3813. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.68it/s]


Epoch: 11/40 | Train Loss: 1.2305 | Train Acc: 0.7034 | Valid Loss: 2.3249 | Valid Acc: 0.6326
Validation loss improved from 2.3813 to 2.3249. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.68it/s]


Epoch: 12/40 | Train Loss: 1.1537 | Train Acc: 0.7144 | Valid Loss: 2.2847 | Valid Acc: 0.6350
Validation loss improved from 2.3249 to 2.2847. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.63it/s]


Epoch: 13/40 | Train Loss: 1.0769 | Train Acc: 0.7283 | Valid Loss: 2.2481 | Valid Acc: 0.6391
Validation loss improved from 2.2847 to 2.2481. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.72it/s]


Epoch: 14/40 | Train Loss: 1.0202 | Train Acc: 0.7390 | Valid Loss: 2.2084 | Valid Acc: 0.6460
Validation loss improved from 2.2481 to 2.2084. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.69it/s]


Epoch: 15/40 | Train Loss: 0.9509 | Train Acc: 0.7488 | Valid Loss: 2.1860 | Valid Acc: 0.6549
Validation loss improved from 2.2084 to 2.1860. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.55it/s]


Epoch: 16/40 | Train Loss: 0.9093 | Train Acc: 0.7577 | Valid Loss: 2.1524 | Valid Acc: 0.6559
Validation loss improved from 2.1860 to 2.1524. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.47it/s]


Epoch: 17/40 | Train Loss: 0.8647 | Train Acc: 0.7677 | Valid Loss: 2.1313 | Valid Acc: 0.6638
Validation loss improved from 2.1524 to 2.1313. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.51it/s]


Epoch: 18/40 | Train Loss: 0.8414 | Train Acc: 0.7712 | Valid Loss: 2.1280 | Valid Acc: 0.6657
Validation loss improved from 2.1313 to 2.1280. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.44it/s]


Epoch: 19/40 | Train Loss: 0.8221 | Train Acc: 0.7754 | Valid Loss: 2.1303 | Valid Acc: 0.6652


100%|██████████| 59/59 [00:06<00:00,  8.48it/s]


Epoch: 20/40 | Train Loss: 0.8011 | Train Acc: 0.7774 | Valid Loss: 2.1239 | Valid Acc: 0.6652
Validation loss improved from 2.1280 to 2.1239. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:07<00:00,  8.40it/s]


Epoch: 21/40 | Train Loss: 0.7678 | Train Acc: 0.7849 | Valid Loss: 2.1105 | Valid Acc: 0.6708
Validation loss improved from 2.1239 to 2.1105. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.53it/s]


Epoch: 22/40 | Train Loss: 0.7377 | Train Acc: 0.7907 | Valid Loss: 2.0937 | Valid Acc: 0.6761
Validation loss improved from 2.1105 to 2.0937. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.54it/s]


Epoch: 23/40 | Train Loss: 0.7188 | Train Acc: 0.7936 | Valid Loss: 2.0693 | Valid Acc: 0.6800
Validation loss improved from 2.0937 to 2.0693. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.56it/s]


Epoch: 24/40 | Train Loss: 0.6986 | Train Acc: 0.7980 | Valid Loss: 2.0744 | Valid Acc: 0.6801


100%|██████████| 59/59 [00:07<00:00,  8.34it/s]


Epoch: 25/40 | Train Loss: 0.6976 | Train Acc: 0.7976 | Valid Loss: 2.0935 | Valid Acc: 0.6785


100%|██████████| 59/59 [00:07<00:00,  8.33it/s]


Epoch: 26/40 | Train Loss: 0.6792 | Train Acc: 0.8016 | Valid Loss: 2.0886 | Valid Acc: 0.6791


100%|██████████| 59/59 [00:06<00:00,  8.44it/s]


Epoch: 27/40 | Train Loss: 0.6574 | Train Acc: 0.8047 | Valid Loss: 2.0742 | Valid Acc: 0.6814


100%|██████████| 59/59 [00:07<00:00,  8.25it/s]


Epoch: 28/40 | Train Loss: 0.6428 | Train Acc: 0.8111 | Valid Loss: 2.0440 | Valid Acc: 0.6828
Validation loss improved from 2.0693 to 2.0440. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:07<00:00,  8.39it/s]


Epoch: 29/40 | Train Loss: 0.6219 | Train Acc: 0.8163 | Valid Loss: 2.0336 | Valid Acc: 0.6865
Validation loss improved from 2.0440 to 2.0336. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.49it/s]


Epoch: 30/40 | Train Loss: 0.6171 | Train Acc: 0.8167 | Valid Loss: 2.0474 | Valid Acc: 0.6869


100%|██████████| 59/59 [00:07<00:00,  8.37it/s]


Epoch: 31/40 | Train Loss: 0.6017 | Train Acc: 0.8205 | Valid Loss: 2.0378 | Valid Acc: 0.6859


100%|██████████| 59/59 [00:07<00:00,  8.34it/s]


Epoch: 32/40 | Train Loss: 0.5875 | Train Acc: 0.8230 | Valid Loss: 2.0226 | Valid Acc: 0.6894
Validation loss improved from 2.0336 to 2.0226. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:07<00:00,  8.40it/s]


Epoch: 33/40 | Train Loss: 0.5697 | Train Acc: 0.8278 | Valid Loss: 2.0281 | Valid Acc: 0.6895


100%|██████████| 59/59 [00:07<00:00,  8.40it/s]


Epoch: 34/40 | Train Loss: 0.5510 | Train Acc: 0.8316 | Valid Loss: 2.0146 | Valid Acc: 0.6925
Validation loss improved from 2.0226 to 2.0146. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:07<00:00,  8.40it/s]


Epoch: 35/40 | Train Loss: 0.5388 | Train Acc: 0.8339 | Valid Loss: 2.0174 | Valid Acc: 0.6954


100%|██████████| 59/59 [00:07<00:00,  8.25it/s]


Epoch: 36/40 | Train Loss: 0.5361 | Train Acc: 0.8348 | Valid Loss: 2.0157 | Valid Acc: 0.6942


100%|██████████| 59/59 [00:07<00:00,  8.36it/s]


Epoch: 37/40 | Train Loss: 0.5264 | Train Acc: 0.8379 | Valid Loss: 2.0077 | Valid Acc: 0.6987
Validation loss improved from 2.0146 to 2.0077. 체크포인트를 저장합니다.


100%|██████████| 59/59 [00:06<00:00,  8.44it/s]


Epoch: 38/40 | Train Loss: 0.5204 | Train Acc: 0.8387 | Valid Loss: 2.0147 | Valid Acc: 0.6953


100%|██████████| 59/59 [00:07<00:00,  8.34it/s]


Epoch: 39/40 | Train Loss: 0.5316 | Train Acc: 0.8338 | Valid Loss: 2.0366 | Valid Acc: 0.6956


100%|██████████| 59/59 [00:07<00:00,  8.22it/s]


Epoch: 40/40 | Train Loss: 0.5223 | Train Acc: 0.8363 | Valid Loss: 2.0369 | Valid Acc: 0.6957


In [54]:
model.load_state_dict(torch.load('best_model_checkpoint.pth', weights_only=True))
model.to(device)

val_loss, val_accuracy = evaluation(model, valid_dataloader, loss_function, device)

print(f'Best model validation loss: {val_loss:.4f}')
print(f'Best model validation accuracy: {val_accuracy:.4f}')

Best model validation loss: 2.0075
Best model validation accuracy: 0.6981


In [55]:
def seq_to_src(input_seq):
  sentence = ''
  for encoded_word in input_seq:
    if(encoded_word != 0):
      sentence = sentence + index2src[encoded_word] + ' '
  return sentence

def seq_to_tar(input_seq):
  sentence = ''
  for encoded_word in input_seq:
    if(encoded_word != 0 and encoded_word != tgt_vocab['<sos>'] and encoded_word != tgt_vocab['<eos>']):
      sentence = sentence + index2tar[encoded_word] + ' '
  return sentence

In [58]:
def decode_sequence(input_seq, model):
    model.eval()
    encoder_inputs = torch.LongTensor(input_seq).unsqueeze(0).to(device)
    src_mask = model.make_src_mask(encoder_inputs)
    encoder_out = model.encode(encoder_inputs, src_mask)

    decoded_tokens = [tgt_vocab['<sos>']]

    with torch.no_grad():
        for _ in range(max_len):
            decoder_input = torch.LongTensor(decoded_tokens).unsqueeze(0).to(device)
            tgt_mask = model.make_tgt_mask(decoder_input)
            src_tgt_mask = model.make_src_tgt_mask(encoder_inputs, decoder_input)
            output = model.decode(decoder_input, encoder_out, tgt_mask, src_tgt_mask)
            output = model.generator(output)

            output_token = output.argmax(dim=2)[:, -1].item()

            if output_token == tgt_vocab['<eos>']:
                break

            decoded_tokens.append(output_token)

    return ' '.join(index2tar[token] for token in decoded_tokens if token != tgt_vocab['<sos>'])


In [59]:
for seq_index in [3, 50, 100, 300, 1001]:
    input_seq = encoder_input_train[seq_index]
    translated_text = decode_sequence(input_seq, model)

    print("입력문장: ", seq_to_src(encoder_input_train[seq_index]))
    print("정답문장: ", seq_to_tar(decoder_input_train[seq_index]))
    print("번역문장: ", translated_text)
    print("-"*50)

입력문장:  you weren t ready . 
정답문장:  vous n etiez pas pret . 
번역문장:  tu ne t aime pas pretes .
--------------------------------------------------
입력문장:  i saw you . 
정답문장:  je vous vis . 
번역문장:  je t ai vu .
--------------------------------------------------
입력문장:  i am a shy boy . 
정답문장:  je suis un garcon timide . 
번역문장:  je suis un garcon timide .
--------------------------------------------------
입력문장:  i loved you . 
정답문장:  je t aimais . 
번역문장:  je t aimais .
--------------------------------------------------
입력문장:  please help tom . 
정답문장:  aidez tom s il vous plait . 
번역문장:  aidez tom s il vous plait .
--------------------------------------------------
