In [1]:
import torch
torch.__version__

'2.1.2+cu118'

## EXERCISE 3.1 

EXERCISE 3.1 COMPARING SELFATTENTION_V1 AND SELFATTENTION_V2

- Note that nn.Linear in SelfAttention_v2 uses a different weight initialization
scheme as nn.Parameter(torch.rand(d_in, d_out)) used in SelfAttention_v1,
which causes both mechanisms to produce different results. To check that both
implementations, SelfAttention_v1 and SelfAttention_v2, are otherwise similar,
we can transfer the weight matrices from a SelfAttention_v2 object to a
SelfAttention_v1, such that both objects then produce the same results.
Your task is to correctly assign the weights from an instance of SelfAttention_v2 to
an instance of SelfAttention_v1. To do this, you need to understand the
relationship between the weights in both versions. (Hint: nn.Linear stores the
weight matrix in a transposed form.) After the assignment, you should observe that
both instances produce the same outputs.

In [2]:
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55]   # step     (x^6)
    ]
)

print(inputs.shape)
print(inputs)

torch.Size([6, 3])
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


In [3]:
d_in, d_out = 3, 2

In [4]:
import torch.nn as nn 

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values 
        return context_vec        

In [5]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)

In [6]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [7]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)

In [8]:
sa_v1(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

In [9]:
sa_v2(inputs)

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

- transferring weights of sa_v1 to sa_v2 weights

In [10]:
sa_v1.W_query = nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_key = nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value = nn.Parameter(sa_v2.W_value.weight.T)

In [11]:
# checking observe that both instances produce the same outputs
output_v1 = sa_v1(inputs)
output_v2 = sa_v2(inputs)
print(output_v1)
print(output_v2)

# we noticed that all outputs the same

print(torch.allclose(output_v1, output_v2)) 

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)
tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)
True


## EXERCISE 3.2
**RETURNING 2-DIMENSIONAL EMBEDDING VECTORS**

- Change the input arguments for the MultiHeadAttentionWrapper(...,
 num_heads=2) call such that the output context vectors are 2-dimensional instead of
 4-dimensional while keeping the setting num_heads=2. Hint: You don't have to modify
 the class implementation; you just have to change one of the other input arguments.

In [12]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)                                          #A
        self.register_buffer(                                                       #A2
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)      #B
        )
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape      #C
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)   #C
        attn_scores.masked_fill_(                      #D
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # PyTorch automatically applies the mask to all batches at once.
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values 
        return context_vec  

In [13]:
# A wrapper class to implement multi-head attention
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [14]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) #2 inputs with 6 tokens each, and each token has embedding dimension 3

torch.Size([2, 6, 3])


In [15]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)

print(context_vecs)
print(f'context_vecs.shape: {context_vecs.shape}')

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


### Exercise explanation

if we want to have an ouput with 2-DIMENSIONAL EMBEDDING VECTORS, we can change the dimension number `d_out` to 1.

In [16]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 1 # here we changed d_out to 1
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)

print(context_vecs)
print(f'context_vecs.shape: {context_vecs.shape}')

tensor([[[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]],

        [[-0.5740,  0.2216],
         [-0.7320,  0.0155],
         [-0.7774, -0.0546],
         [-0.6979, -0.0817],
         [-0.6538, -0.0957],
         [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


## EXERCISE 3.3

 **EXERCISE 3.3 INITIALIZING GPT-2 SIZE ATTENTION MODULES**
 
 - Using the MultiHeadAttention class, initialize a multi-head attention module that
 has the same number of attention heads as the smallest GPT-2 model (12 attention
 heads). Also ensure that you use the respective input and output embedding sizes
 similar to GPT-2 (768 dimensions). Note that the smallest GPT-2 model supports a
 context length of 1024 tokens.

In [17]:
# An efficient multi-head attention class

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads 
        self.head_dim = d_out // num_heads                                       #A
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)                                  #B
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)                                                     #C
        values = self.W_value(x)                                                 #C
        queries = self.W_query(x)                                                #C

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)           #D
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)       #D
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)     #D

        keys = keys.transpose(1, 2)                                              #E
        queries = queries.transpose(1, 2)                                        #E
        values = values.transpose(1, 2)                                          #E

        attn_scores = queries @ keys.transpose(2, 3)                             #F
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]                   #G

        attn_scores.masked_fill_(mask_bool, -torch.inf)                          #H

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)                    #I
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)   #J
        context_vec = self.out_proj(context_vec)                                 #K
        return context_vec

 #A Reduce the projection dim to match desired output dim
 #B Use a Linear layer to combine head outputs
 #C Tensor shape: (b, num_tokens, d_out)
 #D We implicitly split the matrix by adding a `num_heads` dimension. Then we unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
 #E Transpose from shape (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
 #F Compute dot product for each head
 #G Mask truncated to the number of tokens
 #H Use the mask to fill attention scores
 #I Tensor shape: (b, num_tokens, n_heads, head_dim)
 #J Combine heads, where self.d_out = self.num_heads * self.head_dim
 #K Add an optional linear projection

In [33]:
torch.manual_seed(123)

#GPT-2 parameters
context_length = 1024
embedding_size = 768
num_heads = 12

d_in = embedding_size
d_out = embedding_size  # because in GPT-2 d_out = d_in

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=12)
mha


MultiHeadAttention(
  (W_query): Linear(in_features=768, out_features=768, bias=False)
  (W_key): Linear(in_features=768, out_features=768, bias=False)
  (W_value): Linear(in_features=768, out_features=768, bias=False)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

Optionally, the number of parameters is as follows:

In [34]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [35]:
count_parameters(mha)  # 2.36 M params in multiHeadAttention module

2360064

The GPT-2 model has 117M parameters in total, but as we can see, most of its parameters are not in the multi-head attention module itself.

In [36]:
x = torch.rand(4, context_length, embedding_size)
x.shape

torch.Size([4, 1024, 768])

In [37]:
context_vecs = mha(x)
print(context_vecs.shape)

torch.Size([4, 1024, 768])


In [38]:
x[0, 0, 0]

tensor(0.3475)

In [39]:
context_vecs[0, 0, 0]

tensor(0.2410, grad_fn=<SelectBackward0>)