In [19]:
import torch

import torch.nn as nn

import torch.nn.functional as F

In [20]:
class BahdanauAttention(nn.Module):

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(BahdanauAttention,self).__init__()

        self.encoder_attn = nn.Linear(encoder_dim, attention_dim)

        self.decoder_attn = nn.Linear(decoder_dim, attention_dim)

        self.score_layer = nn.Linear(attention_dim, 1)

    
    def forward(self, encoder_outputs, decoder_hidden):

        # encoder_outputs: [batch, seq_len, enc_dim]
            
        # decoder_hidden: [batch, dec_dim]

        seq_len = encoder_outputs.size(1)

        # Repeat decoder hidden to seq_len

        dec_hidden_exp = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1)


        # Score calculation

        energy = torch.tanh(self.encoder_attn(encoder_outputs) + self.decoder_attn(dec_hidden_exp)) # [B, T, attn_dim]

        score  = self.score_layer(energy).squeeze(2) # [B, T]


        attn_weights = F.softmax(score, dim=1)  # [B, T]

        context_vector = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)  # [B, enc_dim]

        return context_vector, attn_weights

1. Encoder LSTM

In [21]:
class Encoder(nn.Module):

    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(Encoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)

        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)

    def forward(self, x):

        embed = self.embedding(x) # [B, T, E]

        outputs, (h,c) = self.lstm(embed)  # outputs: [B, T, H]

        return outputs, (h, c)

2. Decoder with Bahdanau Attention

In [22]:
class Decoder(nn.Module):

    def __init__(self, vocab_size, embed_dim, enc_hidden_dim, dec_hidden_dim, attn_dim):
        super(Decoder, self).__init__()

        self.embedding =  nn.Embedding(vocab_size, embed_dim)

        self.attention = BahdanauAttention(enc_hidden_dim, dec_hidden_dim, attn_dim)

        self.lstm = nn.LSTMCell(embed_dim + enc_hidden_dim , dec_hidden_dim)

        self.fc = nn.Linear(dec_hidden_dim, vocab_size)

    
    def forward(self, input_token, hidden,cell, encoder_outputs):

        embed = self.embedding(input_token)  # [B, E]

        context , attn_weights = self.attention(encoder_outputs, hidden) # [B, H], [B, T]

        lstm_input = torch.cat([embed, context], dim=1) # [B, E+H]

        hidden, cell = self.lstm(lstm_input, (hidden, cell))  # [B, H], [B, H]

        output = self.fc(hidden)

        return output, hidden, cell, attn_weights

3. Seq2Seq Wrapper

In [23]:
class Seq2Seq(nn.Module):

    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()

        self.encoder = encoder

        self.decoder = decoder

        self.device = device

    
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):

        batch_size, trg_len = trg.size()

        vocab_size = self.decoder.fc.out_features


        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(self.device)

        encoder_outputs , (h,c) = self.encoder(src)


        input_token = trg[:, 0] # <sos>


        for t in range(1, trg_len):

            output, h, c ,  _ = self.decoder(input_token, h[0], c[0], encoder_outputs)

            outputs[:, t] = output

            top1 = output.argmax(1)

            input_token = trg[:, t] if torch.rand(1).item() < teacher_forcing_ratio else top1

        return outputs

Final Step:  Testing with example

In [24]:
# Hyperparams

vocab_size = 10

embed_dim = 16

hidden_dim = 32

attn_dim = 32

device  = torch.device("mps" if torch.backends.mps.is_available() else 'cpu')

print(f'Using Device :{device}')


# Instantiate

encoder = Encoder(vocab_size, embed_dim, hidden_dim).to(device)

decoder = Decoder(vocab_size, embed_dim, hidden_dim, hidden_dim, attn_dim).to(device)

model = Seq2Seq(encoder, decoder, device).to(device)


# Dummy input

src = torch.randint(0, vocab_size, (2,5)).to(device)

trg = torch.randint(0, vocab_size, (2,7)).to(device)


# Run forward

output = model(src, trg)

print("Output shape:", output.shape)  # [batch, trg_len, vocab_size]


Using Device :mps


RuntimeError: linear(): input and weight.T shapes cannot be multiplied (160x1 and 32x32)