Fix dtype mismatch error in se_block example #1040
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Repro test is
test_sigmoid_scalar_autocast. There are two specific issues:torch.sigmoidcurrently upcast the sigmoid input to fp32, and we need to cast it back to original input dtype to not deviate from the original type propagation result.2.0 * x_bf16, Triton treats scalar literal like2.0as fp32, thus producing fp32 output, which is different from PyTorch behavior (bf16 output). We need to cast themuloutput back to bf16 dtype to be consistent with the original type propagation result.Fixes #1038.