Architecture Overview

Encoder: Bi-directional LSTM to get rich context from both directions.
Attention: Dot-product between decoder hidden state and all encoder outputs.
Decoder: LSTM + Attention context → predicts next word.

In [1]:
import torch

import torch.nn as nn 

import torch.nn.functional as F

Encoder with Bi-LSTM

In [2]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(Encoder,self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)

        self.lstm(nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True))

    
    def forward(self,x):

        embedded = self.embedding(x) # [B, T, E]

        outputs , (h_n,c_n) = self.lstm(embedded)  # outputs: [B, T, 2H]

        return outputs, (h_n,c_n)  # Return all encoder outputs + final hidden states

Attention Mechanism

In [3]:
class Attention(nn.Module):

    def __init__(self,hidden_dim):
        super(Attention,self).__init__()

        self.hidden_dim = hidden_dim

    def forward(decoder_hidden, encoder_outputs):

            # input_token: [B], decoder_hidden: ([1, B, H], [1, B, H])

            scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(2)).squeeze(2) # [B, T]

            attn_weights = F.softmax(scores, dim=1) # [B, T]

            context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1) # [B, 2H]

            return context, attn_weights

Decoder with Attention

In [4]:
class Decoder(nn.Module):

    def __init__(self, vocab_size, embed_dim,enc_hidden_dim, dec_hidden_dim):
        super(Decoder,self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)

        self.lstm = nn.LSTM(embed_dim + enc_hidden_dim * 2, dec_hidden_dim, batch_first=True)

        self.fc = nn.Linear(dec_hidden_dim, vocab_size) # Multiply hidden by 2 for bidirectional

        self.attention = Attention(dec_hidden_dim)

    
    def forward(self, input_token, decoder_hidden, encoder_outputs):

        # input_token: [B], decoder_hidden: ([1, B, H], [1, B, H])

        embedded = self.embedding(input_token).unsqueeze(1)  # [B, 1, E]

        # Calculate attention context

        hidden_state = decoder_hidden[0].squeeze(0)  # [B, H]

        context, attn_weights = self.attention(hidden_state, encoder_outputs)  # [B, 2H]

        # Combine context + embedding

        combined = torch.cat((embedded, context.unsqueeze(1)), dim=2) # [B, 1, E+2H]

        # Pass through LSTM

        output, decoder_hidden = self.lstm(combined, decoder_hidden) # [B, 1, H]


        logits = self.fc(output.squeeze(1)) # [B, vocab_size]

        return logits, decoder_hidden, attn_weights

Seq2Seq with Attention

In [5]:
class Seq2Seq(nn.Module):

    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()

        self.encoder = encoder

        self.decoder = decoder

        self.device = device


    def forward(self, src, tgt, teacher_forcing_ratio = 0.5):

        batch_size, tgt_len = tgt.shape

        vocab_size = self.decoder.fc.out_features


        outputs = torch.zeros(batch_size,tgt_len, vocab_size).to(self.device)

        encoder_outputs , encoder_hidden = self.encoder(src)


        hidden = (encoder_hidden[0][:1], encoder_hidden[1][:1]) # Use only one direction
 
        input_token = tgt[:,0] # <sos>


        for t in range(1, tgt_len):

            output, hidden, _ = self.decoder(input_token, hidden, encoder_outputs)

            outputs[:,t,:] = output

            top1 = output.argmax(1)

            input_token = tgt[:,t] if torch.rand(1).item() < teacher_forcing_ratio else top1

        
        return outputs

Summary

This model builds a powerful attention-based decoder.
It mimics the way transformers later evolved: querying the encoder outputs per timestep.
Great for machine translation, summarization, and any sequence generation.