Skip to content

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Jan 4, 2024

Fixes #116385

Don't call torch._transformer_encoder_layer_fwd when bias=False

bias=False was not something that torch._transformer_encoder_layer_fwd was meant to work with, it was my bad that this wasn't tested as I approved #101687.

bias=False was causing the tensor_args in TransformerEncoder/TransformerEncoderLayer to contain Nones and error on checks for the fastpath like t.requires_grad for t in tensor_args.

Alternative fix would be to

  1. Pass torch.zeros_like({*}.weight) to the kernel when bias=False and filter tensor_args as appropriate
  2. Fix torch._transformer_encoder_layer_fwd to take Optional<Tensor> for biases and fix the kernels as appropriate

Let me know if these approaches are preferable

Stack from ghstack (oldest at bottom):

Copy link

pytorch-bot bot commented Jan 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116760

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (5 Unrelated Failures)

As of commit c9db481 with merge base a8a9695 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@mikaylagawarecki mikaylagawarecki added topic: bug fixes topic category release notes: nn release notes category labels Jan 4, 2024
Comment on lines +2404 to +2427
# Samples below where we pass reference_fn are for validating the fast path,
# since the fast path requires no_grad mode, we run the fast path in .eval()
# and no_grad() in the reference_fn and verify that against the results in train mode.
def fast_path_reference_fn(module, parameters, *args, **kwargs):
assert not module.training
module = module.train(True)
output = module(*args, **kwargs)
module = module.train(False)
assert module.training
module.train(False)
with torch.no_grad():
output = module(*args, **kwargs)
module.train(True)
return output

if not training:
for norm_first in (True, False):
if training:
for norm_first, bias in itertools.product((True, False), (True, False)):
samples.append(
ModuleInput(
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first),
constructor_input=FunctionInput(
4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias
),
forward_input=FunctionInput(
make_input((2, 3, 4)),
),
reference_fn=fast_path_reference_fn,
desc="fast_path_norm_first" if norm_first else "fast_path"
# fastpath doesn't run when bias=False
reference_fn=fast_path_reference_fn if bias else None,
desc=f'fastpath_{bias}_norm_first_{norm_first}'
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing fastpath test was not actually testing the fastpath against training mode since the fastpath is only run in no_grad mode and TestModule.test_forward is not run under no_grad context

I modified this such that the outputs in train mode are compared against a reference that runs the fastpath in eval/no_grad mode

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified this such that the outputs in train mode are compared against a reference that runs the fastpath in eval/no_grad mode

possibly dumb Q: do we need to compare train mode outputs vs. fastpath reference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from offline discussion: this swaps the reference to use fastpath in order to run it under no_grad mode since dropout=0 the outputs should match

Fixes #116385

Don't call `torch._transformer_encoder_layer_fwd` when `bias=False` (which sets biases to `None`).

This also prevents us from ever doing checks on properties of `tensor_args` in `TransformerEncoder`/`TransformerEncoderLayer` which contained the Nones and was erroring on checks like `t.requires_grad for t in tensor_args`.

Alternative fix would be to
1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate
2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases

Let me know if this approach is preferable




[ghstack-poisoned]
@mikaylagawarecki mikaylagawarecki changed the title Fix TransformerEncoderLayer fastpath for bias=False Fix TransformerEncoderLayer for bias=False Jan 4, 2024
@mikaylagawarecki mikaylagawarecki marked this pull request as ready for review January 4, 2024 04:36
@mikaylagawarecki
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict gh/mikaylagawarecki/172/orig returned non-zero exit code 1

Rebasing (1/1)
Auto-merging test/test_transformers.py
Auto-merging torch/nn/modules/transformer.py
CONFLICT (content): Merge conflict in torch/nn/modules/transformer.py
error: could not apply 594e9fa8bb9... Fix TransformerEncoderLayer fastpath for bias=False
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
Could not apply 594e9fa8bb9... Fix TransformerEncoderLayer fastpath for bias=False

Raised by https://github.com/pytorch/pytorch/actions/runs/7405594863

Fixes #116385

Don't call `torch._transformer_encoder_layer_fwd` when `bias=False` (which sets biases to `None`).

This also prevents us from ever doing checks on properties of `tensor_args` in `TransformerEncoder`/`TransformerEncoderLayer` which contained the Nones and was erroring on checks like `t.requires_grad for t in tensor_args`.

Alternative fix would be to
1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate
2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate

Let me know if these approaches is preferable




[ghstack-poisoned]
Fixes #116385

Don't call `torch._transformer_encoder_layer_fwd` when `bias=False`

`bias=False` was not something that `torch._transformer_encoder_layer_fwd`  was meant to work with, it was my bad that this wasn't tested as I approved #101687.

`bias=False` was causing the `tensor_args` in [`TransformerEncoder`](https://github.com/pytorch/pytorch/blob/a17de2d6455e262f9b514584443ac60cf381bc85/torch/nn/modules/transformer.py#L364-L378)/[`TransformerEncoderLayer`](https://github.com/pytorch/pytorch/blob/a17de2d6455e262f9b514584443ac60cf381bc85/torch/nn/modules/transformer.py#L663-L677) to contain `None`s and error on checks for the fastpath like `t.requires_grad for t in tensor_args`.

Alternative fix would be to
1) Pass `torch.zeros_like({*}.weight)` to the kernel when `bias=False` and filter `tensor_args` as appropriate
2) Fix `torch._transformer_encoder_layer_fwd` to take `Optional<Tensor>` for biases and fix the kernels as appropriate

Let me know if these approaches are preferable




[ghstack-poisoned]
mikaylagawarecki added a commit that referenced this pull request Jan 4, 2024
ghstack-source-id: 920f89b
Pull Request resolved: #116760
batch_first=True,
batch_first=batch_first,
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
desc='no_batch_dim_batch_first'
Copy link
Contributor Author

@mikaylagawarecki mikaylagawarecki Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like batch_first should not affect no_batch_dim so not entirely sure why we have these tests, but didn't delete them in case they are testing an edge case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it's a good black-box testing surface to cover. IIRC there was a problem at some point where some batch_first setting blew up on no-batch-dim inputs.

@albanD albanD removed their request for review January 4, 2024 11:05
Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fix looks solid, couple minor testing comments for ModuleInfo tests. As long as things pass, I'm good with it :)

batch_first=True,
batch_first=batch_first,
kwargs_to_batchify={'tgt_key_padding_mask': 0, 'memory_key_padding_mask': 0}),
desc='no_batch_dim_batch_first'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think it's a good black-box testing surface to cover. IIRC there was a problem at some point where some batch_first setting blew up on no-batch-dim inputs.

Comment on lines +2404 to +2427
# Samples below where we pass reference_fn are for validating the fast path,
# since the fast path requires no_grad mode, we run the fast path in .eval()
# and no_grad() in the reference_fn and verify that against the results in train mode.
def fast_path_reference_fn(module, parameters, *args, **kwargs):
assert not module.training
module = module.train(True)
output = module(*args, **kwargs)
module = module.train(False)
assert module.training
module.train(False)
with torch.no_grad():
output = module(*args, **kwargs)
module.train(True)
return output

if not training:
for norm_first in (True, False):
if training:
for norm_first, bias in itertools.product((True, False), (True, False)):
samples.append(
ModuleInput(
constructor_input=FunctionInput(4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first),
constructor_input=FunctionInput(
4, 2, 8, dropout=0.0, batch_first=True, norm_first=norm_first, bias=bias
),
forward_input=FunctionInput(
make_input((2, 3, 4)),
),
reference_fn=fast_path_reference_fn,
desc="fast_path_norm_first" if norm_first else "fast_path"
# fastpath doesn't run when bias=False
reference_fn=fast_path_reference_fn if bias else None,
desc=f'fastpath_{bias}_norm_first_{norm_first}'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified this such that the outputs in train mode are compared against a reference that runs the fastpath in eval/no_grad mode

possibly dumb Q: do we need to compare train mode outputs vs. fastpath reference?

@mikaylagawarecki
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 4, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/mikaylagawarecki/172/head branch January 8, 2024 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nn release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants