In [1]:
input_text = "나는 최근 파리 여행을 다녀왔다"
input_text_list = input_text.split()

input_text_list

['나는', '최근', '파리', '여행을', '다녀왔다']

In [2]:
str2idx = {word:idx for idx, word in enumerate(input_text_list)}
idx2str = {idx:word for idx, word in enumerate(input_text_list)}

print(str2idx)
print(idx2str)

{'나는': 0, '최근': 1, '파리': 2, '여행을': 3, '다녀왔다': 4}
{0: '나는', 1: '최근', 2: '파리', 3: '여행을', 4: '다녀왔다'}


In [3]:
input_ids = [str2idx[word] for word in input_text_list]
input_ids

[0, 1, 2, 3, 4]

In [4]:
import torch
import torch.nn as nn

In [5]:
embedding_dims = 16
embed_layer = nn.Embedding(len(str2idx), embedding_dim=embedding_dims)

In [6]:
input_embeddings = embed_layer(torch.tensor(input_ids))
input_embeddings = input_embeddings.unsqueeze(0)
input_embeddings.shape

torch.Size([1, 5, 16])

In [7]:
embedding_dim = 16
max_position = 12
embed_layer = nn.Embedding(len(str2idx), embedding_dim)
position_embed_layer = nn.Embedding(max_position, embedding_dim)

position_ids = torch.arange(len(input_ids), dtype=torch.long).unsqueeze(0)
position_encodings = position_embed_layer(position_ids)
token_embeddings = embed_layer(torch.tensor(input_ids))
token_embeddings = token_embeddings.unsqueeze(0)
input_embeddings = token_embeddings + position_encodings

input_embeddings.shape

torch.Size([1, 5, 16])

In [9]:
head_dim = 16

weight_q = nn.Linear(embedding_dim, head_dim)
weight_k = nn.Linear(embedding_dim, head_dim)
weight_v = nn.Linear(embedding_dim, head_dim)

querys = weight_q(input_embeddings)
keys = weight_k(input_embeddings)
values = weight_v(input_embeddings)

In [11]:
from math import sqrt
import torch.nn.functional as F

def compute_attention(querys, keys, values, is_causal=False):
    dim_k = querys.size(-1)
    scores = querys @ keys.transpose(-2, -1) / sqrt(dim_k)
    weights = F.softmax(scores, dim=-1)
    return weights @ values

In [12]:
print("원본 입력 형태 : ", input_embeddings.shape)

after_attention_embeddings = compute_attention(querys, keys, values)

print("어텐션 적용 후 입력 형태 : ", after_attention_embeddings.shape)

원본 입력 형태 :  torch.Size([1, 5, 16])
어텐션 적용 후 입력 형태 :  torch.Size([1, 5, 16])


In [None]:
class AttentionHead(nn.Module):

    def __init__(self, token_embed_dim, head_dim, is_causal=False):
        super().__init__()
        
        self.weight_q = nn.Linear(token_embed_dim, head_dim)
        self.weight_k = nn.Linear(token_embed_dim, head_dim)
        self.weight_v = nn.Linear(token_embed_dim, head_dim)
        self.is_causal = is_causal

    def forward(self, querys, keys, values):
        output = compute_attention(self.weight_q(querys), self.weight_k(keys), self.weight_v(values), self.is_causal)
        return output
        
