In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class Attention(nn.Module): 
                            
    def __init__(self, d_model=2,  
                 row_dim=0, 
                 col_dim=1):
        
        super().__init__()
        
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        
        self.row_dim = row_dim
        self.col_dim = col_dim


    ## The only change from SelfAttention and attention is that
    ## now we expect 3 sets of encodings to be passed in...
    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        ## ...and we pass those sets of encodings to the various weight matrices.
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
            
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## create matrices of token encodings
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

## set the seed for the random number generator
torch.manual_seed(42)

## create an attention object
attention = Attention(d_model=2,
                      row_dim=0,
                      col_dim=1)

## calculate encoder-decoder attention
attention(encodings_for_q, encodings_for_k, encodings_for_v)

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

## Big Disclaimer

In the standard self-attention mechanism, `Q`, `K`, and `V` are all derived from the same input. For example, in the `SelfAttention` class from `1-simplest_selfattn.ipynb`, the `forward` method takes a single input and generates `Q`, `K`, `V` from it using linear transformations. That's self-attention where the same $\text{encoding source}^{1}$ is used for all three.

$$
\text{encoding source} = \text{input text} \rightarrow \text{embeddings} \rightarrow \text{positional encodings} \quad \text{-------- (1)}
$$

But in the current `Attention` class of this notebook, the `forward` method accepts three different encodings: `encodings_for_q`, `encodings_for_k`, `encodings_for_v` because this is for encoder-decoder attention, <span style="color:lightgreen; font-weight:bold; font-size:1em;">where the queries come from the decoder and the keys/values come from the encoder.</span>

For simplicity, this notebook uses the same tensor for all three encodings; however, in real cases, the encoder processes the input sequence and produces `K` and `V`, while the decoder generates `Q` based on its own inputs. So the three encodings would come from different sources.

The usual simplified version of my steps above $\text{(1)}$ result in a single set of embeddings with positional info. But in encoder-decoder attention, the decoder's `Q` is based on its own embeddings (possibly masked), while `K` and `V` are from the encoder's output. Hence, the three separate inputs allow the `Attention` class to handle cases where `Q`, `K`, `V` come from different sequences, unlike self-attention where they're the same.

So the key point is that this `Attention` class is designed for cross-attention between encoder and decoder, not self-attention. The three separate encodings parameters enable flexibility in handling different sources for `Q`, `K`, `V`, which is essential for tasks like translation where the decoder needs to attend to the encoder's output.

---

### So, in summary

The Attention class accepts three separate encodings to handle different transformer architecture scenarios:

#### 1. Self-Attention (single sequence):

- All three encodings (`Q`, `K`, `V`) come from the same source
- Example: Encoder self-attention using the same positional-encoded embeddings

```python
attention(same_encodings, same_encodings, same_encodings)
```

#### 2. Cross-Attention (encoder-decoder):

- `Q` comes from decoder's embeddings
- `K/V` come from encoder's final output

```python
attention(decoder_embeddings, encoder_outputs, encoder_outputs)
```

#### 3. Hybrid Scenarios:

- Could mix different sources (e.g., `Q` from one modality, `K/V` from another)

<span style="color:lightblue; font-weight:bold; font-size:1em;">Key differences from standard self-attention:</span>

| <span style="color:lightblue">Scenario</span> | <span style="color:lightblue">Q Source</span> | <span style="color:lightblue">K/V Source</span> | <span style="color:lightblue">Code Example</span> |
|-----------------|-------------------|-------------------|---------------------------------------|
| Encoder Self | Encoder Embeddings| Encoder Embeddings| Attention(enc, enc, enc) |
| Decoder Self | Decoder Embeddings| Decoder Embeddings| Attention(dec, dec, dec) |
| Encoder-Decoder | Decoder Embeddings| Encoder Outputs | Attention(dec, enc_out, enc_out) |

This implementation mirrors the original Transformer architecture where:
- Encoder outputs become `K/V` for decoder cross-attention
- Decoder inputs (with positional encoding) become `Q`
- Separate encoding parameters allow flexible attention between different sequences