Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assert if padding mask type is unexpected (#86353) #87106

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 4 additions & 14 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5399,7 +5399,7 @@ def _create_src_lengths_mask(batch_size, src_lengths):
return (src_indices < src_lengths).int().detach()

def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, add_zero_attn=False,
saved_kv=False, same_embed_dim=False, byte_mask=False,
saved_kv=False, same_embed_dim=False,
average_attn_weights=average_attn_weights):
for _ in range(100):
batch_sz, seq_len = [random.randint(2, 10) for r in range(2)]
Expand Down Expand Up @@ -5428,20 +5428,15 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
seq_mask = np.random.randint(0, 2, (1, seq_len))
key_padding_mask = (np.repeat(seq_mask, batch_sz, axis=0) == 1)
key_padding_mask_tensor = torch.from_numpy(key_padding_mask)
if byte_mask:
key_padding_mask_tensor = key_padding_mask_tensor.byte()
decoder_state = np.random.rand(batch_sz, d_model)
K = np.random.rand(*dims)
V = K
Q = np.expand_dims(decoder_state, 1)
attn_mask = np.random.randint(0 , 2, size=(1, seq_len))
attn_mask_tensor = torch.from_numpy(attn_mask).float()
if byte_mask:
attn_mask_tensor = (attn_mask_tensor == 0).byte()
else:
attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf'))
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
attn_mask_tensor = attn_mask_tensor.double()
attn_mask_tensor.masked_fill_(attn_mask_tensor == 0, float('-inf'))
attn_mask_tensor.masked_fill_(attn_mask_tensor > 0, float('0.0'))
attn_mask_tensor = attn_mask_tensor.double()

decoder_state_tensor = torch.from_numpy(decoder_state).to(torch.get_default_dtype())
source_hid_tensor = torch.from_numpy(K).to(torch.get_default_dtype()).transpose(0, 1)
Expand Down Expand Up @@ -5588,10 +5583,6 @@ def test_multihead_attn_all_arguments3():
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
saved_kv=True, same_embed_dim=True)

def test_multihead_attn_all_arguments4():
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
saved_kv=True, same_embed_dim=True, byte_mask=True)

test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn
test_multihead_attn_add_bias_kv() # Test MultiheadAttention with add_bias_kv
test_multihead_attn_no_masking() # Test MultiheadAttention without masking
Expand All @@ -5602,7 +5593,6 @@ def test_multihead_attn_all_arguments4():
with self.assertRaisesRegex(AssertionError, "bias cannot be added to static key."):
test_multihead_attn_all_arguments2() # Test MultiheadAttention with all the argument.
test_multihead_attn_all_arguments3() # Test MultiheadAttention with all the argument.
test_multihead_attn_all_arguments4() # Test MultiheadAttention with all the argument.

def test_multihead_attn_3d_attn_mask(self):
embed_dim = 8
Expand Down
54 changes: 54 additions & 0 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from unittest.mock import patch
import math
from torch.backends.cuda import sdp_kernel
import torch.optim as optim

from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -69,6 +70,59 @@ def test_self_attn_TxT_attn_mask(self):

self.assertEqual(output_mask_4d, output_mask_TxT)

@parametrize("device", device_list)
def test_train_with_pad_and_catch_error(self, device):
iters = 100
pad_mask = torch.tensor([[1, 1, 0, 0]], dtype=torch.bool).to(device)
layer = nn.TransformerEncoderLayer(
d_model=2,
dim_feedforward=4,
nhead=2,
batch_first=True,
activation="gelu",
dropout=0,
)
criterion = nn.MSELoss()
encoder = nn.TransformerEncoder(layer, 2).to(device)
optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9)
encoder.train()
for i in range(iters):
encoder.train()
optimizer.zero_grad()
inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)

outputs = encoder(inputs, src_key_padding_mask=pad_mask)

loss = criterion(outputs[:, 0:2, :], inputs[:, 0:2, :])
loss.backward()
optimizer.step()

with torch.no_grad():
test = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device)

# Expect uint8 type not supported
ex = None
try:
test_train_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.uint8))
except AssertionError as e:
continue
self.assertFalse(e, "Failed to catch unsupported uint8 type exception")

test_train_bool = encoder(test, src_key_padding_mask=pad_mask)
encoder.eval()

# Expect long type not supported
ex = None
try:
test_eval_uint8 = encoder(test, src_key_padding_mask=pad_mask.to(torch.int64))
except AssertionError as e:
continue
self.assertFalse(e, "Failed to catch unsupported Long type exception")

test_eval_bool = encoder(test, src_key_padding_mask=pad_mask)
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")

@parametrize("device", device_list)
@parametrize("nhead", [1, 4, 8])
def test_transformerencoderlayer_src_mask(self, device, nhead):
Expand Down
14 changes: 7 additions & 7 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4966,8 +4966,8 @@ def multi_head_attention_forward(
- value: :math:`(S, E)` or :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
If a FloatTensor is provided, it will be directly added to the value.
If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
Expand Down Expand Up @@ -5037,6 +5037,11 @@ def multi_head_attention_forward(
# set up shape vars
tgt_len, bsz, embed_dim = query.shape
src_len, _, _ = key.shape
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
assert embed_dim == embed_dim_to_check, \
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
if isinstance(embed_dim, torch.Tensor):
Expand Down Expand Up @@ -5089,11 +5094,6 @@ def multi_head_attention_forward(
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

# prep key padding mask
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
key_padding_mask = key_padding_mask.to(torch.bool)

# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
assert static_k is None, "bias cannot be added to static key."
Expand Down
8 changes: 6 additions & 2 deletions torch/nn/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,8 +1031,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding ``key``
value will be ignored.
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
Expand Down Expand Up @@ -1062,6 +1061,11 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
`batch_first` argument is ignored for unbatched inputs.
"""
is_batched = query.dim() == 3
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
why_not_fast_path = ''
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
Expand Down
10 changes: 10 additions & 0 deletions torch/nn/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_ma
Shape:
see the docs in Transformer class.
"""
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
output = src
convert_to_nested = False
first_layer = self.layers[0]
Expand Down Expand Up @@ -442,6 +447,11 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None,
see the docs in Transformer class.
"""

if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported")
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
why_not_sparsity_fast_path = ''
if not src.dim() == 3:
Expand Down