In [1]:
import torch
import pprint

# 创建嵌入张量

In [2]:
wordlist: list[str] = ["apple", "and", "banana", "eat", "I", "like", "to"]
dictionary: dict[str, int] = {word: index for index, word in enumerate(wordlist)}

sentence: str = "I like to eat apple and banana"
sentence_index: torch.IntTensor = torch.IntTensor([dictionary[word] for word in sentence.split()]).cuda()

embedding_layer = torch.nn.Embedding(len(wordlist), 16).cuda()
sentence_embedding = embedding_layer(sentence_index)

pprint.pprint((sentence_index, sentence_embedding))

(tensor([4, 5, 6, 3, 0, 1, 2], device='cuda:0', dtype=torch.int32),
 tensor([[-7.3236e-01,  6.1053e-01, -4.3208e-01,  5.3864e-01, -9.3632e-01,
         -1.9156e-01, -6.3060e-01,  1.1363e-01,  1.8043e+00, -9.4311e-02,
         -1.4965e-01, -5.4332e-01,  1.2353e+00, -1.4759e+00, -1.0286e+00,
          5.7266e-01],
        [-5.0617e-01,  2.0712e+00,  9.6715e-02, -4.5395e-01, -6.9383e-01,
         -1.2759e+00, -2.0697e+00,  7.8007e-01, -2.8403e+00,  1.5434e+00,
          1.4221e+00, -1.5882e+00,  2.8490e-01,  8.4525e-01,  2.7150e-01,
          4.0056e-01],
        [-9.2905e-01, -2.8743e-01,  1.4922e+00,  4.9263e-03, -5.2021e-01,
          9.4936e-01,  1.2950e+00,  1.2434e+00, -9.3441e-02,  6.4919e-01,
         -3.7933e-01, -7.7334e-01,  6.7946e-01, -9.8882e-01,  3.3297e-01,
         -1.0761e+00],
        [-2.3402e-01,  1.7017e+00, -1.1020e+00, -8.8707e-01,  8.3205e-01,
         -2.0504e+00, -2.2760e-01, -3.7249e-01, -9.0152e-01, -9.5659e-01,
         -6.8315e-01,  1.4186e+00, -6.3702e-02, 

In [6]:
n, d_embedding, d_query, d_key, d_value = len(wordlist), sentence_embedding.shape[1], 24, 24, 28
W_query = torch.nn.Parameter(torch.rand(d_query, d_embedding)).cuda()
W_key = torch.nn.Parameter(torch.rand(d_key, d_embedding)).cuda()
W_value = torch.nn.Parameter(torch.rand(d_value, d_embedding)).cuda()

keys = torch.matmul(W_key, sentence_embedding.T).T
values = torch.matmul(W_value, sentence_embedding.T).T
queries = torch.matmul(W_query, sentence_embedding.T).T

pprint.pprint({
    "keys.shape": keys.shape,
    "values.shape": values.shape,
    "queries.shape": queries.shape
})

{'keys.shape': torch.Size([7, 24]),
 'queries.shape': torch.Size([7, 24]),
 'values.shape': torch.Size([7, 28])}


In [15]:
query_1 = W_query @ sentence_embedding[1] # shape: d_query * 1
omega_1 = query_1 @ keys.T                # shape: d_embedding * 1

attention_1 = torch.nn.functional.softmax(omega_1 / d_key ** 0.5, dim=0) # shape: d_embedding * 1
context_1 = attention_1 @ values # shape: d_value * 1

pprint.pprint({
    "query_1.shape": query_1.shape,
    "omega_1.shape": omega_1.shape,
    "attention_1.shape": attention_1.shape,
    "context_1.shape": context_1.shape
}, sort_dicts=False)

{'query_1.shape': torch.Size([24]),
 'omega_1.shape': torch.Size([7]),
 'attention_1.shape': torch.Size([7]),
 'context_1.shape': torch.Size([28])}
