## RNN with Attention Model

In [None]:
import torch
import torch.nn as nn

In [2]:
sample_text = [
    "the animal didn't cross the street because it was too tired",
    "the cat sat on the mat",
]

In [7]:
# vocabulary and word index
vocab = set(' '.join(sample_text).split())
word_to_idx = {word:idx for idx, word in enumerate(vocab)}
ix_to_word = {idx:word for idx, word in enumerate(vocab)}

In [15]:
# encoder/decoder data
pairs = [sentence.split() for sentence in sample_text]
input_data = [[word_to_idx[word] for word in sentence[:-1]] for sentence in pairs]
target_data = [word_to_idx[sentence[-1]] for sentence in pairs]

inputs = [torch.tensor(seq, dtype=torch.long) for seq in input_data]
targets = torch.tensor(target_data, dtype=torch.long)

In [33]:
embedding_dim = 10
hidden_dim = 16
vocab_size = len(vocab)

class RNNWithAttentionModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        
        # Embedding layer translates word indexes to vectors
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        
        # RNN layer for sequentail processing
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)

        # Attention layer computes word significanes, performing linear transformation of hidden_dim to one,
        # yielding a singular attention score per word
        self.attention = nn.Linear(hidden_dim, 1)

        # Final layer outputting vocab_size pinpoints the predicted word index
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        
        # word indexes are embedded
        x = self.embeddings(x)

        # process embeddings in sequentail layer generating output for each word
        out, _ = self.rnn(x)

        # attention scores are derived by applying a linear transformation to the RNN outputs, 
        # normalizing using softmax, and reshaping the tensor using squeeze two to simplify attention calculations.
        att_weights = torch.nn.functional.softmax(self.attention(out).squeeze(2), dim=1)

        # Context vector is formulated by multiplying attention scores with RNN outputs, 
        # creating a weighted sum of the outputs, where weights are the attention scores. 
        # The unsqueeze two operation is important for adjusting tensor dimensions for matrix multiplication with RNN outputs. 
        # The context vector is then summed using torch-dot-sum to feed into the fc layer for the final prediction.
        context = torch.sum(att_weights.unsqueeze(2) * out, dim=1)
        
        return self.fc(context)

def pad_sequences(batch):
    """
    Ensures consistent sequence lengths by padding the input sequences 
    with torch-dot-cat and torch-dot-stack, avoiding any potential length discrepancies
    """
    max_len = max([len(i) for i in batch])
    return torch.stack(
        [torch.cat([seq, torch.zeros(max_len - len(seq)).long()])
             for seq in batch]
    )

In [34]:
loss_finction = nn.CrossEntropyLoss()
model = RNNWithAttentionModel(vocab_size, embedding_dim, hidden_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [37]:
# training
for epoch in range(300):
    model.train()
    optimizer.zero_grad()
    padded_inputs = pad_sequences(inputs)
    output = model(padded_inputs)
    loss = loss_finction(output, targets)
    loss.backward()
    optimizer.step()

    if (epoch % 100) == 0:
        print(f'epoch {epoch}, loss: {loss.item()}')

epoch 0, loss: 0.0005866951541975141
epoch 100, loss: 0.0004257845284882933
epoch 200, loss: 0.00032556717633269727


In [40]:
# testing
for seq, target in zip(input_data, target_data):
    input_test = torch.tensor(seq, dtype=torch.long).unsqueeze(0)
    model.eval()
    output = model(input_test)
    predictions = ix_to_word[torch.argmax(output).item()]
    print(f'\nInput: {" ".join([ix_to_word[i] for i in seq])}')
    print(f'\nTarget: {ix_to_word[target]}')
    print(f'\nOutput: {predictions}')


Input: the animal didn't cross the street because it was too

Target: tired

Output: tired

Input: the cat sat on the

Target: mat

Output: mat
