[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/xiptos/is_notes/blob/main/attention_lstm.ipynb)

# Attention Mechanism

The Attention mechanism is inspired by how humans focus on different parts of an input when performing a task. In the context of neural networks, it allows the model to dynamically assign weights to different parts of the input sequence, highlighting the most relevant information.

The general steps of the Attention mechanism are:

- Calculate the alignment scores between the query (usually the current hidden state) and each element in the key sequence (usually the sequence of hidden states from an encoder).

- Apply a softmax function to the alignment scores to get the attention weights.
Compute the weighted sum of the values (usually the same as the keys) using the attention weights.

Based on [https://www.codegenes.net/blog/attention-lstm-pytorch/](https://www.codegenes.net/blog/attention-lstm-pytorch/)

# Attention LSTM

An Attention LSTM combines the power of LSTM for sequential data processing and the Attention mechanism for focusing on relevant parts of the sequence. In an Attention LSTM, the Attention mechanism is typically applied after the LSTM layer to help the model better understand the importance of different time steps in the sequence.

# The model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.attn = nn.Linear(hidden_dim, 1)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, return_attention=False):
        """
        x: (batch, seq_len) of token ids
        """
        emb = self.embedding(x)             # (B, T, E)
        lstm_out, _ = self.lstm(emb)        # (B, T, H)
        scores = self.attn(lstm_out)        # (B, T, 1)
        attn_weights = F.softmax(scores, dim=1)      # (B, T, 1)
        context = torch.sum(attn_weights * lstm_out, dim=1)  # (B, H)
        out = self.fc(context)              # (B, output_dim)

        if return_attention:
            return out, attn_weights.squeeze(-1)     # (B, T)
        return out

In [None]:
from collections import Counter
from torch.utils.data import Dataset, DataLoader

sentences = [
    "i love cats",
    "i like dogs",
    "i hate homework",
    "homework is bad",
    "cats are great",
]

labels = [1, 1, 0, 0, 1]  # 1=positive, 0=negative (toy)

# --- build vocab ---
PAD = "<pad>"
UNK = "<unk>"

tokens = [w for s in sentences for w in s.split()]
freqs = Counter(tokens)
vocab = {PAD: 0, UNK: 1}
for w in freqs:
    vocab[w] = len(vocab)

pad_idx = vocab[PAD]

def encode(sent, max_len=None):
    ids = [vocab.get(w, vocab[UNK]) for w in sent.split()]
    if max_len is not None:
        ids = ids[:max_len]
        if len(ids) < max_len:
            ids += [pad_idx] * (max_len - len(ids))
    return ids

max_len = max(len(s.split()) for s in sentences)
X = torch.tensor([encode(s, max_len) for s in sentences])   # (N, T)
y = torch.tensor(labels)                                   # (N,)

class TextDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

dataset = TextDataset(X, y)
loader = DataLoader(dataset, batch_size=2, shuffle=True)

# Training the model

In [None]:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print("Using device:", device)

vocab_size = len(vocab)
embed_dim = 16
hidden_dim = 32
output_dim = 2  # 2 classes

model = AttentionLSTM(vocab_size, embed_dim, hidden_dim, output_dim, pad_idx).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(20):
    model.train()
    total_loss = 0.0
    for batch_x, batch_y in loader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)

        optimizer.zero_grad()
        logits = model(batch_x)                # (B, 2)
        loss = criterion(logits, batch_y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_x.size(0)

    print(f"Epoch {epoch+1:02d} - loss: {total_loss/len(dataset):.4f}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

model.eval()
with torch.no_grad():
    # pick one example, e.g. index 0
    x_ex = X[0].unsqueeze(0).to(device)   # (1, T)
    logits, attn = model(x_ex, return_attention=True)  # attn: (1, T)

attn = attn[0].cpu().numpy()             # (T,)
tokens = sentences[0].split()
T = len(tokens)
attn = attn[:T]                          # ignore pads

attn_2d = attn[np.newaxis, :]           # (1, T)

plt.figure(figsize=(4, 1.5))
plt.imshow(attn_2d, cmap="viridis", aspect="auto")
plt.colorbar(label="Attention weight")
plt.yticks([])
plt.xticks(range(T), tokens, rotation=45, ha="right")
plt.title("Attention over words")
plt.tight_layout()
plt.show()