In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

# Toy dataset: English to pseudo-French
pairs = [
    ("i love you", "je t'aime"),
    ("hello", "salut"),
    ("how are you", "comment ça va"),
    ("thank you", "merci"),
    ("yes", "oui"),
    ("no", "non")
]

# Special tokens
SOS = "<sos>"
EOS = "<eos>"
PAD = "<pad>"

def tokenize(sentence):
    return sentence.lower().split()

def build_vocab(sentences):
    vocab = {PAD: 0, SOS: 1, EOS: 2}
    idx = 3
    for sent in sentences:
        for word in tokenize(sent):
            if word not in vocab:
                vocab[word] = idx
                idx += 1
    return vocab

# Build vocabularies
src_sentences = [src for src, _ in pairs]
tgt_sentences = [tgt for _, tgt in pairs]

src_vocab = build_vocab(src_sentences)
tgt_vocab = build_vocab(tgt_sentences)

inv_src_vocab = {idx: word for word, idx in src_vocab.items()}
inv_tgt_vocab = {idx: word for word, idx in tgt_vocab.items()}

# Encode function
def encode(sentence, vocab):
    tokens = [vocab[word] for word in tokenize(sentence)]
    return torch.tensor(tokens, dtype=torch.long)

# Custom Dataset
class Seq2SeqDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        src_tensor = encode(src, src_vocab)
        tgt_tensor = torch.tensor(
            [tgt_vocab[SOS]] + [tgt_vocab[word] for word in tokenize(tgt)] + [tgt_vocab[EOS]],
            dtype=torch.long
        )
        return src_tensor, tgt_tensor

# Collate function for padding
def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=src_vocab[PAD])
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=tgt_vocab[PAD])
    return src_batch, tgt_batch

# Dataset and DataLoader
dataset = Seq2SeqDataset(pairs)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

print("Sample vocab:")
print("SRC:", src_vocab)
print("TGT:", tgt_vocab)


[Encoder LSTM] ---> [Attention + Decoder LSTM] ---> Target sequence


🔧 Step 2.1: Encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)  # [B, T, E]
        outputs, (hidden, cell) = self.lstm(embedded)  # outputs: [B, T, H]
        return outputs, hidden, cell


 Step 2.2: Attention

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Linear(hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden: [1, B, H], encoder_outputs: [B, T, H]
        B, T, H = encoder_outputs.shape
        hidden = hidden[-1].unsqueeze(1).repeat(1, T, 1)  # [B, T, H]

        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))  # [B, T, H]
        attention = self.v(energy).squeeze(2)  # [B, T]
        return torch.softmax(attention, dim=1)


 Step 2.3: Decoder with Attention

In [None]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.lstm = nn.LSTM(hidden_dim + emb_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim * 2, output_dim)
        self.attention = Attention(hidden_dim)

    def forward(self, input_token, hidden, cell, encoder_outputs):
        input_token = input_token.unsqueeze(1)  # [B, 1]
        embedded = self.embedding(input_token)  # [B, 1, E]

        attn_weights = self.attention(hidden, encoder_outputs)  # [B, T]
        attn_weights = attn_weights.unsqueeze(1)  # [B, 1, T]

        context = torch.bmm(attn_weights, encoder_outputs)  # [B, 1, H]

        rnn_input = torch.cat((embedded, context), dim=2)  # [B, 1, E+H]
        output, (hidden, cell) = self.lstm(rnn_input, (hidden, cell))  # output: [B, 1, H]

        prediction = self.fc_out(torch.cat((output, context), dim=2).squeeze(1))  # [B, output_dim]
        return prediction, hidden, cell


 Step 2.4: Seq2Seq Wrapper

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, tgt, teacher_forcing_ratio=0.5):
        B, tgt_len = tgt.shape
        vocab_size = self.decoder.fc_out.out_features
        outputs = torch.zeros(B, tgt_len, vocab_size).to(self.device)

        encoder_outputs, hidden, cell = self.encoder(src)

        input_token = tgt[:, 0]  # <sos>

        for t in range(1, tgt_len):
            output, hidden, cell = self.decoder(input_token, hidden, cell, encoder_outputs)
            outputs[:, t] = output

            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input_token = tgt[:, t] if teacher_force else top1

        return outputs


Step 3: Training the Seq2Seq Model

In [None]:
import torch.nn.functional as F

def train(model, data_loader, optimizer, loss_fn, device):
    model.train()
    epoch_loss = 0

    for src, tgt in data_loader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()

        output = model(src, tgt)  # [B, T, vocab_size]
        output_dim = output.shape[-1]

        output = output[:, 1:].reshape(-1, output_dim)
        tgt = tgt[:, 1:].reshape(-1)

        loss = loss_fn(output, tgt)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(data_loader)


 3.2: Evaluation / Inference

In [None]:
def translate_sentence(model, sentence, src_vocab, tgt_vocab, max_len=10, device='cpu'):
    model.eval()
    tokens = [src_vocab.get(tok, src_vocab["<unk>"]) for tok in sentence]
    src_tensor = torch.tensor(tokens).unsqueeze(0).to(device)

    with torch.no_grad():
        encoder_outputs, hidden, cell = model.encoder(src_tensor)

    input_token = torch.tensor([tgt_vocab["<sos>"]]).to(device)
    generated_tokens = []

    for _ in range(max_len):
        with torch.no_grad():
            output, hidden, cell = model.decoder(input_token, hidden, cell, encoder_outputs)
            top1 = output.argmax(1)
            generated_tokens.append(top1.item())
            input_token = top1

            if top1.item() == tgt_vocab["<eos>"]:
                break

    inv_vocab = {v: k for k, v in tgt_vocab.items()}
    return [inv_vocab.get(idx, "<unk>") for idx in generated_tokens]


 3.3: Example Setup for Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

INPUT_DIM = len(input_vocab)
OUTPUT_DIM = len(target_vocab)
EMB_DIM = 64
HID_DIM = 128

enc = Encoder(INPUT_DIM, EMB_DIM, HID_DIM)
dec = Decoder(OUTPUT_DIM, EMB_DIM, HID_DIM)

model = Seq2Seq(enc, dec, device).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss(ignore_index=target_vocab["<pad>"])


🌀 Loop Over Epochs

In [None]:
for epoch in range(1, 11):
    loss = train(model, train_loader, optimizer, loss_fn, device)
    print(f"Epoch {epoch}, Loss: {loss:.4f}")


🔍 Try Inference

In [None]:
test_sentence = ["i", "am", "learning"]
output_words = translate_sentence(model, test_sentence, input_vocab, target_vocab, device=device)
print("Output:", " ".join(output_words))
