In [1]:
import torch
torch.__version__

'2.0.1+cu118'

## EXERCISE 3.1 

EXERCISE 3.1 COMPARING SELFATTENTION_V1 AND SELFATTENTION_V2

- Note that nn.Linear in SelfAttention_v2 uses a different weight initialization
scheme as nn.Parameter(torch.rand(d_in, d_out)) used in SelfAttention_v1,
which causes both mechanisms to produce different results. To check that both
implementations, SelfAttention_v1 and SelfAttention_v2, are otherwise similar,
we can transfer the weight matrices from a SelfAttention_v2 object to a
SelfAttention_v1, such that both objects then produce the same results.
Your task is to correctly assign the weights from an instance of SelfAttention_v2 to
an instance of SelfAttention_v1. To do this, you need to understand the
relationship between the weights in both versions. (Hint: nn.Linear stores the
weight matrix in a transposed form.) After the assignment, you should observe that
both instances produce the same outputs.

In [2]:
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55]   # step     (x^6)
    ]
)

print(inputs.shape)
print(inputs)

torch.Size([6, 3])
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


In [4]:
d_in, d_out = 3, 2

In [5]:
import torch.nn as nn 

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
    
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values 
        return context_vec        

In [6]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)

In [7]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=False)
        self.W_key = nn.Linear(d_in, d_out, bias=False)
        self.W_value = nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [8]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)

In [9]:
sa_v1(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

In [10]:
sa_v2(inputs)

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

- transferring weights of sa_v1 to sa_v2 weights

In [11]:
sa_v1.W_query = nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_key = nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value = nn.Parameter(sa_v2.W_value.weight.T)

In [14]:
# checking observe that both instances produce the same outputs
output_v1 = sa_v1(inputs)
output_v2 = sa_v2(inputs)
print(output_v1)
print(output_v2)

# we noticed that all outputs the same

print(torch.allclose(output_v1, output_v2)) 

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)
tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)
True
