-
Notifications
You must be signed in to change notification settings - Fork 814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batch_first support in MHA and update docs #839
Changes from 11 commits
9a867b0
94528ad
abecdfe
2bfb5bd
6764ebb
533ae55
f28701b
63deff1
6882a7a
2e058e8
10fb958
408013b
6658afc
bcb752c
ae996ca
d85de52
2d6e782
da7d63d
b03d474
a73075f
c74a914
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,14 +3,18 @@ | |
|
||
|
||
class MultiheadAttentionContainer(torch.nn.Module): | ||
def __init__(self, nhead, in_proj_container, attention_layer, out_proj): | ||
def __init__(self, nhead, in_proj_container, attention_layer, out_proj, batch_first=False): | ||
r""" A multi-head attention container | ||
|
||
Args: | ||
nhead: the number of heads in the multiheadattention model | ||
in_proj_container: A container of multi-head in-projection linear layers (a.k.a nn.Linear). | ||
attention_layer: The attention layer. | ||
attention_layer: The custom attention layer. The input sent from MHA container to the attention layer | ||
is in the shape of `(seq, batch, feature)` while the output shape of the attention layer | ||
is expected to be `(seq, batch, feature)`. | ||
out_proj: The multi-head out-projection layer (a.k.a nn.Linear). | ||
batch_first: If ``True``, then the input and output tensors are provided | ||
as `(batch, seq, feature)`. Default: ``False`` | ||
|
||
Examples:: | ||
>>> import torch | ||
|
@@ -33,6 +37,7 @@ def __init__(self, nhead, in_proj_container, attention_layer, out_proj): | |
self.in_proj_container = in_proj_container | ||
self.attention_layer = attention_layer | ||
self.out_proj = out_proj | ||
self.batch_first = batch_first | ||
|
||
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, | ||
attn_mask: Optional[torch.Tensor] = None, | ||
|
@@ -60,6 +65,9 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, | |
where where L is the target length, S is the sequence length, H is the number of attention heads, | ||
N is the batch size, and E is the embedding dimension. | ||
""" | ||
if self.batch_first: | ||
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2) | ||
|
||
tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1) | ||
q, k, v = self.in_proj_container(query, key, value) | ||
assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads" | ||
|
@@ -78,20 +86,26 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, | |
bias_k=bias_k, bias_v=bias_v) | ||
attn_output = attn_output.reshape(tgt_len, bsz, embed_dim) | ||
attn_output = self.out_proj(attn_output) | ||
|
||
if self.batch_first: | ||
attn_output = attn_output.transpose(-3, -2) | ||
|
||
return attn_output, attn_output_weights | ||
|
||
|
||
class ScaledDotProduct(torch.nn.Module): | ||
|
||
def __init__(self, dropout=0.0): | ||
def __init__(self, dropout=0.0, batch_first=False): | ||
r"""Processes a projected query and key-value pair to apply | ||
scaled dot product attention. | ||
|
||
Args: | ||
dropout (float): probability of dropping an attention weight. | ||
batch_first: If ``True``, then the input and output tensors are provided | ||
as `(batch, seq, feature)`. Default: ``False`` | ||
|
||
Examples:: | ||
>>> SDP = torchtext.models.ScaledDotProduct(0.1) | ||
>>> SDP = torchtext.modules.ScaledDotProduct(dropout=0.1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we mirror the pytorch path conventions here? torchtext.nn and torchtext.nn.functional? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other two domains use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but this isn't a model, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, Will fix it. |
||
>>> q = torch.randn(256, 21, 3) | ||
>>> k = v = torch.randn(256, 21, 3) | ||
>>> attn_output, attn_weights = SDP(q, k, v) | ||
|
@@ -100,6 +114,7 @@ def __init__(self, dropout=0.0): | |
""" | ||
super(ScaledDotProduct, self).__init__() | ||
self.dropout = dropout | ||
self.batch_first = batch_first | ||
|
||
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, | ||
attn_mask: Optional[torch.Tensor] = None, | ||
|
@@ -130,6 +145,9 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, | |
where L is the target length, S is the source length, H is the number | ||
of attention heads, N is the batch size, and E is the embedding dimension. | ||
""" | ||
if self.batch_first: | ||
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2) | ||
|
||
if bias_k is not None and bias_v is not None: | ||
assert key.size(-1) == bias_k.size(-1) and key.size(-2) == bias_k.size(-2) and bias_k.size(-3) == 1, \ | ||
"Shape of bias_k is not supported" | ||
|
@@ -138,8 +156,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, | |
key = torch.cat([key, bias_k]) | ||
value = torch.cat([value, bias_v]) | ||
if attn_mask is not None: | ||
_attn_mask = attn_mask | ||
attn_mask = torch.nn.functional.pad(_attn_mask, (0, 1)) | ||
attn_mask = torch.nn.functional.pad(attn_mask, (0, 1)) | ||
|
||
tgt_len, head_dim = query.size(-3), query.size(-1) | ||
assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal." | ||
|
@@ -166,7 +183,11 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, | |
attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) | ||
attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) | ||
attn_output = torch.matmul(attn_output_weights, value) | ||
return attn_output.transpose(-2, -3), attn_output_weights | ||
|
||
if self.batch_first: | ||
return attn_output, attn_output_weights | ||
else: | ||
return attn_output.transpose(-3, -2), attn_output_weights | ||
|
||
|
||
class InProjContainer(torch.nn.Module): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this also take care of broadcasting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The custom attention layer needs to take care of broadcasting. Updated the doc to reflect this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd then augment the shape to
(..., seq, batch, feature)
and explain what that means and also that it's optional, i.e. enough to only handle 3-dim.