# Attention

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### 어텐션 가중치 계산

In [15]:
def attention(query, key, value):
    # 1. 어텐션 스코어 계산 (Query - Key)
    scores = torch.matmul(query, key.transpose(-2, -1))
    print('Attetion Score shape:', scores.shape)

    # 2. Softmax 적용 (가중치 계산)
    attention_weights = F.softmax(scores, dim=-1)
    print('Attention weights shape:', attention_weights.shape)

    # 3. 어텐션 밸류 계산 (Value 적용 => 최종 Context vector 계산)
    context_vector = torch.matmul(attention_weights, value)
    print('Context Vector shape:', context_vector.shape)

    return context_vector

In [16]:
# 토큰화 및 임베딩 결과 예시
vocab = {
    "나는": 0,
    "학교에": 1,
    "간다": 2,
    "<pad>": 3
}
vocab_size = len(vocab)
EMBEDDING_DIM = 4

In [17]:
# 입력 문장
inputs = ["나는", "학교에", "간다"]
inputs_ids = torch.tensor([[vocab[word] for word in inputs]]) # (1, 3)

In [18]:
# 1. 임베딩 적용
embedding_layer = nn.Embedding(vocab_size, EMBEDDING_DIM)
inputs_embedded = embedding_layer(inputs_ids)
# print(inputs_embedded.shape)

# 2. 선형 변환 -> Query, Key, Value
HIDDEN_DIM = 4
W_query = nn.Linear(EMBEDDING_DIM, HIDDEN_DIM)
W_key = nn.Linear(EMBEDDING_DIM, HIDDEN_DIM)
W_value = nn.Linear(EMBEDDING_DIM, HIDDEN_DIM) 

input_query = W_query(inputs_embedded)
input_key = W_key(inputs_embedded)
input_value = W_value(inputs_embedded)

print(input_query.shape, input_key.shape, input_value.shape)

torch.Size([1, 3, 4]) torch.Size([1, 3, 4]) torch.Size([1, 3, 4])


In [19]:
context_vector = attention(input_query, input_key, input_value)
context_vector

Attetion Score shape: torch.Size([1, 3, 3])
Attention weights shape: torch.Size([1, 3, 3])
Context Vector shape: torch.Size([1, 3, 4])


tensor([[[ 0.5648,  0.1701, -0.2122,  0.2961],
         [ 0.0816, -0.1404, -0.0957,  0.3069],
         [-0.0932, -0.2525, -0.0530,  0.3114]]], grad_fn=<UnsafeViewBackward0>)

### Seq2Seq 모델에 어텐션 추가

In [46]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, hidden, encoder_outputs):
        seq_len = encoder_outputs.shape[1]
        hidden_expanded = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden_expanded, encoder_outputs), dim=2)))
        attention_scores = torch.sum(self.v * energy, dim=2)
        attention_weights = F.softmax(attention_scores, dim=1)
        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context_vector, attention_weights

In [47]:
class Seq2SeqWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.decoder = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.attention = Attention(hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.decoder_input_transform = nn.Linear(input_dim, hidden_dim)

    def forward(self, encoder_input, decoder_input):
        encoder_outputs, hidden = self.encoder(encoder_input)
        context_vector, _ = self.attention(hidden[-1], encoder_outputs)
        decoder_input_ = self.decoder_input_transform(decoder_input)
        output, _ = self.decoder(decoder_input_, hidden)
        combined = torch.cat((output, context_vector.unsqueeze(1)), dim=2)
        return self.fc(combined)

In [48]:
batch_size = 1
seq_len = 5
input_dim = 10
hidden_dim = 20
output_dim = 15

encoder_input = torch.randn(batch_size, seq_len, input_dim)
decoder_input = torch.randn(batch_size, 1, input_dim)

model = Seq2SeqWithAttention(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim)
result = model(encoder_input, decoder_input)
print(result)


tensor([[[-0.1102, -0.1047,  0.0012,  0.0793,  0.1339, -0.0231,  0.1965,
           0.0913,  0.0587,  0.0933, -0.1886, -0.1545,  0.0729, -0.0372,
           0.1470]]], grad_fn=<ViewBackward0>)
