In [7]:
import torch
import torch.nn as nn
import math

In [2]:
class GlobalGraph(nn.Module):
    def __init__(self, hidden_size, attention_head_size=None, num_attention_heads=1):
        super(GlobalGraph, self).__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads if attention_head_size is None else attention_head_size
        self.all_head_size = self.num_attention_heads * self.attention_head_size
 
        self.num_qkv = 1
 
        self.query = nn.Linear(hidden_size, self.all_head_size * self.num_qkv)
        self.key = nn.Linear(hidden_size, self.all_head_size * self.num_qkv)
        self.value = nn.Linear(hidden_size, self.all_head_size * self.num_qkv)
 
    def get_extended_attention_mask(self, attention_mask):
        """
        1 in attention_mask stands for doing attention, 0 for not doing attention.
        After this function, 1 turns to 0, 0 turns to -10000.0
        Because the -10000.0 will be fed into softmax and -10000.0 can be thought as 0 in softmax.
        """
        extended_attention_mask = attention_mask.unsqueeze(1)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask
 
    def transpose_for_scores(self, x):
        sz = x.size()[:-1] + (self.num_attention_heads,
                              self.attention_head_size)
        # (batch, max_vector_num, head, head_size)
        x = x.view(*sz)
        # (batch, head, max_vector_num, head_size)
        return x.permute(0, 2, 1, 3)
 
    def forward(self, hidden_states, attention_mask=None, return_scores=False):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = nn.functional.linear(hidden_states, self.key.weight)
        mixed_value_layer = self.value(hidden_states)
 
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
 
        attention_scores = torch.matmul(
            query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
        if attention_mask is not None:
            attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[
                                  :-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        if return_scores:
            attention_probs = torch.squeeze(attention_probs, dim=1)
            return context_layer, attention_probs
        return context_layer


In [3]:


class CrossAttention(GlobalGraph):
    def __init__(self, hidden_size, attention_head_size=None, num_attention_heads=1, key_hidden_size=None, query_hidden_size=None):
        super(CrossAttention, self).__init__(hidden_size, attention_head_size, num_attention_heads)
        if query_hidden_size is not None:
            self.query = nn.Linear(query_hidden_size, self.all_head_size * self.num_qkv)
        if key_hidden_size is not None:
            self.key = nn.Linear(key_hidden_size, self.all_head_size * self.num_qkv)
            self.value = nn.Linear(key_hidden_size, self.all_head_size * self.num_qkv)
 
    def forward(self, hidden_states_query, hidden_states_key=None, attention_mask=None, return_scores=False):
        mixed_query_layer = self.query(hidden_states_query)
        mixed_key_layer = self.key(hidden_states_key)
        mixed_value_layer = self.value(hidden_states_key)
 
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)
 
        attention_scores = torch.matmul(
            query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
        if attention_mask is not None:
            assert hidden_states_query.shape[1] == attention_mask.shape[1] \
                   and hidden_states_key.shape[1] == attention_mask.shape[2]
            attention_scores = attention_scores +    self.get_extended_attention_mask(attention_mask)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[
                                  :-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        if return_scores:
            return context_layer, torch.squeeze(attention_probs, dim=1)
        return context_layer

In [4]:
cross_attention = CrossAttention(512)
a = torch.randn(2,66666,512)
b = torch.randn(2,1,512)

In [8]:
out = cross_attention( hidden_states_query=a, hidden_states_key=b)

In [17]:
print(a[0])
print(b[0])
print(out[0])

tensor([[ 1.6470, -1.3187, -0.1688,  ...,  1.1961, -1.7335, -1.7147],
        [ 0.2732, -0.3632,  0.0078,  ...,  0.0604,  1.7235, -2.3359],
        [-0.6915, -0.3437,  1.6290,  ..., -1.4262,  1.3686,  1.3066],
        ...,
        [-0.0083, -1.8599, -0.7296,  ...,  0.2520,  0.6743,  1.5492],
        [ 1.7215,  0.8142,  0.1106,  ...,  0.7680, -0.8425,  1.7809],
        [ 0.2066,  0.1040, -0.1978,  ..., -1.3924, -1.0080,  0.9363]])
tensor([[-4.7907e-01,  2.8597e-01,  5.3855e-01,  2.1160e-01, -2.4862e-01,
         -5.8863e-01,  1.0730e+00,  1.9546e+00,  1.1577e+00, -3.3469e-01,
          1.6733e+00,  2.9119e-01, -1.8771e-01,  1.6734e+00, -6.8469e-01,
          4.2377e-01, -2.2570e-01,  6.0401e-01,  6.0938e-01,  8.2346e-01,
         -3.7224e-02,  1.8667e-01, -3.1069e-02,  4.0793e-01,  5.0892e-01,
         -2.3262e-01,  3.5800e-01,  2.0616e+00,  5.8345e-01,  6.5544e-01,
         -2.7921e-01,  1.9590e-01,  4.4828e-01,  1.8016e-01, -8.3591e-01,
          1.5395e+00,  2.8855e-01, -2.3225e-01, 