In [3]:
import torch
import math

def sequence_mask(x, valid_lens, value=0):
    max_len = x.size(1)
    mask = torch.arange(max_len, dtype=torch.float32,device=x.device)[None,:]<valid_lens[:,None]
    x[~mask] = value
    return x

def masked_softmax(x, valid_lens):
    if valid_lens is None:
        return torch.softmax(x,dim=-1)
    else:
        shape = x.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        x = sequence_mask(x.reshape(-1, shape[-1]), valid_lens, -1e6)
        return torch.softmax(x.reshape(shape),dim=-1)
    
class DotProductAttention(torch.nn.Module):
    def __init__(self, dropout, **kwargs) -> None:
        super().__init__(**kwargs)
        self.dropout = torch.nn.Dropout(dropout)
    
    def forward(self, queries, keys, values, valid_lens):
        d= queries.size(-1)
        scores = torch.bmm(queries, keys.transpose(1,2))/math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights),values)



In [None]:

def transpose_qkv(x, num_heads):
    print(x.is_contiguous(),x.data_ptr())
    x = x.reshape(x.shape[0], x.shape[1], num_heads, -1)
    print(x.is_contiguous(),x.data_ptr())
    x = x.permute(0,2,1,3)
    print(x.is_contiguous(),x.data_ptr())
    x = x.reshape(-1,x.shape[2],x.shape[3])
    print(x.is_contiguous(),x.data_ptr())
    return x

def transpose_output(x, num_heads):
    x = x.reshape(-1, num_heads, x.shape[1], x.shape[2])
    x = x.permute(0,2,1,3)
    x = x.reshape(x.shape[0],x.shape[1],-1)
    return x


x = torch.rand((2,3,4))
y = transpose_qkv(x,2)
#x.shape,y.shape


In [25]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, query_size, key_size, value_size, hidden_size, num_heads, dropout, bias=False, **kwargs) -> None:
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.w_q = torch.nn.Linear(query_size, hidden_size, bias=bias)
        self.w_k = torch.nn.Linear(key_size, hidden_size, bias=bias)
        self.w_v = torch.nn.Linear(value_size, hidden_size, bias=bias)
        self.w_o = torch.nn.Linear(hidden_size, hidden_size, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.w_q(queries), self.num_heads)
        keys = transpose_qkv(self.w_k(keys), self.num_heads)
        values = transpose_qkv(self.w_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens,repeats = self.num_heads, dim = 0)

        output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.w_o(output_concat)


In [26]:
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (w_q): Linear(in_features=100, out_features=100, bias=False)
  (w_k): Linear(in_features=100, out_features=100, bias=False)
  (w_v): Linear(in_features=100, out_features=100, bias=False)
  (w_o): Linear(in_features=100, out_features=100, bias=False)
)

In [29]:
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

<built-in method contiguous of Tensor object at 0x0000024F3E34C450>
<built-in method contiguous of Tensor object at 0x0000024F3E1E76A0>
<built-in method contiguous of Tensor object at 0x0000024F3E34FE20>
<built-in method contiguous of Tensor object at 0x0000024F3E1E76A0>
<built-in method contiguous of Tensor object at 0x0000024F3E34C450>
<built-in method contiguous of Tensor object at 0x0000024F3E34FE20>
<built-in method contiguous of Tensor object at 0x0000024F3E34DEE0>
<built-in method contiguous of Tensor object at 0x0000024F3E34FE20>
<built-in method contiguous of Tensor object at 0x0000024F3E34C450>
<built-in method contiguous of Tensor object at 0x0000024F3E34DEE0>
<built-in method contiguous of Tensor object at 0x0000024F3E34CEA0>
<built-in method contiguous of Tensor object at 0x0000024F3E34DEE0>


torch.Size([2, 4, 100])