-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Fix TransformerEncoderLayer for bias=False #116760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6c128bf
a17de2d
8045d2f
c9db481
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2374,52 +2374,57 @@ def module_inputs_torch_nn_TransformerEncoderLayer(module_info, device, dtype, r | |
make_input((2, 3, 4)) | ||
), | ||
desc='gelu_activation' | ||
), ] | ||
), | ||
ModuleInput( | ||
constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False), | ||
forward_input=FunctionInput( | ||
make_input((2, 3, 4)) | ||
), | ||
desc='no_bias' | ||
),] | ||
|
||
# Samples below are for validating the no-batch-dim support. | ||
key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) | ||
attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3))) | ||
for src_mask, src_key_padding_mask, norm_first in itertools.product(attn_masks, key_padding_masks, (True, False)): | ||
for src_mask, src_key_padding_mask, norm_first, batch_first, bias in \ | ||
itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)): | ||
samples.append( | ||
ModuleInput( | ||
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, | ||
dropout=0.0, batch_first=True, norm_first=norm_first), | ||
forward_input=FunctionInput( | ||
make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask | ||
), | ||
reference_fn=partial(no_batch_dim_reference_fn, | ||
batch_first=True, kwargs_to_batchify={'src_key_padding_mask': 0}), | ||
desc='no_batch_dim_batch_first' | ||
)) | ||
|
||
samples.append( | ||
ModuleInput( | ||
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=False, norm_first=norm_first), | ||
dropout=0.0, batch_first=batch_first, | ||
norm_first=norm_first, bias=bias), | ||
forward_input=FunctionInput( | ||
make_input((3, 4)), src_mask=src_mask, src_key_padding_mask=src_key_padding_mask | ||
), | ||
reference_fn=partial(no_batch_dim_reference_fn, | ||
batch_first=False, kwargs_to_batchify={'src_key_padding_mask': 0}), | ||
desc='no_batch_dim' | ||
batch_first=batch_first, kwargs_to_batchify={'src_key_padding_mask': 0}), | ||
desc=f'no_batch_dim_batch_first_{batch_first}' | ||
)) | ||
|
||
# Samples below where we pass reference_fn are for validating the fast path, | ||
# since the fast path requires no_grad mode, we run the fast path in .eval() | ||
# and no_grad() in the reference_fn and verify that against the results in train mode. | ||
def fast_path_reference_fn(module, parameters, *args, **kwargs): | ||
assert not module.training | ||
module = module.train(True) | ||
output = module(*args, **kwargs) | ||
module = module.train(False) | ||
assert module.training | ||
module.train(False) | ||
with torch.no_grad(): | ||
output = module(*args, **kwargs) | ||
module.train(True) | ||
return output | ||
|
||
if not training: | ||
for norm_first in (True, False): | ||
if training: | ||
for norm_first, bias in itertools.product((True, False), (True, False)): | ||
samples.append( | ||
ModuleInput( | ||
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first), | ||
constructor_input=FunctionInput( | ||
4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias | ||
), | ||
forward_input=FunctionInput( | ||
make_input((2, 3, 4)), | ||
), | ||
reference_fn=fast_path_reference_fn, | ||
desc="fast_path_norm_first" if norm_first else "fast_path" | ||
# fastpath doesn't run when bias=False | ||
reference_fn=fast_path_reference_fn if bias else None, | ||
desc=f'fastpath_{bias}_norm_first_{norm_first}' | ||
) | ||
) | ||
|
||
|
@@ -2443,40 +2448,51 @@ def module_inputs_torch_nn_TransformerDecoderLayer(module_info, device, dtype, r | |
make_input((2, 3, 4)), make_input((2, 3, 4)) | ||
), | ||
desc='gelu_activation' | ||
), | ||
ModuleInput( | ||
constructor_input=FunctionInput(4, 2, 8, 0.0, bias=False), | ||
forward_input=FunctionInput( | ||
make_input((2, 3, 4)), make_input((2, 3, 4)) | ||
), | ||
desc='no_bias' | ||
), ] | ||
|
||
# Samples below are for validating the no-batch-dim support. | ||
key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) | ||
attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3))) | ||
for tgt_mask, tgt_key_padding_mask, norm_first in itertools.product(attn_masks, key_padding_masks, (True, False)): | ||
for tgt_mask, tgt_key_padding_mask, norm_first, bias, batch_first in \ | ||
itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)): | ||
# Using same mask for tgt and memory | ||
memory_mask = tgt_mask | ||
memory_key_padding_mask = tgt_key_padding_mask | ||
samples.append( | ||
ModuleInput( | ||
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, | ||
dropout=0.0, batch_first=True, norm_first=norm_first), | ||
dropout=0.0, batch_first=batch_first, | ||
norm_first=norm_first, bias=bias), | ||
forward_input=FunctionInput( | ||
make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask, | ||
tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask | ||
), | ||
reference_fn=partial(no_batch_dim_reference_fn, | ||
batch_first=True, | ||
batch_first=batch_first, | ||
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}), | ||
desc='no_batch_dim_batch_first' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do think it's a good black-box testing surface to cover. IIRC there was a problem at some point where some |
||
desc=f'no_batch_dim_batch_first_{batch_first}' | ||
)) | ||
|
||
src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4)) | ||
if not batch_first: | ||
src, tgt = src.transpose(0, 1), tgt.transpose(0, 1) | ||
if tgt_key_padding_mask is not None: | ||
memory_key_padding_mask, tgt_key_padding_mask = (tgt_key_padding_mask.expand(2, 3),) * 2 | ||
samples.append( | ||
ModuleInput( | ||
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=False, norm_first=norm_first), | ||
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, | ||
dropout=0.0, batch_first=batch_first, | ||
norm_first=norm_first, bias=bias), | ||
forward_input=FunctionInput( | ||
make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, memory_mask=memory_mask, | ||
src, tgt, tgt_mask=tgt_mask, memory_mask=memory_mask, | ||
tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask | ||
), | ||
reference_fn=partial(no_batch_dim_reference_fn, | ||
batch_first=False, | ||
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}), | ||
desc='no_batch_dim' | ||
desc=f'norm_first_{norm_first}_batch_first_{batch_first}_bias_{bias}' | ||
)) | ||
|
||
return samples | ||
|
@@ -2488,41 +2504,43 @@ def module_inputs_torch_nn_Transformer(module_info, device, dtype, requires_grad | |
# Samples below are for validating the no-batch-dim support. | ||
key_padding_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool)) | ||
attn_masks = (None, torch.tensor([False, False, True], device=device, dtype=torch.bool).expand((3, 3))) | ||
for mask, key_padding_mask, norm_first, bias in \ | ||
itertools.product(attn_masks, key_padding_masks, (True, False), (True, False)): | ||
for mask, key_padding_mask, norm_first, bias, batch_first in \ | ||
itertools.product(attn_masks, key_padding_masks, (True, False), (True, False), (True, False)): | ||
# Using same mask for tgt and memory | ||
src_mask , tgt_mask = (mask,) * 2 | ||
src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask,) * 2 | ||
samples.append( | ||
ModuleInput( | ||
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, | ||
num_encoder_layers=1, num_decoder_layers=1, | ||
dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias), | ||
dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias), | ||
forward_input=FunctionInput( | ||
make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask, | ||
tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask | ||
), | ||
reference_fn=partial(no_batch_dim_reference_fn, | ||
batch_first=True, | ||
batch_first=batch_first, | ||
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}), | ||
desc='no_batch_dim_batch_first' | ||
desc=f'no_batch_dim_batch_first_{batch_first}' | ||
)) | ||
|
||
src, tgt = make_input((2, 3, 4)), make_input((2, 3, 4)) | ||
if not batch_first: | ||
src = src.transpose(0, 1) | ||
tgt = tgt.transpose(0, 1) | ||
if key_padding_mask is not None: | ||
src_key_padding_mask, tgt_key_padding_mask = (key_padding_mask.expand(2, 3),) * 2 | ||
|
||
samples.append( | ||
ModuleInput( | ||
constructor_input=FunctionInput(d_model=4, nhead=2, dim_feedforward=8, | ||
num_encoder_layers=1, num_decoder_layers=1, | ||
dropout=0.0, batch_first=False, norm_first=norm_first), | ||
dropout=0.0, batch_first=batch_first, norm_first=norm_first, bias=bias), | ||
forward_input=FunctionInput( | ||
make_input((3, 4)), make_input((3, 4)), tgt_mask=tgt_mask, src_mask=src_mask, | ||
src, tgt, tgt_mask=tgt_mask, src_mask=src_mask, | ||
tgt_key_padding_mask=tgt_key_padding_mask, src_key_padding_mask=src_key_padding_mask | ||
), | ||
reference_fn=partial(no_batch_dim_reference_fn, | ||
batch_first=False, | ||
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'src_key_padding_mask': 0}), | ||
desc='no_batch_dim' | ||
)) | ||
|
||
return samples | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing fastpath test was not actually testing the fastpath against training mode since the fastpath is only run in
no_grad
mode andTestModule.test_forward
is not run underno_grad
contextI modified this such that the outputs in train mode are compared against a reference that runs the fastpath in eval/no_grad mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
possibly dumb Q: do we need to compare train mode outputs vs. fastpath reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from offline discussion: this swaps the reference to use fastpath in order to run it under no_grad mode since dropout=0 the outputs should match