-
Notifications
You must be signed in to change notification settings - Fork 8
/
attention.py
117 lines (93 loc) · 5.48 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) 2020, Soohwan Kim. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from speech_transformer.modules import Linear
from torch import Tensor
from typing import Optional, Tuple
class ScaledDotProductAttention(nn.Module):
"""
Scaled Dot-Product Attention proposed in "Attention Is All You Need"
Compute the dot products of the query with all keys, divide each by sqrt(dim),
and apply a softmax function to obtain the weights on the values
Args: dim, mask
dim (int): dimension of attention
mask (torch.Tensor): tensor containing indices to be masked
Inputs: query, key, value, mask
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
- **mask** (-): tensor containing indices to be masked
Returns: context, attn
- **context**: tensor containing the context vector from attention mechanism.
- **attn**: tensor containing the attention (alignment) from the encoder outputs.
"""
def __init__(self, dim: int) -> None:
super(ScaledDotProductAttention, self).__init__()
self.sqrt_dim = np.sqrt(dim)
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
if mask is not None:
score.masked_fill_(mask, -1e9)
attn = F.softmax(score, -1)
context = torch.bmm(attn, value)
return context, attn
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention proposed in "Attention Is All You Need"
Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
These are concatenated and once again projected, resulting in the final values.
Multi-head attention allows the model to jointly attend to information from different representation
subspaces at different positions.
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
where head_i = Attention(Q · W_q, K · W_k, V · W_v)
Args:
d_model (int): The dimension of keys / values / quries (default: 512)
num_heads (int): The number of attention heads. (default: 8)
Inputs: query, key, value, mask
- **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
- **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
- **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
- **mask** (-): tensor containing indices to be masked
Returns: output, attn
- **output** (batch, output_len, dimensions): tensor containing the attended output features.
- **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
"""
def __init__(self, d_model: int = 512, num_heads: int = 8) -> None:
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "hidden_dim % num_heads should be zero."
self.d_head = int(d_model / num_heads)
self.num_heads = num_heads
self.query_proj = Linear(d_model, self.d_head * num_heads)
self.key_proj = Linear(d_model, self.d_head * num_heads)
self.value_proj = Linear(d_model, self.d_head * num_heads)
self.sqrt_dim = np.sqrt(d_model)
self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
batch_size = value.size(0)
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
if mask is not None:
mask = mask.repeat(self.num_heads, 1, 1)
context, attn = self.scaled_dot_attn(query, key, value, mask)
context = context.view(self.num_heads, batch_size, -1, self.d_head)
context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
return context, attn