In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

# ======================
# 1. Data toy example
# ======================
# Giả sử ta có vài cặp (input_text, summary) rất ngắn
pairs = [
    ("the cat sat on the mat", "cat on mat"),
    ("dogs are playing in the park", "dogs playing"),
    ("a man is eating food", "man eating"),
    ("a woman is reading a book", "woman reading"),
]

# Xây vocab đơn giản
from collections import Counter

word_counts = Counter()
for src, tgt in pairs:
    word_counts.update(src.split())
    word_counts.update(tgt.split())

# Tạo vocab
vocab = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"] + sorted(word_counts.keys())
word2idx = {w:i for i,w in enumerate(vocab)}
idx2word = {i:w for w,i in word2idx.items()}

def encode(sentence):
    return [word2idx.get(w, word2idx["<UNK>"]) for w in sentence.split()]

def decode(indices):
    words = [idx2word[i] for i in indices if i not in (word2idx["<PAD>"], word2idx["<SOS>"], word2idx["<EOS>"])]
    return " ".join(words)

data = []
for src, tgt in pairs:
    src_ids = encode(src)
    tgt_ids = [word2idx["<SOS>"]] + encode(tgt) + [word2idx["<EOS>"]]
    data.append((src_ids, tgt_ids))

# Pad function
def pad(seq, max_len):
    return seq + [word2idx["<PAD>"]] * (max_len - len(seq))

# ======================
# 2. Model components
# ======================
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)

    def forward(self, x):
        embed = self.embedding(x)
        outputs, (h, c) = self.lstm(embed)
        return h, c

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h, c):
        embed = self.embedding(x.unsqueeze(1))
        output, (h, c) = self.lstm(embed, (h, c))
        logits = self.fc(output.squeeze(1))
        return logits, h, c

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        batch_size, tgt_len = tgt.shape
        vocab_size = len(vocab)

        outputs = torch.zeros(batch_size, tgt_len, vocab_size).to(self.device)
        h, c = self.encoder(src)

        input_token = tgt[:, 0]  # <SOS>
        for t in range(1, tgt_len):
            output, h, c = self.decoder(input_token, h, c)
            outputs[:, t] = output
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input_token = tgt[:, t] if teacher_force else top1

        return outputs

# ======================
# 3. Training
# ======================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

INPUT_DIM = len(vocab)
OUTPUT_DIM = len(vocab)
EMB_DIM = 32
HID_DIM = 64

enc = Encoder(INPUT_DIM, EMB_DIM, HID_DIM).to(device)
dec = Decoder(OUTPUT_DIM, EMB_DIM, HID_DIM).to(device)
model = Seq2Seq(enc, dec, device).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=word2idx["<PAD>"])

for epoch in range(100):
    epoch_loss = 0
    for src, tgt in data:
        src_tensor = torch.tensor([pad(src, max_len=6)], dtype=torch.long).to(device)
        tgt_tensor = torch.tensor([pad(tgt, max_len=6)], dtype=torch.long).to(device)

        optimizer.zero_grad()
        output = model(src_tensor, tgt_tensor)

        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        tgt = tgt_tensor[:, 1:].reshape(-1)

        loss = criterion(output, tgt)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    if (epoch+1) % 20 == 0:
        print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}")

# ======================
# 4. Inference
# ======================
def summarize(sentence, max_len=6):
    model.eval()
    with torch.no_grad():
        src = torch.tensor([pad(encode(sentence), max_len)], dtype=torch.long).to(device)
        h, c = model.encoder(src)
        input_token = torch.tensor([word2idx["<SOS>"]], dtype=torch.long).to(device)

        outputs = []
        for _ in range(max_len):
            output, h, c = model.decoder(input_token, h, c)
            top1 = output.argmax(1)
            if top1.item() == word2idx["<EOS>"]:
                break
            outputs.append(top1.item())
            input_token = top1
    return decode(outputs)

print("\n--- Demo inference ---")
print("Input: the cat sat on the mat")
print("Summary:", summarize("the cat sat on the mat"))
print("Input: dogs are playing in the park")
print("Summary:", summarize("dogs are playing in the park"))

Epoch 20, Loss: 0.0233
Epoch 40, Loss: 0.0095
Epoch 60, Loss: 0.0054
Epoch 80, Loss: 0.0035
Epoch 100, Loss: 0.0025

--- Demo inference ---
Input: the cat sat on the mat
Summary: cat on mat
Input: dogs are playing in the park
Summary: dogs playing
