In [None]:
import torch

from data import tokenize, batch_decode
from model import Transformer

sos_token_id = 101
eos_token_id = 102

num_sequences = 24
h_dim = 768
vocab_size = 119547
checkpoint_path = "checkpoint_model.pt"

device = torch.device("cuda:0")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
model = Transformer(h_dim, num_sequences, vocab_size).to(device)
model.load_state_dict(checkpoint)

input_text = "Thank you for gladly accepting the invitation."
answer = "당신이 초청에 흔쾌히 응해 주셔서 감사합니다."

with torch.no_grad():
    src_token = tokenize(input_text, num_sequences)
    src_token = src_token.to(device)
    tgt_token = torch.tensor([[sos_token_id]], dtype=torch.long)
    tgt_token = tgt_token.to(device)

    for _ in range(num_sequences - 1):
        logits = model(src_token, tgt_token)[:, -1, :]
        next_token = logits.softmax(-1).argmax(-1).unsqueeze(0)
        tgt_token = torch.cat([tgt_token, next_token], dim=-1)
        if next_token == eos_token_id:
            break

print(tgt_token)
output_texts = batch_decode(tgt_token.detach())
print(output_texts)