Architecture Overview

Encoder: Bi-directional LSTM to get rich context from both directions.
Attention: Dot-product between decoder hidden state and all encoder outputs.
Decoder: LSTM + Attention context → predicts next word.

In [None]:
import torch

import torch.nn as nn 

import torch.nn.functional as F

Encoder with Bi-LSTM

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(Encoder,self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)

        self.lstm(nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True))

    
    def forward(self,x):

        embedded = self.embedding(x) # [B, T, E]

        outputs , (h_n,c_n) = self.lstm(embedded)  # outputs: [B, T, 2H]

        return outputs, (h_n,c_n)  # Return all encoder outputs + final hidden states

Attention Mechanism

In [None]:
class Attention(nn.Module):

    def __init__(self,hidden_dim):
        super(Attention,self).__init__()

        self.hidden_dim = hidden_dim

    def forward(decoder_hidden, encoder_outputs):

            # input_token: [B], decoder_hidden: ([1, B, H], [1, B, H])

            scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(2)).squeeze(2) # [B, T]

            attn_weights = F.softmax(scores, dim=1) # [B, T]

            context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1) # [B, 2H]

            return context, attn_weights