In [43]:
import torch

# x sequence in {1..8}
sentence = torch.tensor([0, 7, 1, 2, 5, 6, 4, 3])
torch.manual_seed(123)
embed = torch.nn.Embedding(10, 16)
embedded_sentence = embed(sentence)
embedded_sentence.shape

torch.Size([8, 16])

In [44]:
omega = torch.zeros(8, 8)
for i, x_i in enumerate(embedded_sentence):
    for j, x_j in enumerate(embedded_sentence):
        omega[i, j] = torch.dot(x_i, x_j)
print(omega)
# embedded_sentence is 8x16, transpose: 16x8
# x_i的形式是行向量，每个行下标代表一个输入，即x_i
omega_mat = torch.matmul(embedded_sentence, embedded_sentence.T)
print(torch.allclose(omega, omega_mat))
import torch.nn.functional as F
# 在omega_i_j中，我们对j分量进行归一化
attention_weights = F.softmax(omega, dim=1)

tensor([[ 9.7601,  1.7326,  4.7543, -1.3587,  0.4752, -1.6717,  1.0227, -0.1286],
        [ 1.7326, 16.0787,  9.0642, -0.3370,  1.1368,  1.1972,  1.6485, -1.2789],
        [ 4.7543,  9.0642, 22.6615, -0.8519,  7.7799,  2.7483, -0.6832,  1.6236],
        [-1.3587, -0.3370, -0.8519, 13.9473, -1.4198, 10.9659, -0.5887,  2.3869],
        [ 0.4752,  1.1368,  7.7799, -1.4198, 13.7511, -6.8568, -2.5114, -3.3468],
        [-1.6717,  1.1972,  2.7483, 10.9659, -6.8568, 24.6738, -3.8294,  4.9581],
        [ 1.0227,  1.6485, -0.6832, -0.5887, -2.5114, -3.8294, 15.8691,  2.0269],
        [-0.1286, -1.2789,  1.6236,  2.3869, -3.3468,  4.9581,  2.0269, 18.7382]],
       grad_fn=<CopySlices>)
True


In [45]:
x_2 = embedded_sentence[1, :]
context_vec_2 = torch.zeros(x_2.shape)
for j in range(8):
    x_j = embedded_sentence[j, :]
    context_vec_2 += attention_weights[1, j] * x_j
print(context_vec_2)

tensor([-9.3975e-01, -4.6856e-01,  1.0311e+00, -2.8192e-01,  4.9373e-01,
        -1.2896e-02, -2.7327e-01, -7.6358e-01,  1.3958e+00, -9.9543e-01,
        -7.1287e-04,  1.2449e+00, -7.8077e-02,  1.2765e+00, -1.4589e+00,
        -2.1601e+00], grad_fn=<AddBackward0>)


In [46]:
context_vectors = torch.matmul(attention_weights, embedded_sentence)
print(context_vectors)

tensor([[ 3.3420e-01, -1.8324e-01, -3.0218e-01, -5.7772e-01,  3.5662e-01,
          6.6452e-01, -2.0998e-01, -3.7798e-01,  7.6537e-01, -1.1946e+00,
          6.9960e-01, -1.4067e+00,  1.7021e-01,  1.8838e+00,  4.8729e-01,
          2.4730e-01],
        [-9.3975e-01, -4.6856e-01,  1.0311e+00, -2.8192e-01,  4.9373e-01,
         -1.2896e-02, -2.7327e-01, -7.6358e-01,  1.3958e+00, -9.9543e-01,
         -7.1287e-04,  1.2449e+00, -7.8077e-02,  1.2765e+00, -1.4589e+00,
         -2.1601e+00],
        [-7.7021e-02, -1.0205e+00, -1.6895e-01,  9.1776e-01,  1.5810e+00,
          1.3010e+00,  1.2753e+00, -2.0095e-01,  4.9647e-01, -1.5723e+00,
          9.6657e-01, -1.1481e+00, -1.1589e+00,  3.2547e-01, -6.3151e-01,
         -2.8400e+00],
        [-1.3679e+00,  1.0614e-01, -2.1317e+00,  1.0480e+00, -3.7127e-01,
         -9.1234e-01, -4.3802e-01, -1.0329e+00,  9.3425e-01,  1.5453e+00,
          5.7218e-01, -1.8049e-01, -6.0454e-03, -8.8691e-02,  2.0559e-01,
         -5.2292e-01],
        [ 2.5444e-01

In [47]:
torch.manual_seed(123)
d = embedded_sentence.shape[1]
U_query = torch.rand(d, d)
U_key = torch.rand(d, d)
U_value = torch.rand(d, d)
x_2 = embedded_sentence[1]
query_2 = U_query.matmul(x_2)
key_2 = U_key.matmul(x_2)
value_2 = U_value.matmul(x_2)
keys = U_key.matmul(embedded_sentence.T).T
values = (U_value @ embedded_sentence.T).T
print(torch.allclose(keys[1], key_2))

True


In [48]:
omega_23 = query_2 @ keys[2]
omega_23

tensor(14.3667, grad_fn=<DotBackward0>)

In [49]:
omega_2 = query_2 @ keys.T
print(omega_2)
queries = (U_query @ embedded_sentence.T).T
omega = (queries @ keys.T)
print(omega)
torch.allclose(omega_2, omega[1])

tensor([-25.1623,   9.3602,  14.3667,  32.1482,  53.8976,  46.6626,  -1.2131,
        -32.9392], grad_fn=<SqueezeBackward4>)
tensor([[ -0.7569,  -3.7951,  -7.9465, -10.0615, -12.1732, -12.8006,   4.1644,
           6.3346],
        [-25.1623,   9.3602,  14.3667,  32.1482,  53.8976,  46.6626,  -1.2131,
         -32.9392],
        [-28.8096,  10.9046,  14.4355,  23.8255,  52.7999,  41.3237,   1.5884,
         -35.1890],
        [-15.5115,  17.5500,  19.8771,  21.5002,  42.0597,  35.2061,  -0.5541,
         -25.9203],
        [-36.3682,  20.2438,  27.1240,  49.8610,  84.9364,  85.7472,   5.8265,
         -69.9103],
        [-34.6901,  38.3814,  42.0269,  48.1298,  92.0512,  74.9869,  -6.6510,
         -65.5576],
        [ -1.1880,   3.7619,  -5.6129,  -6.8691,   6.3126, -13.3452,  -1.3225,
          -6.2390],
        [ 31.8297, -25.2041, -25.3536, -57.8440, -79.4676, -85.3054, -10.5390,
          64.5980]], grad_fn=<MmBackward0>)


True

In [50]:
X = embedded_sentence.T
Q = U_query @ X
K = U_key @ X
omega = Q.T @ K
print(torch.allclose(omega_23, omega[1, 2])) # True
query_2 = U_query @ x_2 # x_2 is a rank-1 tensor, we regard as dx1 vector
omega_2 = query_2 @ keys.T
print(torch.allclose(omega_2, omega[1, :])) # True

True
True


In [51]:
import torch.nn.functional as F
attention_weights_2 = F.softmax(omega_2 / d**5e-1, dim=0)
print(attention_weights_2)
context_vector_2 = attention_weights_2 @ values
print(context_vector_2)
print(torch.allclose(context_vector_2, values.T @ attention_weights_2))

tensor([2.2317e-09, 1.2499e-05, 4.3696e-05, 3.7242e-03, 8.5596e-01, 1.4026e-01,
        8.8897e-07, 3.1935e-10], grad_fn=<SoftmaxBackward0>)
tensor([-1.2226, -3.4387, -4.3928, -5.2125, -1.1249, -3.3041, -1.4316, -3.2765,
        -2.5114, -2.6105, -1.5793, -2.8433, -2.4142, -0.3998, -1.9917, -3.3499],
       grad_fn=<SqueezeBackward4>)
True


In [52]:
attention_weights = F.softmax(omega / d**5e-1, dim=1)
context_vectors = attention_weights @ values
print(
    torch.allclose(context_vectors[1], context_vector_2)
)
print(attention_weights.shape)

True
torch.Size([8, 8])


In [53]:
t1 = torch.tensor([1, 2])
t2 = torch.arange(0, 6)
t2 = t2.reshape((2, 3))
t = t1 @ t2
t

tensor([ 6,  9, 12])

In [57]:
import torch
torch.manual_seed(123)
d = embedded_sentence.shape[1]
one_U_query = torch.rand(d, d)
h = 8
multihead_U_query = torch.rand(h, d, d)
multihead_U_key = torch.rand(h, d, d)
multihead_U_value = torch.rand(h, d, d)

multihead_query_2 = multihead_U_query @ x_2
print(multihead_query_2.shape)
print(x_2.shape)
print(multihead_U_query.shape)

torch.Size([8, 16])
torch.Size([16])
torch.Size([8, 16, 16])


In [59]:
multihead_key_2 = multihead_U_key @ x_2
multihead_value_2 = multihead_U_value @ x_2
multihead_key_2[2]

tensor([-1.9619, -0.7701, -0.7280, -1.6840, -1.0801, -1.6778,  0.6763,  0.6547,
         1.4445, -2.7016, -1.1364, -1.1204, -2.4430, -0.5982, -0.8292, -1.4401],
       grad_fn=<SelectBackward0>)

In [60]:
print(embedded_sentence.T)

tensor([[ 3.3737e-01, -9.4053e-01, -7.7020e-02, -1.3250e+00,  2.5529e-01,
         -2.2150e+00,  5.1463e-01,  8.7684e-01],
        [-1.7778e-01, -4.6806e-01, -1.0205e+00,  1.7843e-01, -5.4963e-01,
         -1.3193e+00,  9.9376e-01,  1.6221e+00],
        [-3.0353e-01,  1.0322e+00, -1.6896e-01, -2.1338e+00,  1.0042e+00,
         -2.0915e+00, -2.5873e-01, -1.4779e+00],
        [-5.8801e-01, -2.8300e-01,  9.1776e-01,  1.0524e+00,  8.2723e-01,
          9.6285e-01, -1.0826e+00,  1.1331e+00],
        [ 3.4861e-01,  4.9275e-01,  1.5810e+00, -3.8848e-01, -3.9481e-01,
         -3.1861e-02, -4.4382e-02, -1.2203e+00],
        [ 6.6034e-01, -1.4078e-02,  1.3010e+00, -9.3435e-01,  4.8923e-01,
         -4.7896e-01,  1.6236e+00,  1.3139e+00],
        [-2.1964e-01, -2.7466e-01,  1.2753e+00, -4.9914e-01, -2.1681e-01,
          7.6681e-01, -2.3229e+00,  1.0533e+00],
        [-3.7917e-01, -7.6409e-01, -2.0095e-01, -1.0867e+00, -1.7472e+00,
          2.7468e-02,  1.0878e+00,  1.3881e-01],
        [ 7.6711

In [61]:
stacked_inputs = embedded_sentence.T.repeat(8, 1, 1)
print(stacked_inputs)

tensor([[[ 0.3374, -0.9405, -0.0770,  ..., -2.2150,  0.5146,  0.8768],
         [-0.1778, -0.4681, -1.0205,  ..., -1.3193,  0.9938,  1.6221],
         [-0.3035,  1.0322, -0.1690,  ..., -2.0915, -0.2587, -1.4779],
         ...,
         [ 1.8951,  1.2774,  0.3255,  ...,  0.0064,  0.1167, -0.7979],
         [ 0.4954, -1.4596, -0.6315,  ..., -0.9896,  0.4403,  0.1838],
         [ 0.2692, -2.1595, -2.8400,  ...,  0.7016, -1.4465,  0.2293]],

        [[ 0.3374, -0.9405, -0.0770,  ..., -2.2150,  0.5146,  0.8768],
         [-0.1778, -0.4681, -1.0205,  ..., -1.3193,  0.9938,  1.6221],
         [-0.3035,  1.0322, -0.1690,  ..., -2.0915, -0.2587, -1.4779],
         ...,
         [ 1.8951,  1.2774,  0.3255,  ...,  0.0064,  0.1167, -0.7979],
         [ 0.4954, -1.4596, -0.6315,  ..., -0.9896,  0.4403,  0.1838],
         [ 0.2692, -2.1595, -2.8400,  ...,  0.7016, -1.4465,  0.2293]],

        [[ 0.3374, -0.9405, -0.0770,  ..., -2.2150,  0.5146,  0.8768],
         [-0.1778, -0.4681, -1.0205,  ..., -1

In [74]:
# x是一个shape为(3, 2)的Tensor
x = torch.tensor([[1, 2], 
                  [3, 4], 
                  [5, 6]])
#得到shape为(2*3, 3*2)，即(6, 6)的Tensor
#将维度从后往前的考察。首先，将x的第二维扩充到3*x.shape[1]维，也就是重复x的第二维3次。
#随后，将x的第一维扩充到2*x.shape[0]维，扩充部分是repeat；
print(x.repeat(2, 3))
#再看一个例子，第一维重复扩充了3背，第二维保持。所以得到shape为(6,2)。
print(x.repeat(3, 1)) # 得到一个rank-2，每个保留了第二维，重复了2次第一维。
#对于扩充到高维的情况，倒数第一维相当于扩充了x的倒数第一维的3倍，采用重复的策略；
#倒数第二维将x的倒数第二维按照重复的策略扩充2倍；
#由于x没有倒数第三维，所以将已经重复扩充的结果，repeat，重复之。
print(x.repeat(2, 2, 3))

tensor([[1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4],
        [5, 6, 5, 6, 5, 6],
        [1, 2, 1, 2, 1, 2],
        [3, 4, 3, 4, 3, 4],
        [5, 6, 5, 6, 5, 6]])
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [1, 2],
        [3, 4],
        [5, 6],
        [1, 2],
        [3, 4],
        [5, 6]])
tensor([[[1, 2, 1, 2, 1, 2],
         [3, 4, 3, 4, 3, 4],
         [5, 6, 5, 6, 5, 6],
         [1, 2, 1, 2, 1, 2],
         [3, 4, 3, 4, 3, 4],
         [5, 6, 5, 6, 5, 6]],

        [[1, 2, 1, 2, 1, 2],
         [3, 4, 3, 4, 3, 4],
         [5, 6, 5, 6, 5, 6],
         [1, 2, 1, 2, 1, 2],
         [3, 4, 3, 4, 3, 4],
         [5, 6, 5, 6, 5, 6]]])
