<a href="https://colab.research.google.com/github/srvmishra/Language-Models/blob/main/Attention_Mecchanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Imports

In [1]:
import torch
import torch.nn as nn

### `SelfAttentionV1`: Simple Self Attention with weights as `nn.Parameter`

In [2]:
class SelfAttentionV1(nn.Module):
  def __init__(self, d_in, d_out):
    super(SelfAttentionV1, self).__init__()
    self.W_query = nn.Parameter(torch.rand(d_in, d_out))
    self.W_key = nn.Parameter(torch.rand(d_in, d_out))
    self.W_value = nn.Parameter(torch.rand(d_in, d_out))

  def forward(self, x):
    '''
    x -> torch.tensor, shape: [batch_size, d_in]
    '''
    query = x @ self.W_query
    key = x @ self.W_key
    value = x @ self.W_value
    attention_scores = query @ key.T
    attention_weights = torch.softmax(attention_scores/value.shape[-1] ** 0.5,
                                      dim=-1)
    context_vector = attention_weights @ value
    return context_vector

  def set_weights(self, W_query, W_key, W_value):
    self.W_query = nn.Parameter(W_query)
    self.W_key = nn.Parameter(W_key)
    self.W_value = nn.Parameter(W_value)

### `SelfAttentionV2`: Simple Self Attention with weights as `nn.Linear` with `bias=False`

In [3]:
class SelfAttentionV2(nn.Module):
  def __init__(self, d_in, d_out):
    super(SelfAttentionV2, self).__init__()
    self.W_query = nn.Linear(d_in, d_out, bias=False)
    self.W_key = nn.Linear(d_in, d_out, bias=False)
    self.W_value = nn.Linear(d_in, d_out, bias=False)

  def forward(self, x):
    '''
    x -> torch.tensor, shape: [batch_size, d_in]
    '''
    query = self.W_query(x)
    key = self.W_key(x)
    value = self.W_value(x)
    attention_scores = query @ key.T
    attention_weights = torch.softmax(attention_scores/value.shape[-1] ** 0.5,
                                      dim=-1)
    context_vector = attention_weights @ value
    return context_vector

### Checking outputs for both Self Attention implementations

In [4]:
inputs = torch.tensor([[0.43, 0.15, 0.89],
                       [0.55, 0.87, 0.66],
                       [0.57, 0.85, 0.64],
                       [0.22, 0.58, 0.33],
                       [0.77, 0.25, 0.10],
                       [0.05, 0.80, 0.55]])

d_in = 3
d_out = 2

torch.manual_seed(123)
sa_v1 = SelfAttentionV1(d_in, d_out)
sa_v1_outs = sa_v1(inputs)
print('Self Attention V1 outputs')
print(sa_v1_outs)

sa_v2 = SelfAttentionV2(d_in, d_out)
sa_v2_outs = sa_v2(inputs)
print('Self Attention V2 outputs')
print(sa_v2_outs)

Self Attention V1 outputs
tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)
Self Attention V2 outputs
tensor([[0.5085, 0.3508],
        [0.5084, 0.3508],
        [0.5084, 0.3506],
        [0.5074, 0.3471],
        [0.5076, 0.3446],
        [0.5077, 0.3493]], grad_fn=<MmBackward0>)


In [5]:
print('Self Attention Weight Shapes')
print('Query: V1: ', sa_v1.W_query.shape, ' V2: ', sa_v2.W_query.weight.shape)
print('Key: V1: ', sa_v1.W_key.shape, ' V2: ', sa_v2.W_key.weight.shape)
print('Value: V1: ', sa_v1.W_value.shape, ' V2: ', sa_v2.W_value.weight.shape)

Self Attention Weight Shapes
Query: V1:  torch.Size([3, 2])  V2:  torch.Size([2, 3])
Key: V1:  torch.Size([3, 2])  V2:  torch.Size([2, 3])
Value: V1:  torch.Size([3, 2])  V2:  torch.Size([2, 3])


### Transferring weights from `SelfAttentionV2` -> `SelfAttentionV1` and matching outputs

In [6]:
W_query = sa_v2.W_query.weight.T
W_key = sa_v2.W_key.weight.T
W_value = sa_v2.W_value.weight.T

sa_v1.set_weights(W_query, W_key, W_value)
sa_v1_outs_new = sa_v1(inputs)
print(sa_v1_outs_new)
print((sa_v1_outs_new == sa_v2_outs).all())

tensor([[0.5085, 0.3508],
        [0.5084, 0.3508],
        [0.5084, 0.3506],
        [0.5074, 0.3471],
        [0.5076, 0.3446],
        [0.5077, 0.3493]], grad_fn=<MmBackward0>)
tensor(True)


### `CausalSelfAttention`: Self Attention with Causal Masking

In attention masking, why do we add $-\infty$ to the raw attention scores before computing the softmax? We could also zero out the attention weights after softmax above the diagonal and then rescale each row.

Two ways to include dropout in attention scores - to prevent too much dependence on any single position for attention computation:
1. directly apply to the attention weight matrix
2. apply to the context vector after multiplying attention weight with value vector

after applying the dropout, the resulting weights/vectors are scaled so that the overall logits stay consistent during training and inference. note that inference does not use dropouts. it is only used during training.

In [7]:
class CausalSelfAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, drop_rate, qkv_bias=False):
    super(CausalSelfAttention, self).__init__()
    self.d_out = d_out
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = nn.Dropout(drop_rate)

    # we dont have to worry about placing tensors separately on device, so we
    # use register_buffer
    self.register_buffer('mask',
                         torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    '''
    x -> torch.tensor, shape: [num_sequences, num_tokens, d_in]
    '''
    query = self.W_query(x)
    key = self.W_key(x)
    value = self.W_value(x)

    attention_scores = query @ key.transpose(-1, -2)

    # in place operation as function ends with _
    # max length is context length, but sequence only has num_tokens tokens
    attention_scores.masked_fill_(self.mask.bool()[:attention_scores.shape[1], :attention_scores.shape[1]], -torch.inf)
    attention_weights = torch.softmax(attention_scores/value.shape[-1] ** 0.5, dim=-1)
    drop_attention_weights = self.dropout(attention_weights)

    context_vector = drop_attention_weights @ value
    return context_vector

In [8]:
batch = torch.stack([inputs, inputs], dim=0)
print(batch)
print(batch.shape)
ca = CausalSelfAttention(d_in, d_out, batch.shape[1], 0.0)
ca_outputs = ca(batch)
print('Causal Attention Outputs')
print(ca_outputs)
print(ca_outputs.shape)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])
torch.Size([2, 6, 3])
Causal Attention Outputs
tensor([[[0.4566, 0.2729],
         [0.5792, 0.3011],
         [0.6249, 0.3102],
         [0.5691, 0.2785],
         [0.5543, 0.2520],
         [0.5337, 0.2499]],

        [[0.4566, 0.2729],
         [0.5792, 0.3011],
         [0.6249, 0.3102],
         [0.5691, 0.2785],
         [0.5543, 0.2520],
         [0.5337, 0.2499]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([2, 6, 2])


### `MultiHeadCausalAttentionWrapper`: Multi Head Attention with separate `CausalSelfAttention` heads

In [9]:
class MultiHeadCausalAttentionWrapper(nn.Module):
  def __init__(self, d_in, d_out, context_length, drop_rate, num_heads, qkv_bias=False):
    super(MultiHeadCausalAttentionWrapper, self).__init__()
    self.heads = nn.ModuleList([CausalSelfAttention(d_in, d_out, context_length, drop_rate, qkv_bias=False)
                                for _ in range(num_heads)])

  def forward(self, x):
    '''
    Here we can use torch.bmm
    '''
    return torch.cat([h(x) for h in self.heads], dim=-1)

In [10]:
mhsa_w = MultiHeadCausalAttentionWrapper(d_in, d_out, batch.shape[1], 0.0, 2)
mhsa_w_outputs = mhsa_w(batch)
print('Multi Head Causal Self Attention Wrapper Outputs')
print(mhsa_w_outputs)
print(mhsa_w_outputs.shape)

Multi Head Causal Self Attention Wrapper Outputs
tensor([[[-0.5684,  0.5063, -0.4821,  0.4336],
         [-0.5388,  0.6447, -0.5368,  0.5483],
         [-0.5242,  0.6954, -0.5545,  0.5886],
         [-0.4578,  0.6471, -0.4937,  0.5311],
         [-0.4006,  0.5921, -0.4589,  0.5169],
         [-0.3997,  0.5971, -0.4479,  0.4971]],

        [[-0.5684,  0.5063, -0.4821,  0.4336],
         [-0.5388,  0.6447, -0.5368,  0.5483],
         [-0.5242,  0.6954, -0.5545,  0.5886],
         [-0.4578,  0.6471, -0.4937,  0.5311],
         [-0.4006,  0.5921, -0.4589,  0.5169],
         [-0.3997,  0.5971, -0.4479,  0.4971]]], grad_fn=<CatBackward0>)
torch.Size([2, 6, 4])


In [11]:
mhsa_w1 = MultiHeadCausalAttentionWrapper(d_in, 1, batch.shape[1], 0.0, 2)
mhsa_w1_outputs = mhsa_w1(batch)
print('New Multi Head Causal Self Attention Wrapper Outputs')
print(mhsa_w1_outputs)
print(mhsa_w1_outputs.shape)

New Multi Head Causal Self Attention Wrapper Outputs
tensor([[[0.7128, 0.4106],
         [0.8309, 0.3569],
         [0.8696, 0.3342],
         [0.7802, 0.2922],
         [0.7388, 0.2238],
         [0.7163, 0.2381]],

        [[0.7128, 0.4106],
         [0.8309, 0.3569],
         [0.8696, 0.3342],
         [0.7802, 0.2922],
         [0.7388, 0.2238],
         [0.7163, 0.2381]]], grad_fn=<CatBackward0>)
torch.Size([2, 6, 2])


### `MultiHeadSelfAttention`: Efficient Multi Head Self Attention with batch matrix multiplication

why don't we directly use

```
value = query.view(num_seq, self.num_heads, num_tokens, self.head_dim)
```
in this way, number of transposes will be less

1. `tensor.contiguous()` creates a tensor with the same memory mapping as if it is created from scratch.
2. `tensor.view()` or `tensor.transpose()` or `tensor.reshape()` modifies this mapping.
3. also, `view/reshape` is not the same as `transpose`. we can verify this by using `flatten` or `view(-1)`.

Reference: [stackoverflow](https://stackoverflow.com/questions/48915810/what-does-contiguous-do-in-pytorch)


In [12]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, num_heads, drop_rate, qkv_bias=False):
    super(MultiHeadSelfAttention, self).__init__()

    self.in_dim = d_in
    self.out_dim = d_out
    self.num_heads = num_heads
    self.head_dim = self.out_dim//self.num_heads
    self.context_length = context_length

    self.W_query = nn.Linear(self.in_dim, self.out_dim, bias=qkv_bias)
    self.W_key = nn.Linear(self.in_dim, self.out_dim, bias=qkv_bias)
    self.W_value = nn.Linear(self.in_dim, self.out_dim, bias=qkv_bias)
    # combines the outputs from all heads
    self.out_proj = nn.Linear(self.out_dim, self.out_dim)

    self.dropout = nn.Dropout(drop_rate)
    self.register_buffer('mask',
                         torch.triu(torch.ones(self.context_length, self.context_length), diagonal=1))

  def forward(self, x):
    num_seq, num_tokens, _ = x.shape

    query = self.W_query(x)
    key = self.W_key(x)
    value = self.W_value(x)

    query = query.view(num_seq, num_tokens, self.num_heads, self.head_dim)
    key = query.view(num_seq, num_tokens, self.num_heads, self.head_dim)
    value = query.view(num_seq, num_tokens, self.num_heads, self.head_dim)

    attention_scores = query.transpose(1, 2) @ key.transpose(1, 2).transpose(2, 3)   # --> num_seq, num_heads, num_tokens, num_tokens
    attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    attention_weights = torch.softmax(attention_scores/self.head_dim ** 0.5, dim=-1)
    drop_attention_weights = self.dropout(attention_weights)

    context_vector = drop_attention_weights @ value.transpose(1, 2) # --> num_seq, num_heads, num_tokens, head_dim
    context_vector = context_vector.transpose(1, 2) # --> num_seq, num_tokens, num_heads, head_dim --> transpose is not the same as view/reshape
    # --> create same memory mapping as if created from scratch
    context_vector = context_vector.contiguous().view(num_seq, num_tokens, self.out_dim) # --> num_seq, num_tokens, out_dim
    context_vector = self.out_proj(context_vector) # --> num_seq, num_tokens, out_dim
    return context_vector

In [13]:
mhsa = MultiHeadSelfAttention(d_in, d_out, batch.shape[1], 2, 0.0)
mhsa_outputs = mhsa(batch)
print('Multi Head Self Attention Outputs')
print(mhsa_outputs)
print(mhsa_outputs.shape)

Multi Head Self Attention Outputs
tensor([[[-0.1700,  0.2093],
         [-0.1529,  0.2570],
         [-0.1486,  0.2705],
         [-0.1606,  0.2722],
         [-0.1784,  0.2464],
         [-0.1737,  0.2648]],

        [[-0.1700,  0.2093],
         [-0.1529,  0.2570],
         [-0.1486,  0.2705],
         [-0.1606,  0.2722],
         [-0.1784,  0.2464],
         [-0.1737,  0.2648]]], grad_fn=<ViewBackward0>)
torch.Size([2, 6, 2])


`MultiHeadCausalAttentionWrapper` vs `MultiHeadSelfAttention`:

- wrapper goes through each head sequentially so it is slower. the combined class implements all computations in parallel so it is faster.
- wrapper stacks individual causal attention heads. the combined class uses a batched matrix multiplication.
- wrapper has a separate `W_query`, `W_key`, and `W_value` matrix for each head. So the input parameter `d_out` will be `d_out//num_heads` so that the final output dim is `d_out`. the combined class implements one `W_query`, `W_key`, and `W_value` matrix for all heads and splits their outputs into as many vectors as there are heads. So the input parameter `d_out` will be `d_out` only.
- For `W_query`, `W_key`, and `W_value` matrices, both implementations have the same number of parameters.
- In either implementation, we have not scaled the context vector after the application of dropout to account for the effect of dropped neurons.

Note: `torch.bmm()` works with 3D tensors only, whereas `@` works with tensors of any shape as long as the last two dimensions permit matrix multiplication.

### GPT2 Self Attention

In [14]:
gpt2_context_length = 1024
gpt2_embedding_dim = 768
gpt2_attention_heads = 12

gpt2_mhsa_w1 = MultiHeadCausalAttentionWrapper(gpt2_embedding_dim,
                                               gpt2_embedding_dim//gpt2_attention_heads,
                                               gpt2_context_length,
                                               0.1, gpt2_attention_heads)
gpt2_mhsa = MultiHeadSelfAttention(gpt2_embedding_dim,
                                   gpt2_embedding_dim,
                                   gpt2_context_length,
                                   gpt2_attention_heads, 0.1)

batch = torch.rand(size=(10, 512, 768)) # --> improper batch shape (previous shape was [2, 6, 3]) was causing error in this step

In [15]:
%%time
mhsa_w1_outputs = gpt2_mhsa_w1(batch)

CPU times: user 911 ms, sys: 998 ms, total: 1.91 s
Wall time: 2.28 s


In [16]:
%%time
mhsa_outputs = gpt2_mhsa(batch)

CPU times: user 773 ms, sys: 1.07 s, total: 1.84 s
Wall time: 1.83 s
