In [1]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from model import Encoder, Decoder, Seq2Seq
from data import ReverseDataset

OSError: [WinError 1114] Error en una rutina de inicialización de biblioteca de vínculos dinámicos (DLL). Error loading "C:\Users\ralme\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\lib\c10.dll" or one of its dependencies.

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_PATH = "seq2seq_bahdanau.pt"
DATA_PATH = "data/reverse_dataset.txt"

EMB_DIM = 64
ENC_HID_DIM = 128
DEC_HID_DIM = 128
ATTN_DIM = 64
MAX_LEN = 30


In [None]:
dataset = ReverseDataset(DATA_PATH)

SRC_VOCAB_SIZE = len(dataset.src_vocab.itos)
TRG_VOCAB_SIZE = len(dataset.trg_vocab.itos)

PAD_IDX = dataset.pad_idx
SOS_IDX = dataset.sos_idx
EOS_IDX = dataset.eos_idx

print("SRC vocab size:", SRC_VOCAB_SIZE)
print("TRG vocab size:", TRG_VOCAB_SIZE)

In [None]:
encoder = Encoder(
    vocab_size=SRC_VOCAB_SIZE,
    emb_dim=EMB_DIM,
    hid_dim=ENC_HID_DIM,
    bidir=True,
    pad_idx=PAD_IDX
)

decoder = Decoder(
    vocab_size=TRG_VOCAB_SIZE,
    emb_dim=EMB_DIM,
    enc_hid_dim=ENC_HID_DIM * 2,
    dec_hid_dim=DEC_HID_DIM,
    attn_dim=ATTN_DIM,
    pad_idx=PAD_IDX
)

model = Seq2Seq(
    encoder=encoder,
    decoder=decoder,
    sos_idx=SOS_IDX,
    eos_idx=EOS_IDX,
    device=DEVICE
).to(DEVICE)

checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print("Modelo cargado correctamente.")


In [None]:
@torch.no_grad()
def translate_with_attention(sentence):
    tokens = sentence.strip().split()
    
    src_ids = [dataset.src_vocab.stoi.get(t, dataset.src_vocab.stoi["<unk>"]) for t in tokens]
    src_tensor = torch.tensor(src_ids, dtype=torch.long).unsqueeze(0).to(DEVICE)
    src_len = torch.tensor([len(src_ids)], dtype=torch.long).to(DEVICE)
    src_mask = (src_tensor != PAD_IDX).int()

    enc_outputs, (h, c) = model.encoder(src_tensor, src_len)
    hidden = model._init_decoder_hidden(h, c)

    y = torch.tensor([SOS_IDX], device=DEVICE)
    outputs = []
    attentions = []

    for _ in range(MAX_LEN):
        logits, hidden, alpha = model.decoder(y, hidden, enc_outputs, mask=src_mask)
        next_token = logits.argmax(dim=-1).item()

        if next_token == EOS_IDX:
            break

        outputs.append(dataset.trg_vocab.itos[next_token])
        attentions.append(alpha.squeeze(0).cpu().numpy())

        y = torch.tensor([next_token], device=DEVICE)

    return outputs, np.array(attentions), tokens


In [None]:
sentence = "i train models data"

output_tokens, attention_matrix, input_tokens = translate_with_attention(sentence)

print("Input :", sentence)
print("Output:", " ".join(output_tokens))


In [None]:
plt.figure(figsize=(8, 5))

sns.heatmap(
    attention_matrix,
    xticklabels=input_tokens,
    yticklabels=output_tokens,
    cmap="viridis"
)

plt.xlabel("Input tokens")
plt.ylabel("Output tokens")
plt.title("Attention weights (Bahdanau)")
plt.tight_layout()
plt.show()
