Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Oct 28, 2025

Repro test is test_sigmoid_scalar_autocast. There are two specific issues:

  • The lowering of torch.sigmoid currently upcast the sigmoid input to fp32, and we need to cast it back to original input dtype to not deviate from the original type propagation result.
  • For an operation like 2.0 * x_bf16, Triton treats scalar literal like 2.0 as fp32, thus producing fp32 output, which is different from PyTorch behavior (bf16 output). We need to cast the mul output back to bf16 dtype to be consistent with the original type propagation result.

Fixes #1038.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 28, 2025
@yf225 yf225 force-pushed the fix_dtype_mismatch branch 9 times, most recently from bf01868 to 5ad8d13 Compare October 28, 2025 09:20
@yf225 yf225 force-pushed the fix_dtype_mismatch branch 2 times, most recently from 5c5ddb3 to f225319 Compare October 28, 2025 09:24
@yf225 yf225 force-pushed the fix_dtype_mismatch branch from f225319 to 5509e8f Compare October 28, 2025 09:27
@yf225 yf225 merged commit 8926c5c into main Oct 28, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

dtype mismatch error due to automatic upcasting

4 participants