Implementing Attention from Scratch. Specificaly multi headed attention

In [3]:
# libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from typing import List, Optional

Multi Headed Attention:

Multi Headed Attention is a way to combine multiple attention heads into one. This is done by concatenating the attention heads and then applying a linear transformation to the result. This is done to allow the model to jointly attend to information at different positions from different representational spaces.

In [3]:
class PrepareForMultiHeadAttention(nn.Module):
    """
    ## Prepare for multi-head attention
    This module does a linear transformation and splits the vector into given
    number of heads for multi-head attention.
    This is used to transform **key**, **query**, and **value** vectors.
    """

    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
        super().__init__()
        # Linear layer for linear transform
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        # Number of heads
        self.heads = heads
        # Number of dimensions in vectors in each head
        self.d_k = d_k

    def forward(self, x: torch.Tensor):
        # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
        # We apply the linear transformation to the last dimension and split that into
        # the heads.
        head_shape = x.shape[:-1]

        # Linear transform
        x = self.linear(x)

        # Split last dimension into heads
        x = x.view(*head_shape, self.heads, self.d_k)

        # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`
        return x

$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$

In [10]:
class MultiHeadAttention(nn.Module):
    """
    This computes scaled multi-headed attention for given `query`, `key` and `value` vectors.
    i.e, it finds keys that matches the query, and gets the values of
     those keys.
    It uses dot-product of query and key as the indicator of how matching they are.
    Softmax is calculated along the axis of of the sequence (or time).
    """

    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
        """
        * `heads` is the number of heads.
        * `d_model` is the number of features in the `query`, `key` and `value` vectors.
        """

        super().__init__()

        # Number of features per head
        self.d_k = d_model // heads
        # Number of heads
        self.heads = heads

        # These transform the `query`, `key` and `value` vectors for multi-headed attention.
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

        # Softmax for attention along the time dimension of `key`
        self.softmax = nn.Softmax(dim=1)

        # Output layer
        self.output = nn.Linear(d_model, d_model)
        # Dropout
        self.dropout = nn.Dropout(dropout_prob)
        # Scaling factor before the softmax
        self.scale = 1 / math.sqrt(self.d_k)

        # We store attentions so that it can be used for logging, or other computations if needed
        self.attn = None

    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
        """
        ### Calculate scores between queries and keys
        This method can be overridden for other variations like relative attention.
        """

        # Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
        return torch.einsum('ibhd,jbhd->ijbh', query, key)

    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
        """
        `mask` has shape `[seq_len_q, seq_len_k, batch_size]`, where first dimension is the query dimension.
        If the query dimension is equal to $1$ it will be broadcasted.
        """

        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

        # Same mask applied to all heads.
        mask = mask.unsqueeze(-1)

        # resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`
        return mask

    def forward(self, *,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask: Optional[torch.Tensor] = None):
        """
        `query`, `key` and `value` are the tensors that store
        collection of *query*, *key* and *value* vectors.
        They have shape `[seq_len, batch_size, d_model]`.
        `mask` has shape `[seq_len, seq_len, batch_size]` and
        `mask[i, j, b]` indicates whether for batch `b`,
        query at position `i` has access to key-value at position `j`.
        """

        # `query`, `key` and `value`  have shape `[seq_len, batch_size, d_model]`
        seq_len, batch_size, _ = query.shape

        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)

        # Prepare `query`, `key` and `value` for attention computation.
        # These will then have shape `[seq_len, batch_size, heads, d_k]`.
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        # Compute attention scores $Q K^\top$.
        # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
        scores = self.get_scores(query, key)

        # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
        scores *= self.scale

        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # $softmax$ attention along the key sequence dimension
        attn = self.softmax(scores)

        # Save attentions if debugging
        tracker.debug('attn', attn)

        # Apply dropout
        attn = self.dropout(attn)

        # Multiply by values
        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

        # Save attentions for any other calculations 
        self.attn = attn.detach()

        # Concatenate multiple heads
        x = x.reshape(seq_len, batch_size, -1)

        # Output layer
        return self.output(x)

> Example with data for Simple Self attention (single head)

In [4]:
sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


In [5]:
sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)


tensor([0, 4, 5, 2, 1, 3])


In [6]:
torch.manual_seed(123)
embed = torch.nn.Embedding(6, 16) # 6 words in vocab, 16 dimensional embeddings
embedded_sentence = embed(sentence_int).detach()

In [7]:
embedded_sentence

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])

In [8]:
embedded_sentence.shape

torch.Size([6, 16])

In [9]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 24, 24, 28

W_query = torch.rand(d_q, d)
W_key = torch.rand(d_k, d)
W_value = torch.rand(d_v, d)

In [10]:
W_query.shape, W_key.shape, W_value.shape

(torch.Size([24, 16]), torch.Size([24, 16]), torch.Size([28, 16]))

In [22]:
x_2 = embedded_sentence[1] # second word
x_2.shape

torch.Size([16])

In [15]:
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

In [21]:
W_query@x_2 # Matrix multiplication

tensor([ 0.8982,  0.1030,  0.4428,  0.6328, -1.7003,  1.3489, -0.3082, -0.5900,
        -0.9257, -0.7688,  1.8828, -1.6065, -0.8011, -0.4114, -0.6116,  1.3902,
        -0.1460,  0.0244, -0.5577,  1.5972, -2.2190, -0.0214,  0.2002,  1.3752])

In [17]:
query_2

tensor([ 0.8982,  0.1030,  0.4428,  0.6328, -1.7003,  1.3489, -0.3082, -0.5900,
        -0.9257, -0.7688,  1.8828, -1.6065, -0.8011, -0.4114, -0.6116,  1.3902,
        -0.1460,  0.0244, -0.5577,  1.5972, -2.2190, -0.0214,  0.2002,  1.3752])

In [16]:
query_2.shape, key_2.shape, value_2.shape

(torch.Size([24]), torch.Size([24]), torch.Size([28]))

In [27]:
W_key.shape, embedded_sentence.shape

(torch.Size([24, 16]), torch.Size([6, 16]))

In [29]:
keys = (W_key@ embedded_sentence.T).T
values = (W_value@ embedded_sentence.T).T

In [30]:
keys.shape, values.shape

(torch.Size([6, 24]), torch.Size([6, 28]))

In [31]:
query_2.dot(keys[4])

tensor(11.1466)

In [34]:
omega_2 = query_2.matmul(keys.T)
omega_2.shape

torch.Size([6])

In [38]:
omega_2 # unnormalized attention weights

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800])

In [39]:
attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0) # normalized attention weights w.r.t input token 2
print(attention_weights_2)

tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458])


In [41]:
context_vector_2 = attention_weights_2.matmul(values) # context vector w.r.t input token 2

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084])


In [42]:
# multi head attention

h = 3

multihead_W_query = torch.rand(h, d_q, d)
multihead_W_key = torch.rand(h, d_k, d)
multihead_W_value = torch.rand(h, d_v, d)

In [45]:
multihead_W_query.shape, multihead_W_key.shape, multihead_W_value.shape

(torch.Size([3, 24, 16]), torch.Size([3, 24, 16]), torch.Size([3, 28, 16]))

In [46]:
multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2.shape)

torch.Size([3, 24])


In [48]:
multihead_key_2 = multihead_W_key.matmul(x_2)
multihead_value_2 = multihead_W_value.matmul(x_2)
multihead_key_2.shape, multihead_value_2.shape

(torch.Size([3, 24]), torch.Size([3, 28]))

In [49]:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs.shape)

torch.Size([3, 16, 6])


In [54]:
embedded_sentence.shape

torch.Size([6, 16])

In [53]:
embedded_sentence

tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])

In [52]:
stacked_inputs[0]

tensor([[ 0.3374,  0.5146,  0.2553, -1.3250, -0.0770,  0.8768],
        [-0.1778,  0.9938, -0.5496,  0.1784, -1.0205,  1.6221],
        [-0.3035, -0.2587,  1.0042, -2.1338, -0.1690, -1.4779],
        [-0.5880, -1.0826,  0.8272,  1.0524,  0.9178,  1.1331],
        [ 0.3486, -0.0444, -0.3948, -0.3885,  1.5810, -1.2203],
        [ 0.6603,  1.6236,  0.4892, -0.9343,  1.3010,  1.3139],
        [-0.2196, -2.3229, -0.2168, -0.4991,  1.2753,  1.0533],
        [-0.3792,  1.0878, -1.7472, -1.0867, -0.2010,  0.1388],
        [ 0.7671,  0.6716, -1.6025,  0.8805,  0.4965,  2.2473],
        [-1.1925,  0.6933, -1.0764,  1.5542, -1.5723, -0.8036],
        [ 0.6984, -0.9487,  0.9031,  0.6266,  0.9666, -0.2808],
        [-1.4097, -0.0765, -0.7218, -0.1755, -1.1481,  0.7697],
        [ 0.1794, -0.1526, -0.5951,  0.0983, -1.1589, -0.6596],
        [ 1.8951,  0.1167, -0.7112, -0.0935,  0.3255, -0.7979],
        [ 0.4954,  0.4403,  0.6230,  0.2662, -0.6315,  0.1838],
        [ 0.2692, -1.4465, -1.3729, -0.5

In [59]:
stacked_inputs[0].permute(1, 0)==stacked_inputs[0].T #checking permute functionality

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True]])

In [55]:
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs)
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 24, 6])
multihead_values.shape: torch.Size([3, 28, 6])


In [56]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 6, 24])
multihead_values.shape: torch.Size([3, 6, 28])
