In [67]:
import torch

In [68]:
!pip show torch

Name: torch
Version: 2.1.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: C:\Users\Steve\anaconda3\Lib\site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: monai-weekly


In [69]:
# consider the sentence "Life is short, eat dessert first"
sentence = 'Life is short, eat dessert first'


In [70]:
dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


In [71]:
sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


In [72]:
# Now, using the integer-vector representation of the input sentence, we can use an embedding 
# layer to encode the inputs into a real-vector embedding. Here, we will use a 16-dimensional 
# embedding such that each input word is represented by a 16-dimensional vector. Since the sentence 
# consists of 6 words, this will result in a 6×16-dimensional embedding

In [73]:
torch.manual_seed(123) # ensure the reproductibility with the same random number production sequence
embed = torch.nn.Embedding(6, 16) # 6 means can represent 0 ~ 5
embedded_sentence = embed(sentence_int).detach()

print(embed) 
print(embed.weight) # stack the 6 dim 16 tensor together
print(embed.weight.shape)
print(embedded_sentence) # change each of input [0, 4, 5, 2, 1, 3]  to a dim 16 tensor. 
print(embedded_sentence.shape)
print(embedded_sentence[0].shape)

Embedding(6, 16)
Parameter containing:
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0

In [74]:
# Self-attention utilizes three weight matrices, referred to as Wq, Wk,and Wv
# which are adjusted as model parameters during training. These matrices serve 
# to project the inputs into query, key, and value components of the sequence, respectively.

In [75]:
# Query sequence: q(i)=Wq*x(i) for i∈[1,T]
# Key sequence: k(i)=Wk*x(i) for i∈[1,T]
# Value sequence: v(i)=Wv*x(i) for  i∈[1,T]

In [76]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]
print("d: ", d)

d_q, d_k, d_v = 24, 24, 28

W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))
print ("W_query shape: ", W_query.shape, " W_key shape: ", W_key.shape, "W_value shape: ", W_value.shape)

d:  16
W_query shape:  torch.Size([24, 16])  W_key shape:  torch.Size([24, 16]) W_value shape:  torch.Size([28, 16])


In [77]:
# Now, let’s suppose we are interested in computing the attention-vector for the second input element – the second input element acts as the query here:

In [78]:
x_2 = embedded_sentence[1]
print("x_w shape", x_2.shape)
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

x_w shape torch.Size([16])
torch.Size([24])
torch.Size([24])
torch.Size([28])


In [79]:
torch.manual_seed(123)
A = torch.tensor([1,2,3], dtype=float)
print(A.shape)
print(A.T.shape)
# A = A.T
B = torch.rand(3,3, dtype=float)
print(B.matmul(A))

torch.Size([3])
torch.Size([3])
tensor([2.1710, 2.6036, 4.5133], dtype=torch.float64)


In [80]:
keys = W_key.matmul(embedded_sentence.T).T
values = W_value.matmul(embedded_sentence.T).T

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])


In [81]:
omega_24 = query_2.dot(keys[4])
print(omega_24)

tensor(11.1466, grad_fn=<DotBackward0>)


In [82]:
omega_2 = query_2.matmul(keys.T)
print(omega_2)

tensor([ 8.5808, -7.6597,  3.2558,  1.0395, 11.1466, -0.4800],
       grad_fn=<SqueezeBackward4>)


In [83]:
# The subsequent step in self-attention is to normalize the unnormalized attention weights, 
# ω, to obtain the normalized attention weights, α, by applying the softmax function. 
# Additionally, 1/√dk is used to scale ω before normalizing it through the softmax function

In [84]:
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

tensor([0.2912, 0.0106, 0.0982, 0.0625, 0.4917, 0.0458],
       grad_fn=<SoftmaxBackward0>)


In [85]:
# Finally, the last step is to compute the context vector z(2)
# , which is an attention-weighted version of our original query input x(2)
# , including all the other input elements as its context via the attention weights:

In [86]:
context_vector_2 = attention_weights_2.matmul(values)

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([28])
tensor([-1.5993,  0.0156,  1.2670,  0.0032, -0.6460, -1.1407, -0.4908, -1.4632,
         0.4747,  1.1926,  0.4506, -0.7110,  0.0602,  0.7125, -0.1628, -2.0184,
         0.3838, -2.1188, -0.8136, -1.5694,  0.7934, -0.2911, -1.3640, -0.2366,
        -0.9564, -0.5265,  0.0624,  1.7084], grad_fn=<SqueezeBackward4>)


In [87]:
# Note that this output vector has more dimensions (dv=28) than the original input vector (d=16) 
# since we specified dv>d earlier; however, the embedding size choice is arbitrary.

In [88]:
# Multi-Head Attention (3 heads as example)
h = 3

In [89]:
# To illustrate this in code, suppose we have 3 attention heads, so we now extend the d′×d
# dimensional weight matrices so 3×d′×d:

In [90]:
multihead_W_query = torch.nn.Parameter(torch.rand(h, d_q, d))
multihead_W_key = torch.nn.Parameter(torch.rand(h, d_k, d))
multihead_W_value = torch.nn.Parameter(torch.rand(h, d_v, d))
print("multihead_W_query shape: ", multihead_W_query.shape)
print("multihead_W_key shape: ", multihead_W_key.shape)
print("multihead_W_value shape: ", multihead_W_value.shape)


multihead_W_query shape:  torch.Size([3, 24, 16])
multihead_W_key shape:  torch.Size([3, 24, 16])
multihead_W_value shape:  torch.Size([3, 28, 16])


In [91]:
multihead_query_2 = multihead_W_query.matmul(x_2)
print(multihead_query_2.shape)

torch.Size([3, 24])


In [92]:
stacked_inputs = embedded_sentence.T.repeat(3, 1, 1)
print(stacked_inputs.shape)
# print(stacked_inputs)

torch.Size([3, 16, 6])
tensor([[[ 0.3374,  0.5146,  0.2553, -1.3250, -0.0770,  0.8768],
         [-0.1778,  0.9938, -0.5496,  0.1784, -1.0205,  1.6221],
         [-0.3035, -0.2587,  1.0042, -2.1338, -0.1690, -1.4779],
         [-0.5880, -1.0826,  0.8272,  1.0524,  0.9178,  1.1331],
         [ 0.3486, -0.0444, -0.3948, -0.3885,  1.5810, -1.2203],
         [ 0.6603,  1.6236,  0.4892, -0.9343,  1.3010,  1.3139],
         [-0.2196, -2.3229, -0.2168, -0.4991,  1.2753,  1.0533],
         [-0.3792,  1.0878, -1.7472, -1.0867, -0.2010,  0.1388],
         [ 0.7671,  0.6716, -1.6025,  0.8805,  0.4965,  2.2473],
         [-1.1925,  0.6933, -1.0764,  1.5542, -1.5723, -0.8036],
         [ 0.6984, -0.9487,  0.9031,  0.6266,  0.9666, -0.2808],
         [-1.4097, -0.0765, -0.7218, -0.1755, -1.1481,  0.7697],
         [ 0.1794, -0.1526, -0.5951,  0.0983, -1.1589, -0.6596],
         [ 1.8951,  0.1167, -0.7112, -0.0935,  0.3255, -0.7979],
         [ 0.4954,  0.4403,  0.6230,  0.2662, -0.6315,  0.1838],
  

In [93]:
multihead_keys = torch.bmm(multihead_W_key, stacked_inputs) # bmm batch matrix multiplication
multihead_values = torch.bmm(multihead_W_value, stacked_inputs)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 24, 6])
multihead_values.shape: torch.Size([3, 28, 6])


In [94]:
multihead_keys = multihead_keys.permute(0, 2, 1)
multihead_values = multihead_values.permute(0, 2, 1)
print("multihead_keys.shape:", multihead_keys.shape)
print("multihead_values.shape:", multihead_values.shape)

multihead_keys.shape: torch.Size([3, 6, 24])
multihead_values.shape: torch.Size([3, 6, 28])


In [None]:
# Then, we follow the same steps as previously to compute the unscaled 
# attention weights ω and attention weights α, followed by the scaled
# -softmax computation to obtain an h×dv (here: 3×dv) dimensional context 
# vector z for the input element x(2).