In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")

In [3]:
## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import CIFAR100
from torchvision import transforms

import math

import pytorch_lightning as pl

from __future__ import annotations

  from .autonotebook import tqdm as notebook_tqdm


Attention Mechanism

$ \alpha_i=\frac{\exp \left(f_{\text {attn }}\left(\text { key }_i, \text { query }\right)\right)}{\sum_j \exp \left(f_{\text {attn }}\left(\text { key }_j, \text { query }\right)\right)}, \quad \text { out }=\sum_i \alpha_i \cdot \text { value }_i$

## Implement Function: Scaled Products

Implement the Scaled Dot Product. The mechanism of Scaled Dot Product is as follows:

We have a set of queries, keys and values:

Queries: $Q \in \mathbb{R}^{T \times d_k}$
Keys: $K \in \mathbb{R}^{T \times d_k}$
Values: $V \in \mathbb{R}^{T \times d_v}$

where $T$ is the sequence length, and $d_k$ and $d_v$ are the hidden dimensionality for queries/keys and values respectively

Our goal here are:

1) measure the similarity between Q and K. We achieve this by doing a dot product between Q and K. This we can say is $QK^\top \in \mathbb R^{T \times T}$ Each row in $QK^\top$ is the attention logits for a specific element $i$ to all other elements in the sequence.
1) Scale $QK^\top$ by $\frac{1}{\sqrt{d_k}}$. This will ensure that the variance is scaled down to ~ $\sigma^2$
1) Apply mask fill. If there are masks involved, then change those masks to some arbitrarily small number. We will use binary masks, where 0 means you apply the mask fill, and 1 means you don't and should be attended to. Use [masked_fill_](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_). Note that the mask fill is applied to the attention logits, which is of the dimension $R^{T \times T}$
1) Apply softmax function to $\frac{Q K^T}{\sqrt{d_k}}$ to get **attentions**
1) Multiply $\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right)$ by $V$ to get **attention values**

Return the **attention values** and **attentions**, respectively.

Very important:

Make sure to account for batch dimension. 

In [22]:
def scaled_dot_product(q: torch.tensor, k: torch.tensor, v: torch.tensor, mask: torch.tensor = None) -> Tuple[torch.tensor, torch.tensor]: 
    """implements the scaled dot products

    Args:
        q (torch.tensor): _description_
        k (torch.tensor): _description_
        v (torch.tensor): _description_
        mask (torch.tensor, optional): _description_. Defaults to None.

    Returns:
        Tuple[torch.tensor, torch.tensor]: _description_
    """
    # step 1: peform Q@K.T
    # be sure to do for batch dimension

    attention_logits = None

    # step 2: scale QK.T


    # step 3: apply mask fill. Fill with -9e15


    # step 4: softmax function on your logits to get your attention
    
    attention = None

    # step 5: do a matmul with V to get your values
    
    values = None

    return values, attention

In [44]:
# my implementation

def scaled_dot_product(q: torch.tensor, k: torch.tensor, v: torch.tensor, mask: torch.tensor = None) -> Tuple[torch.tensor, torch.tensor]: 
    """implements the scaled dot products

    Args:
        q (torch.tensor): _description_
        k (torch.tensor): _description_
        v (torch.tensor): _description_
        mask (torch.tensor, optional): _description_. Defaults to None.

    Returns:
        Tuple[torch.tensor, torch.tensor]: _description_
    """


    # replace values and attention variables

    # step 1: peform Q@K.T
    # be sure to do for batch dimension

    qk = q@k.transpose(-2, -1)

    # step 2: scale QK.T

    d_k = q.shape[-1]

    qk_scaled = qk / (math.sqrt(d_k))

    # step 3: apply mask fill. Fill with -9e15

    if mask is not None:
        qk_scaled = qk_scaled.masked_fill(mask == 0, -9e15)

    # step 4: softmax function on your logits to get your attention
    
    attention = F.softmax(qk_scaled, dim = -1)

    # step 5: do a matmul with V to get your values

    values = attention@v

    return values, attention

In [42]:
# the standard implementation

def scaled_dot_product_test(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention

In [48]:
batch, seq_len, d_k = 3, 3, 2

# test only one batch
q = torch.randn(seq_len, d_k)
k = torch.randn(seq_len, d_k)
v = torch.randn(seq_len, d_k)
values, attention = scaled_dot_product(q, k, v)
val, att = scaled_dot_product_test(q,k,v)
torch.allclose(val, values)
torch.allclose(att, attention)
# print("Q\n", q)
# print("K\n", k)
# print("V\n", v)
# print("Values\n", values)
# print("Attention\n", attention)

# test multiple batches
q = torch.randn(batch, seq_len, d_k)
k = torch.randn(batch, seq_len, d_k)
v = torch.randn(batch, seq_len, d_k)
values, attention = scaled_dot_product(q, k, v)
# print("Q\n", q)
# print("K\n", k)
# print("V\n", v)
# print("Values\n", values)
# print("Attention\n", attention)

val, att = scaled_dot_product_test(q,k,v)
torch.allclose(val, values)
torch.allclose(att, attention)

# test that masking works right

mask = torch.eye(seq_len)
values, attention = scaled_dot_product(q, k, v, mask)
val, att = scaled_dot_product_test(q,k,v, mask)
torch.allclose(val, values)
torch.allclose(att, attention)

print ("Passed")

Passed


## Multi-Head Attention Layer

Now that we know what is going on with the Scaled Dot Product Attention, we turn our attention (hah!) to the Multi-Head Attention Network. In a way, we are giving head to our data (woah!).

Let's leave the `Multi` part aside and talk about what giving head to our data means. A head contains the Q K and V weights. Using those weights we then turn our sequence of inputs into a matrix where each word (element) our sequence is a vector. We then take these vectors, do the linear combination and get Q K and V outputs. We then run the Scaled Dot Product to get the Z. 

So instead of just have one head, we are giving our data mulitple heads! In the original paper, there were 8 heads, like a Hydra. The unique this is that each of these heads' weights are totally different in their initialisation. However, within each head, the Q K and V weights should be initialised the same. 

In [None]:
# Helper function to support different mask shapes.
# Output shape supports (batch_size, number of heads, seq length, seq length)
# If 2D: broadcasted over batch size and number of heads
# If 3D: broadcasted over number of heads
# If 4D: leave as is
def expand_mask(mask):
    assert mask.ndim > 2, "Mask must be at least 2-dimensional with seq_length x seq_length"
    if mask.ndim == 3:
        mask = mask.unsqueeze(1)
    while mask.ndim < 4:
        mask = mask.unsqueeze(0)
    return mask

In [None]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size()
        if mask is not None:
            mask = expand_mask(mask)
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o