diff --git a/test/data/test_jit.py b/test/data/test_jit.py index dff0d26b9a..8ac0284466 100644 --- a/test/data/test_jit.py +++ b/test/data/test_jit.py @@ -1,5 +1,5 @@ import torch -from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct +from torchtext.nn import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase diff --git a/test/data/test_modules.py b/test/data/test_modules.py index 0de4e93239..7f88d131eb 100644 --- a/test/data/test_modules.py +++ b/test/data/test_modules.py @@ -1,7 +1,7 @@ import torch -from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct +from torch.nn import Linear +from torchtext.nn import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct from torch.nn.functional import multi_head_attention_forward as mha_forward -from torch.testing import assert_allclose from ..common.torchtext_test_case import TorchtextTestCase @@ -10,13 +10,13 @@ class TestModels(TorchtextTestCase): def test_multiheadattention(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 # Build torchtext MultiheadAttention module - in_proj = InProjContainer(torch.nn.Linear(embed_dim, embed_dim, bias=False), - torch.nn.Linear(embed_dim, embed_dim, bias=False), - torch.nn.Linear(embed_dim, embed_dim, bias=False)) + in_proj = InProjContainer(Linear(embed_dim, embed_dim, bias=False), + Linear(embed_dim, embed_dim, bias=False), + Linear(embed_dim, embed_dim, bias=False)) MHA = MultiheadAttentionContainer(nhead, in_proj, ScaledDotProduct(), - torch.nn.Linear(embed_dim, embed_dim, bias=False)) + Linear(embed_dim, embed_dim, bias=False)) query = torch.rand((tgt_len, bsz, embed_dim)) key = value = torch.rand((src_len, bsz, embed_dim)) @@ -40,10 +40,49 @@ def test_multiheadattention(self): MHA.out_proj.weight, None, attn_mask=torch_attn_mask) - assert_allclose(mha_output, torch_mha_output) + self.assertEqual(mha_output, torch_mha_output) # With bias_k and bias_v, src_len needs to plus 1 attn_weights = attn_weights.view(bsz, nhead, tgt_len, src_len + 1).sum(dim=1) / nhead - assert_allclose(attn_weights, torch_mha_weights) + self.assertEqual(attn_weights, torch_mha_weights) + + def test_mha_batch_first(self): + embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 + # Build torchtext MultiheadAttention module + in_proj = InProjContainer(Linear(embed_dim, embed_dim, bias=False), + Linear(embed_dim, embed_dim, bias=False), + Linear(embed_dim, embed_dim, bias=False)) + + MHA_batch_1st = MultiheadAttentionContainer(nhead, in_proj, + ScaledDotProduct(), + Linear(embed_dim, embed_dim, bias=False), + batch_first=True) + + query = torch.rand((tgt_len, bsz, embed_dim)) + key = value = torch.rand((src_len, bsz, embed_dim)) + attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) + bias_k = bias_v = torch.rand((1, 1, embed_dim)) + mha_output_1st, attn_weights_1st = MHA_batch_1st(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1), + attn_mask=torch.stack([attn_mask_2D] * (bsz * nhead)), + bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1), + bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1)) + + # Use torch.nn.functional.multi_head_attention_forward + torch_attn_mask = torch.zeros((tgt_len, src_len)).masked_fill_(attn_mask_2D, float('-inf')) + in_proj_weight = torch.cat([MHA_batch_1st.in_proj_container.query_proj.weight, + MHA_batch_1st.in_proj_container.key_proj.weight, + MHA_batch_1st.in_proj_container.value_proj.weight]) + torch_mha_output, torch_mha_weights = mha_forward(query, key, value, + embed_dim, nhead, + in_proj_weight, None, + bias_k, bias_v, + False, 0.0, + MHA_batch_1st.out_proj.weight, None, + attn_mask=torch_attn_mask) + + self.assertEqual(mha_output_1st.transpose(0, 1), torch_mha_output) + # With bias_k and bias_v, src_len needs to plus 1 + attn_weights_1st = attn_weights_1st.view(bsz, nhead, tgt_len, src_len + 1).sum(dim=1) / nhead + self.assertEqual(attn_weights_1st, torch_mha_weights) def test_broadcast_scaled_dot_product(self): embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 @@ -61,15 +100,15 @@ def test_broadcast_scaled_dot_product(self): sdp_attn_output, sdp_attn_weights = SDP(query, key.expand(src_len, bsz * nhead, embed_dim), value.expand(src_len, bsz * nhead, embed_dim), attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) - assert_allclose(sdp_attn_output, sdp_attn_output_full) - assert_allclose(sdp_attn_weights, sdp_attn_weights_full) + self.assertEqual(sdp_attn_output, sdp_attn_output_full) + self.assertEqual(sdp_attn_weights, sdp_attn_weights_full) # key/value have a batch size of 1 while query has a batch size of bsz * nhead sdp_attn_output, sdp_attn_weights = SDP(query.expand(tgt_len, bsz * nhead, embed_dim), key, value, attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) - assert_allclose(sdp_attn_output, sdp_attn_output_full) - assert_allclose(sdp_attn_weights, sdp_attn_weights_full) + self.assertEqual(sdp_attn_output, sdp_attn_output_full) + self.assertEqual(sdp_attn_weights, sdp_attn_weights_full) # key/value have a size of (3, 3, src_len, bsz * nhead, embed_dim) # while query has a size of (tgt_len, 1, embed_dim) @@ -79,8 +118,8 @@ def test_broadcast_scaled_dot_product(self): attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) assert list(sdp_attn_output.size()) == [3, 3, tgt_len, bsz * nhead, embed_dim] assert list(sdp_attn_weights.size()) == [3, 3, bsz * nhead, tgt_len, embed_dim] - assert_allclose(sdp_attn_output[2][2], sdp_attn_output_full) - assert_allclose(sdp_attn_weights[2][2], sdp_attn_weights_full) + self.assertEqual(sdp_attn_output[2][2], sdp_attn_output_full) + self.assertEqual(sdp_attn_weights[2][2], sdp_attn_weights_full) # dim -2 is not equal to neither key/value's dim -2 or 1 with self.assertRaises(RuntimeError): SDP(query.expand(tgt_len, 2, embed_dim), key.expand(3, 3, src_len, bsz * nhead, embed_dim), @@ -95,8 +134,8 @@ def test_broadcast_scaled_dot_product(self): attn_mask=attn_mask_2D.expand(bsz * nhead, tgt_len, src_len)) assert list(sdp_attn_output.size()) == [1, 2, 3, tgt_len, bsz * nhead, embed_dim] assert list(sdp_attn_weights.size()) == [1, 2, 3, bsz * nhead, tgt_len, embed_dim] - assert_allclose(sdp_attn_output[0][1][2], sdp_attn_output_full) - assert_allclose(sdp_attn_weights[0][1][2], sdp_attn_weights_full) + self.assertEqual(sdp_attn_output[0][1][2], sdp_attn_output_full) + self.assertEqual(sdp_attn_weights[0][1][2], sdp_attn_weights_full) # key dim -2 is not equal to value dim -2 with self.assertRaisesRegex(AssertionError, "Shape of key, value must match"): SDP(query.expand(1, 2, 3, tgt_len, bsz * nhead, embed_dim), key.expand(src_len, 2, embed_dim), @@ -114,8 +153,8 @@ def test_broadcast_scaled_dot_product(self): key.expand(src_len, bsz * nhead, embed_dim), value.expand(src_len, bsz * nhead, embed_dim), attn_mask=attn_mask_2D.expand(1, tgt_len, src_len)) - assert_allclose(sdp_attn_output, sdp_attn_output_full) - assert_allclose(sdp_attn_weights, sdp_attn_weights_full) + self.assertEqual(sdp_attn_output, sdp_attn_output_full) + self.assertEqual(sdp_attn_weights, sdp_attn_weights_full) # attn_mask's dim -3 is not equal to neither batch size or 1 with self.assertRaisesRegex(RuntimeError, "The size of the attn_mask is not correct."): SDP(query.expand(tgt_len, bsz * nhead, embed_dim), key.expand(src_len, bsz * nhead, embed_dim), diff --git a/torchtext/__init__.py b/torchtext/__init__.py index 9ce2ce9c08..e7c08bb239 100644 --- a/torchtext/__init__.py +++ b/torchtext/__init__.py @@ -1,5 +1,5 @@ from . import data -from . import modules +from . import nn from . import datasets from . import utils from . import vocab @@ -12,7 +12,7 @@ pass __all__ = ['data', - 'modules', + 'nn', 'datasets', 'utils', 'vocab', diff --git a/torchtext/nn/__init__.py b/torchtext/nn/__init__.py new file mode 100644 index 0000000000..c48d6de70e --- /dev/null +++ b/torchtext/nn/__init__.py @@ -0,0 +1 @@ +from .modules import * # noqa: F401,F403 diff --git a/torchtext/modules/__init__.py b/torchtext/nn/modules/__init__.py similarity index 100% rename from torchtext/modules/__init__.py rename to torchtext/nn/modules/__init__.py diff --git a/torchtext/modules/multiheadattention.py b/torchtext/nn/modules/multiheadattention.py similarity index 79% rename from torchtext/modules/multiheadattention.py rename to torchtext/nn/modules/multiheadattention.py index f6d8e7675d..c67d7c6560 100644 --- a/torchtext/modules/multiheadattention.py +++ b/torchtext/nn/modules/multiheadattention.py @@ -3,14 +3,20 @@ class MultiheadAttentionContainer(torch.nn.Module): - def __init__(self, nhead, in_proj_container, attention_layer, out_proj): + def __init__(self, nhead, in_proj_container, attention_layer, out_proj, batch_first=False): r""" A multi-head attention container Args: nhead: the number of heads in the multiheadattention model in_proj_container: A container of multi-head in-projection linear layers (a.k.a nn.Linear). - attention_layer: The attention layer. + attention_layer: The custom attention layer. The input sent from MHA container to the attention layer + is in the shape of `(..., L, N * H, E / H)` for query and `(..., S, N * H, E / H)` for key/value + while the output shape of the attention layer is expected to be `(..., L, N * H, E / H)`. + The attention_layer needs to support broadcast if users want the overall MultiheadAttentionContainer + with broadcast. out_proj: The multi-head out-projection layer (a.k.a nn.Linear). + batch_first: If ``True``, then the input and output tensors are provided + as `(..., N, L, E)`. Default: ``False`` Examples:: >>> import torch @@ -33,6 +39,7 @@ def __init__(self, nhead, in_proj_container, attention_layer, out_proj): self.in_proj_container = in_proj_container self.attention_layer = attention_layer self.out_proj = out_proj + self.batch_first = batch_first def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, @@ -48,18 +55,24 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, Shape: - Inputs: - - query: :math:`(L, N, E)` - - key: :math:`(S, N, E)` - - value: :math:`(S, N, E)` + - query: :math:`(..., L, N, E)` + - key: :math:`(..., S, N, E)` + - value: :math:`(..., S, N, E)` - attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer. - Outputs: - - attn_output: :math:`(L, N, E)` + - attn_output: :math:`(..., L, N, E)` - attn_output_weights: :math:`(N * H, L, S)` + Note: It's optioinal to have the query/key/value inputs with more than three dimensions (for broadcast purpose). + The MultiheadAttentionContainer module will operate on the last three dimensions. + where where L is the target length, S is the sequence length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. """ + if self.batch_first: + query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2) + tgt_len, src_len, bsz, embed_dim = query.size(-3), key.size(-3), query.size(-2), query.size(-1) q, k, v = self.in_proj_container(query, key, value) assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads" @@ -78,20 +91,26 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias_k=bias_k, bias_v=bias_v) attn_output = attn_output.reshape(tgt_len, bsz, embed_dim) attn_output = self.out_proj(attn_output) + + if self.batch_first: + attn_output = attn_output.transpose(-3, -2) + return attn_output, attn_output_weights class ScaledDotProduct(torch.nn.Module): - def __init__(self, dropout=0.0): + def __init__(self, dropout=0.0, batch_first=False): r"""Processes a projected query and key-value pair to apply scaled dot product attention. Args: dropout (float): probability of dropping an attention weight. + batch_first: If ``True``, then the input and output tensors are provided + as `(batch, seq, feature)`. Default: ``False`` Examples:: - >>> SDP = torchtext.models.ScaledDotProduct(0.1) + >>> SDP = torchtext.nn.ScaledDotProduct(dropout=0.1) >>> q = torch.randn(256, 21, 3) >>> k = v = torch.randn(256, 21, 3) >>> attn_output, attn_weights = SDP(q, k, v) @@ -100,6 +119,7 @@ def __init__(self, dropout=0.0): """ super(ScaledDotProduct, self).__init__() self.dropout = dropout + self.batch_first = batch_first def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, @@ -118,18 +138,24 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, non-None to both arguments in order to activate them. Shape: - - query: :math:`(L, N * H, E / H)` - - key: :math:`(S, N * H, E / H)` - - value: :math:`(S, N * H, E / H)` + - query: :math:`(..., L, N * H, E / H)` + - key: :math:`(..., S, N * H, E / H)` + - value: :math:`(..., S, N * H, E / H)` - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. - bias_k and bias_v:bias: :math:`(1, N * H, E / H)` - - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)` + - Output: :math:`(..., L, N * H, E / H)`, :math:`(N * H, L, S)` + + Note: It's optioinal to have the query/key/value inputs with more than three dimensions (for broadcast purpose). + The ScaledDotProduct module will operate on the last three dimensions. where L is the target length, S is the source length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. """ + if self.batch_first: + query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2) + if bias_k is not None and bias_v is not None: assert key.size(-1) == bias_k.size(-1) and key.size(-2) == bias_k.size(-2) and bias_k.size(-3) == 1, \ "Shape of bias_k is not supported" @@ -138,8 +164,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key = torch.cat([key, bias_k]) value = torch.cat([value, bias_v]) if attn_mask is not None: - _attn_mask = attn_mask - attn_mask = torch.nn.functional.pad(_attn_mask, (0, 1)) + attn_mask = torch.nn.functional.pad(attn_mask, (0, 1)) tgt_len, head_dim = query.size(-3), query.size(-1) assert query.size(-1) == key.size(-1) == value.size(-1), "The feature dim of query, key, value must be equal." @@ -166,7 +191,11 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_output_weights, value) - return attn_output.transpose(-2, -3), attn_output_weights + + if self.batch_first: + return attn_output, attn_output_weights + else: + return attn_output.transpose(-3, -2), attn_output_weights class InProjContainer(torch.nn.Module):