-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Fix TransformerEncoderLayer for bias=False #116760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix TransformerEncoderLayer for bias=False #116760
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116760
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (5 Unrelated Failures)As of commit c9db481 with merge base a8a9695 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# Samples below where we pass reference_fn are for validating the fast path, | ||
# since the fast path requires no_grad mode, we run the fast path in .eval() | ||
# and no_grad() in the reference_fn and verify that against the results in train mode. | ||
def fast_path_reference_fn(module, parameters, *args, **kwargs): | ||
assert not module.training | ||
module = module.train(True) | ||
output = module(*args, **kwargs) | ||
module = module.train(False) | ||
assert module.training | ||
module.train(False) | ||
with torch.no_grad(): | ||
output = module(*args, **kwargs) | ||
module.train(True) | ||
return output | ||
|
||
if not training: | ||
for norm_first in (True, False): | ||
if training: | ||
for norm_first, bias in itertools.product((True, False), (True, False)): | ||
samples.append( | ||
ModuleInput( | ||
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first), | ||
constructor_input=FunctionInput( | ||
4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias | ||
), | ||
forward_input=FunctionInput( | ||
make_input((2, 3, 4)), | ||
), | ||
reference_fn=fast_path_reference_fn, | ||
desc="fast_path_norm_first" if norm_first else "fast_path" | ||
# fastpath doesn't run when bias=False | ||
reference_fn=fast_path_reference_fn if bias else None, | ||
desc=f'fastpath_{bias}_norm_first_{norm_first}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing fastpath test was not actually testing the fastpath against training mode since the fastpath is only run in no_grad
mode and TestModule.test_forward
is not run under no_grad
context
I modified this such that the outputs in train mode are compared against a reference that runs the fastpath in eval/no_grad mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified this such that the outputs in train mode are compared against a reference that runs the fastpath in eval/no_grad mode
possibly dumb Q: do we need to compare train mode outputs vs. fastpath reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from offline discussion: this swaps the reference to use fastpath in order to run it under no_grad mode since dropout=0 the outputs should match
Fixes #116385 Don't call `torch._transformer_encoder_layer_fwd` when `bias=False` (which sets biases to `None`). This also prevents us from ever doing checks on properties of `tensor_args` in `TransformerEncoder`/`TransformerEncoderLayer` which contained the Nones and was erroring on checks like `t.requires_grad for t in tensor_args`. Alternative fix would be to 1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate 2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases Let me know if this approach is preferable [ghstack-poisoned]
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/7405594863 |
Fixes #116385 Don't call `torch._transformer_encoder_layer_fwd` when `bias=False` (which sets biases to `None`). This also prevents us from ever doing checks on properties of `tensor_args` in `TransformerEncoder`/`TransformerEncoderLayer` which contained the Nones and was erroring on checks like `t.requires_grad for t in tensor_args`. Alternative fix would be to 1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate 2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate Let me know if these approaches is preferable [ghstack-poisoned]
Fixes #116385 Don't call `torch._transformer_encoder_layer_fwd` when `bias=False` `bias=False` was not something that `torch._transformer_encoder_layer_fwd` was meant to work with, it was my bad that this wasn't tested as I approved #101687. `bias=False` was causing the `tensor_args` in [`TransformerEncoder`](https://github.com/pytorch/pytorch/blob/a17de2d6455e262f9b514584443ac60cf381bc85/torch/nn/modules/transformer.py#L364-L378)/[`TransformerEncoderLayer`](https://github.com/pytorch/pytorch/blob/a17de2d6455e262f9b514584443ac60cf381bc85/torch/nn/modules/transformer.py#L663-L677) to contain `None`s and error on checks for the fastpath like `t.requires_grad for t in tensor_args`. Alternative fix would be to 1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate 2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate Let me know if these approaches are preferable [ghstack-poisoned]
batch_first=True, | ||
batch_first=batch_first, | ||
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}), | ||
desc='no_batch_dim_batch_first' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels like batch_first
should not affect no_batch_dim
so not entirely sure why we have these tests, but didn't delete them in case they are testing an edge case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think it's a good black-box testing surface to cover. IIRC there was a problem at some point where some batch_first
setting blew up on no-batch-dim inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Fix looks solid, couple minor testing comments for ModuleInfo
tests. As long as things pass, I'm good with it :)
batch_first=True, | ||
batch_first=batch_first, | ||
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}), | ||
desc='no_batch_dim_batch_first' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think it's a good black-box testing surface to cover. IIRC there was a problem at some point where some batch_first
setting blew up on no-batch-dim inputs.
# Samples below where we pass reference_fn are for validating the fast path, | ||
# since the fast path requires no_grad mode, we run the fast path in .eval() | ||
# and no_grad() in the reference_fn and verify that against the results in train mode. | ||
def fast_path_reference_fn(module, parameters, *args, **kwargs): | ||
assert not module.training | ||
module = module.train(True) | ||
output = module(*args, **kwargs) | ||
module = module.train(False) | ||
assert module.training | ||
module.train(False) | ||
with torch.no_grad(): | ||
output = module(*args, **kwargs) | ||
module.train(True) | ||
return output | ||
|
||
if not training: | ||
for norm_first in (True, False): | ||
if training: | ||
for norm_first, bias in itertools.product((True, False), (True, False)): | ||
samples.append( | ||
ModuleInput( | ||
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first), | ||
constructor_input=FunctionInput( | ||
4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias | ||
), | ||
forward_input=FunctionInput( | ||
make_input((2, 3, 4)), | ||
), | ||
reference_fn=fast_path_reference_fn, | ||
desc="fast_path_norm_first" if norm_first else "fast_path" | ||
# fastpath doesn't run when bias=False | ||
reference_fn=fast_path_reference_fn if bias else None, | ||
desc=f'fastpath_{bias}_norm_first_{norm_first}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified this such that the outputs in train mode are compared against a reference that runs the fastpath in eval/no_grad mode
possibly dumb Q: do we need to compare train mode outputs vs. fastpath reference?
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #116385
Don't call
torch._transformer_encoder_layer_fwd
whenbias=False
bias=False
was not something thattorch._transformer_encoder_layer_fwd
was meant to work with, it was my bad that this wasn't tested as I approved #101687.bias=False
was causing thetensor_args
inTransformerEncoder
/TransformerEncoderLayer
to containNone
s and error on checks for the fastpath liket.requires_grad for t in tensor_args
.Alternative fix would be to
torch.zeros_like({*}.weight)
to the kernel whenbias=False
and filtertensor_args
as appropriatetorch._transformer_encoder_layer_fwd
to takeOptional<Tensor>
for biases and fix the kernels as appropriateLet me know if these approaches are preferable
Stack from ghstack (oldest at bottom):