### Attention Mechanisms
In this notebook, we will sequentially implement different variants of attention mechanisms. These variants will build on each other, with the goal of finally creating a compact, efficient implementation of an attention mechanism, which we can then plug into our LLM architecture.

**Simple Self-Attention**: Introduce the broader idea behind attention.

**Self-Attention**: Trainable weights that forms the basis of the mechanisms used in LLMs.

**Causal Attention**: A self-attention variant allowing a model to consider only previous and current inputs in a sequence, ensuring temporal order during text generation.

**Multi-head Attention**: A self-attention and causal attension extension, which enables the model to simultaneously attend to information from different representation subspaces.

#### Why Attention?
In machine translation, it is not possible to merely translate word by word. The translation process requires contextual understandng and grammatical alignment.

- "Kannst du mir helfen diesen Satz zu uebersetzen" should not be translated to "Can you me help this sentence to translate", but rather to "Can you help me translate this sentence".
- Certain words require access to words appearing before or later in the original sentence. For instance, the verb "to translate" should be used in the context of "this sentence", and not independently.

Typically, to overcome this challenge, deep neural networks with two submodules are used:

- **encoder**: first read in and process the entire text (already done in the `preprocessing.ipynb` notebook).

- **decoder**: produces the translated text.

Pre-LLM architectures typically involved recurrent neural networks, a type of neural network where outputs from previous steps are fed as inputs to the current step, making them well-suited for sequential data. In this many-to-one RNN architecture, the input text is fed token by token into the encoder, which processes it sequentially. The terminal state of the encoder is a memory cell, known as the hidden state, which encodes the entire input. This hidden state is then fed to a decoder that would then generate the translated sentence, word by word, one word at a time.

- While the encoder is many-to-one, the decoder is a one-to-many architecture, since the hidden state is passed at every step of the decoding process.

**encoder-decoder RNNs had many shortcomings that motivated the design of attention mechanisms**, namely that the it was not possible to access earlier hidden states from the encoder during the decoding phase, since we rely on a single hidden state containing all the relevant information. Context was lost, especially in complex sentences where dependencies span larger distances.

### *Non-Modified RNN Encoders*




In [1]:
import torch
import torch.autograd
from torch import optim
import torch.nn.functional as F

class EncoderRNN(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, dropout = None):
        super(EncoderRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        if dropout:
            self.dropout = torch.nn.Dropout(dropout)
        self.embedding = torch.nn.Embedding(input_size, hidden_size)
        self.gru = torch.nn.GRU(hidden_size, hidden_size, num_layers)

    def forward(self, word_inputs, hidden):
        seq_len = len(word_inputs)
        embedded = self.embedding(word_inputs).view(seq_len, 1, -1)
        output, hidden = self.gru(embedded, hidden)
        return output, hidden

    def init_hidden(self):
        hidden = torch.autograd.Variable(torch.zeros(self.num_layers, 1, self.hidden_size))
        return hidden

The class EncoderRNN above depicts how a non-modified encoder RNN works. At every step of decoding, the decoder is given an input token as well as a hidden state. The relevant building blocks are:

- `torch.nn.Embedding`: As described before, it is simply a lookup table storing the embeddings of a fixed dictionary, of a fixed size. The input to this module is a list of indices, and the output is the corresponding word embeddings.

    - Note that the module has learnable weights, a tensor, of the shape (num_embeddings, embedding_dim). These are traditionally initialized from N(0,1).


- `torch.nn.GRU`: Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. More specifically, this means that, for each element in an input, the following are computed at each layer:
$$
\begin{align*}
r_t &= \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t &= \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t &= \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)} + b_{hn})) \\
h_t &= (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
\end{align*}
$$

    - where $h_{t}$ is the hidden state at time $t$, $x$ is the input at time $t$, and $r_{t}$, $z_{t}$, and $n_{t}$ are the reset, update and new gates, respectively. $\odot$ is the Hadamard product (or element-wise product).
        - the reset gate determines how much of a previous hidden state $h_{(t-1)}$ to forget, while the update gate determines how much of the candidate activation vector $n_{t}$ to incorporate into the new hidden state, $h_{t}$.
        - the input $x_{t}^{(l)}$ of the $l$-th layer ($l >= 2$) is the hidden state of $h_{t}^{(l-1)}$ multiplied by a dropout $d_{t}^{(l-1)}$, a Bernoulli random variable with probability `dropout`.

### *The Badhanau Attention Mechanism*
As a result of the major shortcomings of traditional RNN encoder-decoders, researchers eventually developed the **Badhanau Attention Mechanism** for RNNs in 2014. In this modification, the decoder can selectively access different parts of the input sequence at each decoding step.

- When generating an output token, the model has a way to access all input tokens.Input tokens have contain a measure of how important the input token is for the respective output tokens.

- The new module computes an 'energy' score for each encoder output given the current decoder hidden state. The score is computed as $score = V_{a}^{T} tanh(W_{a} * s_{t-1} + U_{a} * h_{i})$. These scores are later normalized to produce attention weights.
    - This module is then used by the `AttnDecoderRNN`, which embeds the input token, applies dropout, and uses the attention module to compute a context vector. The context vector and the embedded input are then concatenated, fed into a GRU, and output probabilities over the target dictionary are produced.

In [3]:
class BadhanauAttention(torch.nn.Module):
    def __init__(self, hidden_size):
        super(BadhanauAttention, self).__init__()
        self.hidden_size = hidden_size
        # linear layers to transform the decoder hidden state and encoder outputs.
        self.Wa = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        self.Ua = torch.nn.Linear(hidden_size, hidden_size, bias=False)
        # to produce scalar score, a parameter vector
        self.Va = torch.nn.Linear(hidden_size, 1, bias=False)

    def forward(self, decoder_hidden, encoder_outputs):
        """
        decoder_hidden: (1, batch_size, hidden_size) - current decoder hidden state.
        encoder_outputs: (seq_len, batch_size, hidden_size) - all encoder outputs
        """
        # squeeze time dimension from the decoder hidden state (batch_size, hidden_size)
        decoder_hidden = decoder_hidden.squeeze(0)
        encoder_outputs = encoder_outputs.transpose(0, 1)
        # expand decoder sates to (batch, seq_len, hidden_size)
        decoder_hidden_expanded = decoder_hidden.unsqueeze(1).expand_as(encoder_outputs)
        nrg = torch.tanh(self.Wa(decoder_hidden_expanded) + self.Ua(encoder_outputs))
        attention_scores = self.Va(nrg).squeeze(2)
        # normalize the scores to probs
        attention_weights = F.softmax(attention_scores, dim=1)
        #compute context vector as weighted sum of encoder outputs:
        # (batch_size, 1, seq_len) x (batch_size, seq_len, hidden_size) --> (batch_size, 1, hidden_size)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)
        return context, attention_weights

class AttnDecoderRNN(torch.nn.Module):
    def __init__(self, hidden_size, output_size, dropout=None):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.embedding = torch.nn.Embedding(output_size, hidden_size)
        self.dropout = torch.nn.Dropout(dropout)
        self.attention = BadhanauAttention(hidden_size)
        #GRU now takes the concatenated [embedded; context] vector as input.
        self.gru = torch.nn.GRU(hidden_size * 2, hidden_size)
        self.out = torch.nn.Linear(hidden_size, output_size)

    def forward(self, input_token, hidden, encoder_outputs):
        """
        input_token: (batch,) token indices for the current decoder input.
        hidden: (1, batch, hidden_size) current decoder hidden state.
        encoder_outputs: (seq_len, batch, hidden_size) outputs from the encoder.
        """
        # Get embedding of current token and apply dropout
        embedded = self.embedding(input_token).unsqueeze(0)  # shape: (1, batch, hidden_size)
        embedded = self.dropout(embedded)
        # Compute the context vector using attention
        context, attn_weights = self.attention(hidden, encoder_outputs)
        # Prepare context for concatenation: (1, batch, hidden_size)
        context = context.unsqueeze(0)
        # Concatenate embedded input and context vector
        rnn_input = torch.cat((embedded, context), dim=2)  # shape: (1, batch, 2*hidden_size)
        # Pass through the GRU
        output, hidden = self.gru(rnn_input, hidden)
        # Prepare output for final linear layer
        output = output.squeeze(0)  # shape: (batch, hidden_size)
        output = self.out(output)   # shape: (batch, output_size)
        # Return log probabilities and attention weights
        output = F.log_softmax(output, dim=1)
        return output, hidden, attn_weights