Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch_first support in MHA and update docs #839

Merged
merged 21 commits into from
Jul 15, 2020
73 changes: 56 additions & 17 deletions test/data/test_modules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.nn import Linear
from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct
from torch.nn.functional import multi_head_attention_forward as mha_forward
from torch.testing import assert_allclose
from ..common.torchtext_test_case import TorchtextTestCase


Expand All @@ -10,13 +10,13 @@ class TestModels(TorchtextTestCase):
def test_multiheadattention(self):
embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
# Build torchtext MultiheadAttention module
in_proj = InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False),
torch.nn.Linear(embed_dim, embed_dim, bias=False),
torch.nn.Linear(embed_dim, embed_dim, bias=False))
in_proj = InProjContainer(Linear(embed_dim, embed_dim, bias=False),
Linear(embed_dim, embed_dim, bias=False),
Linear(embed_dim, embed_dim, bias=False))

MHA = MultiheadAttentionContainer(nhead, in_proj,
ScaledDotProduct(),
torch.nn.Linear(embed_dim, embed_dim, bias=False))
Linear(embed_dim, embed_dim, bias=False))

query = torch.rand((tgt_len, bsz, embed_dim))
key = value = torch.rand((src_len, bsz, embed_dim))
Expand All @@ -40,10 +40,49 @@ def test_multiheadattention(self):
MHA.out_proj.weight, None,
attn_mask=torch_attn_mask)

assert_allclose(mha_output, torch_mha_output)
self.assertEqual(mha_output, torch_mha_output)
# With bias_k and bias_v, src_len needs to plus 1
attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len + 1).sum(dim=1) / nhead
assert_allclose(attn_weights, torch_mha_weights)
self.assertEqual(attn_weights, torch_mha_weights)

def test_mha_batch_first(self):
embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
# Build torchtext MultiheadAttention module
in_proj = InProjContainer(Linear(embed_dim, embed_dim, bias=False),
Linear(embed_dim, embed_dim, bias=False),
Linear(embed_dim, embed_dim, bias=False))

MHA_batch_1st = MultiheadAttentionContainer(nhead, in_proj,
ScaledDotProduct(),
Linear(embed_dim, embed_dim, bias=False),
batch_first=True)

query = torch.rand((tgt_len, bsz, embed_dim))
key = value = torch.rand((src_len, bsz, embed_dim))
attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool)
bias_k = bias_v = torch.rand((1, 1, embed_dim))
mha_output_1st, attn_weights_1st = MHA_batch_1st(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1),
attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)),
bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1),
bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1))

# Use torch.nn.functional.multi_head_attention_forward
torch_attn_mask = torch.zeros((tgt_len, src_len)).masked_fill_(attn_mask_2D, float('-inf'))
in_proj_weight = torch.cat([MHA_batch_1st.in_proj_container.query_proj.weight,
MHA_batch_1st.in_proj_container.key_proj.weight,
MHA_batch_1st.in_proj_container.value_proj.weight])
torch_mha_output, torch_mha_weights = mha_forward(query, key, value,
embed_dim, nhead,
in_proj_weight, None,
bias_k, bias_v,
False, 0.0,
MHA_batch_1st.out_proj.weight, None,
attn_mask=torch_attn_mask)

self.assertEqual(mha_output_1st.transpose(0, 1), torch_mha_output)
# With bias_k and bias_v, src_len needs to plus 1
attn_weights_1st = attn_weights_1st.view(bsz, nhead, tgt_len, src_len + 1).sum(dim=1) / nhead
self.assertEqual(attn_weights_1st, torch_mha_weights)

def test_broadcast_scaled_dot_product(self):
embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64
Expand All @@ -61,15 +100,15 @@ def test_broadcast_scaled_dot_product(self):
sdp_attn_output, sdp_attn_weights = SDP(query, key.expand(src_len, bsz * nhead, embed_dim),
value.expand(src_len, bsz * nhead, embed_dim),
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
assert_allclose(sdp_attn_output, sdp_attn_output_full)
assert_allclose(sdp_attn_weights, sdp_attn_weights_full)
self.assertEqual(sdp_attn_output, sdp_attn_output_full)
self.assertEqual(sdp_attn_weights, sdp_attn_weights_full)

# key/value have a batch size of 1 while query has a batch size of bsz * nhead
sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, bsz * nhead, embed_dim),
key, value,
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
assert_allclose(sdp_attn_output, sdp_attn_output_full)
assert_allclose(sdp_attn_weights, sdp_attn_weights_full)
self.assertEqual(sdp_attn_output, sdp_attn_output_full)
self.assertEqual(sdp_attn_weights, sdp_attn_weights_full)

# key/value have a size of (3, 3, src_len, bsz * nhead, embed_dim)
# while query has a size of (tgt_len, 1, embed_dim)
Expand All @@ -79,8 +118,8 @@ def test_broadcast_scaled_dot_product(self):
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
assert list(sdp_attn_output.size()) == [3, 3, tgt_len, bsz * nhead, embed_dim]
assert list(sdp_attn_weights.size()) == [3, 3, bsz * nhead, tgt_len, embed_dim]
assert_allclose(sdp_attn_output[2][2], sdp_attn_output_full)
assert_allclose(sdp_attn_weights[2][2], sdp_attn_weights_full)
self.assertEqual(sdp_attn_output[2][2], sdp_attn_output_full)
self.assertEqual(sdp_attn_weights[2][2], sdp_attn_weights_full)
# dim -2 is not equal to neither key/value's dim -2 or 1
with self.assertRaises(RuntimeError):
SDP(query.expand(tgt_len, 2, embed_dim), key.expand(3, 3, src_len, bsz * nhead, embed_dim),
Expand All @@ -95,8 +134,8 @@ def test_broadcast_scaled_dot_product(self):
attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len))
assert list(sdp_attn_output.size()) == [1, 2, 3, tgt_len, bsz * nhead, embed_dim]
assert list(sdp_attn_weights.size()) == [1, 2, 3, bsz * nhead, tgt_len, embed_dim]
assert_allclose(sdp_attn_output[0][1][2], sdp_attn_output_full)
assert_allclose(sdp_attn_weights[0][1][2], sdp_attn_weights_full)
self.assertEqual(sdp_attn_output[0][1][2], sdp_attn_output_full)
self.assertEqual(sdp_attn_weights[0][1][2], sdp_attn_weights_full)
# key dim -2 is not equal to value dim -2
with self.assertRaisesRegex(AssertionError, "Shape of key, value must match"):
SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), key.expand(src_len, 2, embed_dim),
Expand All @@ -114,8 +153,8 @@ def test_broadcast_scaled_dot_product(self):
key.expand(src_len, bsz * nhead, embed_dim),
value.expand(src_len, bsz * nhead, embed_dim),
attn_mask=attn_mask_2D.expand(1, tgt_len, src_len))
assert_allclose(sdp_attn_output, sdp_attn_output_full)
assert_allclose(sdp_attn_weights, sdp_attn_weights_full)
self.assertEqual(sdp_attn_output, sdp_attn_output_full)
self.assertEqual(sdp_attn_weights, sdp_attn_weights_full)
# attn_mask's dim -3 is not equal to neither batch size or 1
with self.assertRaisesRegex(RuntimeError, "The size of the attn_mask is not correct."):
SDP(query.expand(tgt_len, bsz * nhead, embed_dim), key.expand(src_len, bsz * nhead, embed_dim),
Expand Down
35 changes: 28 additions & 7 deletions torchtext/modules/multiheadattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@


class MultiheadAttentionContainer(torch.nn.Module):
def __init__(self, nhead, in_proj_container, attention_layer, out_proj):
def __init__(self, nhead, in_proj_container, attention_layer, out_proj, batch_first=False):
r""" A multi-head attention container

Args:
nhead: the number of heads in the multiheadattention model
in_proj_container: A container of multi-head in-projection linear layers (a.k.a nn.Linear).
attention_layer: The attention layer.
attention_layer: The custom attention layer. The input sent from MHA container to the attention layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also take care of broadcasting?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The custom attention layer needs to take care of broadcasting. Updated the doc to reflect this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd then augment the shape to (..., seq, batch, feature) and explain what that means and also that it's optional, i.e. enough to only handle 3-dim.

is in the shape of `(seq, batch, feature)` while the output shape of the attention layer
is expected to be `(seq, batch, feature)`.
out_proj: The multi-head out-projection layer (a.k.a nn.Linear).
batch_first: If ``True``, then the input and output tensors are provided
as `(batch, seq, feature)`. Default: ``False``

Examples::
>>> import torch
Expand All @@ -33,6 +37,7 @@ def __init__(self, nhead, in_proj_container, attention_layer, out_proj):
self.in_proj_container = in_proj_container
self.attention_layer = attention_layer
self.out_proj = out_proj
self.batch_first = batch_first

def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -60,6 +65,9 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
where where L is the target length, S is the sequence length, H is the number of attention heads,
N is the batch size, and E is the embedding dimension.
"""
if self.batch_first:
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)

tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1)
q, k, v = self.in_proj_container(query, key, value)
assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads"
Expand All @@ -78,20 +86,26 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
bias_k=bias_k, bias_v=bias_v)
attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
attn_output = self.out_proj(attn_output)

if self.batch_first:
attn_output = attn_output.transpose(-3, -2)

return attn_output, attn_output_weights


class ScaledDotProduct(torch.nn.Module):

def __init__(self, dropout=0.0):
def __init__(self, dropout=0.0, batch_first=False):
r"""Processes a projected query and key-value pair to apply
scaled dot product attention.

Args:
dropout (float): probability of dropping an attention weight.
batch_first: If ``True``, then the input and output tensors are provided
as `(batch, seq, feature)`. Default: ``False``

Examples::
>>> SDP = torchtext.models.ScaledDotProduct(0.1)
>>> SDP = torchtext.modules.ScaledDotProduct(dropout=0.1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we mirror the pytorch path conventions here?

torchtext.nn and torchtext.nn.functional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other two domains use torchvision/audio.models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this isn't a model, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, Will fix it.

>>> q = torch.randn(256, 21, 3)
>>> k = v = torch.randn(256, 21, 3)
>>> attn_output, attn_weights = SDP(q, k, v)
Expand All @@ -100,6 +114,7 @@ def __init__(self, dropout=0.0):
"""
super(ScaledDotProduct, self).__init__()
self.dropout = dropout
self.batch_first = batch_first

def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -130,6 +145,9 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
where L is the target length, S is the source length, H is the number
of attention heads, N is the batch size, and E is the embedding dimension.
"""
if self.batch_first:
query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)

if bias_k is not None and bias_v is not None:
assert key.size(-1) == bias_k.size(-1) and key.size(-2) == bias_k.size(-2) and bias_k.size(-3) == 1, \
"Shape of bias_k is not supported"
Expand All @@ -138,8 +156,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key = torch.cat([key, bias_k])
value = torch.cat([value, bias_v])
if attn_mask is not None:
_attn_mask = attn_mask
attn_mask = torch.nn.functional.pad(_attn_mask, (0, 1))
attn_mask = torch.nn.functional.pad(attn_mask, (0, 1))

tgt_len, head_dim = query.size(-3), query.size(-1)
assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal."
Expand All @@ -166,7 +183,11 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1)
attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_output_weights, value)
return attn_output.transpose(-2, -3), attn_output_weights

if self.batch_first:
return attn_output, attn_output_weights
else:
return attn_output.transpose(-3, -2), attn_output_weights


class InProjContainer(torch.nn.Module):
Expand Down