In [3]:
import torch
from torch.nn import functional as F

In [5]:
def verbose_attention(encoder_state_vectors, query_vector):
    """ 원소별 연산을 사용하는 어텐션 메커니즘 버전
    
    매개변수:
        encoder_state_vectors (torch.Tensor): 인코더의 양방향 GRU에서 출력된 3차원 텐서
        query_vector (torch.Tensor): 디코더 GRU의 은닉 상태
    """
    # TRY IT YOURSELF
    batch_size, num_vectors, vector_size = encoder_state_vectors.size()
    vector_scores = torch.sum(encoder_state_vectors * query_vector.view(batch_size, 1, vector_size), dim=2)
    vector_probabilities = F.softmax(vector_scores, dim=1)
    weighted_vectors = encoder_state_vectors * vector_probabilities.view(batch_size, num_vectors, 1)
    context_vectors = torch.sum(weighted_vectors, dim=1)
    
    return context_vectors, vector_probabilities, vector_scores

def terse_attention(encoder_state_vectors, query_vector):
    """ 점곱을 사용하는 어텐션 메커니즘 버전
    
    매개변수:
        encoder_state_vectors (torch.Tensor): 인코더의 양방향 GRU에서 출력된 3차원 텐서
        query_vector (torch.Tensor): 디코더 GRU의 은닉 상태
    """
    # TRY IT YOURSELF
    vector_scores = torch.matmul(encoder_state_vectors, query_vector.unsqueeze(dim=2)).squeeze()
    vector_probabilities = F.softmax(vector_scores, dim=-1)
    context_vectors = torch.matmul(encoder_state_vectors.transpose(-2, -1),
                                   vector_probabilities.unsqueeze(dim=2)).squeeze()
    
    return context_vectors, vector_probabilities

In [None]:
import numpy as np

np.random.seed(1234)
encoder_states = np.random.rand(3, 5, 4) # batch_size, num_vectors, vector_size
query_vector = encoder_states[:, -1, :]

encoder_states = torch.tensor(encoder_states)
query_vector = torch.tensor(query_vector)

context_vectors, vector_probabilities, vector_scores = verbose_attention(encoder_states, query_vector)
print(vector_probabilities.numpy())

context_vectors, vector_probabilities = terse_attention(encoder_states, query_vector)
print(vector_probabilities.numpy())

[[0.16908073 0.2026648  0.18250851 0.16886662 0.27687935]
 [0.12252482 0.24226674 0.17525572 0.15620217 0.30375056]
 [0.05641803 0.11084368 0.09659718 0.14712325 0.58901787]]
[[0.16908073 0.2026648  0.18250851 0.16886662 0.27687935]
 [0.12252482 0.24226674 0.17525572 0.15620217 0.30375056]
 [0.05641803 0.11084368 0.09659718 0.14712325 0.58901787]]


In [9]:
query_vector.shape

torch.Size([3, 4])