In [1]:
import torch
import torch.nn as nn

In [2]:
class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.wq = nn.Parameter(torch.randn(d_in, d_out))
        self.wk = nn.Parameter(torch.randn(d_in, d_out))
        self.wv = nn.Parameter(torch.randn(d_in, d_out))

    def forward(self, x):
        query = x @ self.wq
        keys= x @ self.wk
        values = x @ self.wv

        attention_score = query @ keys.T
        attention_weights = torch.softmax(attention_score/(keys.shape[-1]** 0.5), dim = -1)
        context_vectors = attention_weights @ values
        return context_vectors

In [3]:
torch.manual_seed(123)
inputs = torch.rand((6,3))

d_in = 3
d_out = 2
sa_v1 = SelfAttentionV1(d_in,d_out)

In [4]:
print(sa_v1(inputs)) # if you compare this output with the output from the other notebook
# self_attention_with_trainable_weights.ipynb you will find that it matches exactly.

tensor([[-1.6836,  0.3672],
        [-1.7127,  0.3734],
        [-1.6905,  0.3618],
        [-1.6909,  0.3751],
        [-1.7109,  0.3617],
        [-1.7347,  0.3683]], grad_fn=<MmBackward0>)


In [5]:
class SelfAttentionV2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.wq = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.wk = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.wv = nn.Linear(d_in, d_out,bias=qkv_bias)
    
    def forward(self, x):
        query= self.wq(x)
        keys = self.wk(x)
        values = self.wv(x)

        attention_score = query @ keys.T
        attention_weights = torch.softmax(attention_score/(keys.shape[-1]**0.5),dim=-1)
        context_matrix= attention_weights @ values
        return context_matrix

In [6]:
torch.manual_seed(123)
inputs = torch.rand((6,3))

d_in = 3
d_out = 2
sa_v2 = SelfAttentionV2(d_in,d_out)

In [7]:
print(sa_v2(inputs))

tensor([[0.3801, 0.1600],
        [0.3776, 0.1624],
        [0.3817, 0.1591],
        [0.3783, 0.1620],
        [0.3802, 0.1597],
        [0.3773, 0.1621]], grad_fn=<MmBackward0>)


**The following is the exercise 3.1 from the book**

In [8]:
# lets extract the the linear weights of query, keys and values.
list(sa_v2.named_parameters())

[('wq.weight',
  Parameter containing:
  tensor([[-0.1362,  0.1853,  0.4083],
          [ 0.1076,  0.1579,  0.5573]], requires_grad=True)),
 ('wk.weight',
  Parameter containing:
  tensor([[-0.2604,  0.1829, -0.2569],
          [ 0.4126,  0.4611, -0.5323]], requires_grad=True)),
 ('wv.weight',
  Parameter containing:
  tensor([[ 0.4929,  0.2757,  0.2516],
          [ 0.2377,  0.4800, -0.0762]], requires_grad=True))]

In [9]:
# list the parameters
list(sa_v2.named_parameters())[0]

('wq.weight',
 Parameter containing:
 tensor([[-0.1362,  0.1853,  0.4083],
         [ 0.1076,  0.1579,  0.5573]], requires_grad=True))

In [10]:
# get parameter by name
sa_v2.get_parameter("wq.weight")[0]

tensor([-0.1362,  0.1853,  0.4083], grad_fn=<SelectBackward0>)

In [11]:
weight_list= []
weights = ["wq", "wk","wv"]
for weight in weights:
    weight_list.append(sa_v2.get_parameter(f"{weight}.weight")[0])

In [12]:
weight_list

[tensor([-0.1362,  0.1853,  0.4083], grad_fn=<SelectBackward0>),
 tensor([-0.2604,  0.1829, -0.2569], grad_fn=<SelectBackward0>),
 tensor([0.4929, 0.2757, 0.2516], grad_fn=<SelectBackward0>)]

**I found rather than using `get_parameter` I can simply use `state_dict` to access weights and assign them, so below I try that again, using `get_parameter` complicates things as it also copies the `grad_fn`**

In [13]:
weight_list = []
weights = ["wq", "wk","wv"]
for weight in weights:
    #print(weight)
    weight_list.append(sa_v2.state_dict()[f"{weight}.weight"])
    
print("weight list: ", weight_list)    

weight list:  [tensor([[-0.1362,  0.1853,  0.4083],
        [ 0.1076,  0.1579,  0.5573]]), tensor([[-0.2604,  0.1829, -0.2569],
        [ 0.4126,  0.4611, -0.5323]]), tensor([[ 0.4929,  0.2757,  0.2516],
        [ 0.2377,  0.4800, -0.0762]])]


In [17]:
for param, weight in zip(weights, weight_list):
    print(param, weight.T)
    sa_v1.state_dict()[param].copy_(weight.T)

wq tensor([[-0.1362,  0.1076],
        [ 0.1853,  0.1579],
        [ 0.4083,  0.5573]])
wk tensor([[-0.2604,  0.4126],
        [ 0.1829,  0.4611],
        [-0.2569, -0.5323]])
wv tensor([[ 0.4929,  0.2377],
        [ 0.2757,  0.4800],
        [ 0.2516, -0.0762]])


In [18]:
sa_v1.state_dict()

OrderedDict([('wq',
              tensor([[-0.1362,  0.1076],
                      [ 0.1853,  0.1579],
                      [ 0.4083,  0.5573]])),
             ('wk',
              tensor([[-0.2604,  0.4126],
                      [ 0.1829,  0.4611],
                      [-0.2569, -0.5323]])),
             ('wv',
              tensor([[ 0.4929,  0.2377],
                      [ 0.2757,  0.4800],
                      [ 0.2516, -0.0762]]))])

In [19]:
sa_v2.state_dict()

OrderedDict([('wq.weight',
              tensor([[-0.1362,  0.1853,  0.4083],
                      [ 0.1076,  0.1579,  0.5573]])),
             ('wk.weight',
              tensor([[-0.2604,  0.1829, -0.2569],
                      [ 0.4126,  0.4611, -0.5323]])),
             ('wv.weight',
              tensor([[ 0.4929,  0.2757,  0.2516],
                      [ 0.2377,  0.4800, -0.0762]]))])

In [20]:
sa_v1(inputs)

tensor([[0.3801, 0.1600],
        [0.3776, 0.1624],
        [0.3817, 0.1591],
        [0.3783, 0.1620],
        [0.3802, 0.1597],
        [0.3773, 0.1621]], grad_fn=<MmBackward0>)

In [21]:
sa_v2(inputs)

tensor([[0.3801, 0.1600],
        [0.3776, 0.1624],
        [0.3817, 0.1591],
        [0.3783, 0.1620],
        [0.3802, 0.1597],
        [0.3773, 0.1621]], grad_fn=<MmBackward0>)