Handle out_dtype in ReplacePT2DequantWithCadenceDequantPass (#19743)#19743
Handle out_dtype in ReplacePT2DequantWithCadenceDequantPass (#19743)#19743ethansfng wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19743
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit fa1bce3 with merge base ec76470 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@ethansfng has exported this pull request. If you are a Meta employee, you can view the originating Diff in D105630451. |
This PR needs a
|
…19743) Summary: torchao's `convert_pt2e` adds `out_dtype` kwargs to dequant nodes for bf16 models. `cadence::dequantize_per_tensor` doesn't support this kwarg (it hardcodes float32 output), so `ReplacePT2DequantWithCadenceDequantPass` crashes when it forwards kwargs blindly to the cadence op. Strip `out_dtype` from kwargs before creating the cadence dequant node, and insert an `aten.to.dtype` cast after it to preserve the original output dtype semantics. Differential Revision: D105630451
50f7925 to
fa1bce3
Compare
…19743) Summary: torchao's `convert_pt2e` adds `out_dtype` kwargs to dequant nodes for bf16 models. `cadence::dequantize_per_tensor` doesn't support this kwarg (it hardcodes float32 output), so `ReplacePT2DequantWithCadenceDequantPass` crashes when it forwards kwargs blindly to the cadence op. Strip `out_dtype` from kwargs before creating the cadence dequant node, and insert an `aten.to.dtype` cast after it to preserve the original output dtype semantics. Differential Revision: D105630451
…19743) Summary: torchao's `convert_pt2e` adds `out_dtype` kwargs to dequant nodes for bf16 models. `cadence::dequantize_per_tensor` doesn't support this kwarg (it hardcodes float32 output), so `ReplacePT2DequantWithCadenceDequantPass` crashes when it forwards kwargs blindly to the cadence op. Strip `out_dtype` from kwargs before creating the cadence dequant node, and insert an `aten.to.dtype` cast after it to preserve the original output dtype semantics. Differential Revision: D105630451
…19743) Summary: torchao's `convert_pt2e` adds `out_dtype` kwargs to dequant nodes for bf16 models. `cadence::dequantize_per_tensor` doesn't support this kwarg (it hardcodes float32 output), so `ReplacePT2DequantWithCadenceDequantPass` crashes when it forwards kwargs blindly to the cadence op. Strip `out_dtype` from kwargs before creating the cadence dequant node, and insert an `aten.to.dtype` cast after it to preserve the original output dtype semantics. Differential Revision: D105630451
…19743) Summary: torchao's `convert_pt2e` adds `out_dtype` kwargs to dequant nodes for bf16 models. `cadence::dequantize_per_tensor` doesn't support this kwarg (it hardcodes float32 output), so `ReplacePT2DequantWithCadenceDequantPass` crashes when it forwards kwargs blindly to the cadence op. Strip `out_dtype` from kwargs before creating the cadence dequant node, and insert an `aten.to.dtype` cast after it to preserve the original output dtype semantics. Differential Revision: D105630451
…19743) Summary: torchao's `convert_pt2e` adds `out_dtype` kwargs to dequant nodes for bf16 models. `cadence::dequantize_per_tensor` doesn't support this kwarg (it hardcodes float32 output), so `ReplacePT2DequantWithCadenceDequantPass` crashes when it forwards kwargs blindly to the cadence op. Strip `out_dtype` from kwargs before creating the cadence dequant node, and insert an `aten.to.dtype` cast after it to preserve the original output dtype semantics. Differential Revision: D105630451
Summary:
torchao's
convert_pt2eaddsout_dtypekwargs to dequant nodes for bf16 models.cadence::dequantize_per_tensordoesn't support this kwarg (it hardcodes float32 output), soReplacePT2DequantWithCadenceDequantPasscrashes when it forwards kwargs blindly to the cadence op.Strip
out_dtypefrom kwargs before creating the cadence dequant node, and insert anaten.to.dtypecast after it to preserve the original output dtype semantics.Differential Revision: D105630451