Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix weight_norm decomposition behavior (#128956)
By upcasting norm to float32 to align with CUDA and CPU behaviors https://github.com/pytorch/pytorch/blob/e6d4451ae8987bf8d6ad85eb7cde685fac746f6f/aten/src/ATen/native/WeightNorm.cpp#L56-L59 Discovered this when started running OpInfo tests, see https://github.com/pytorch/pytorch/actions/runs/9552858711/job/26332062502#step:20:1060 ``` File "/var/lib/jenkins/workspace/test/test_decomp.py", line 185, in op_assert_ref assert orig.dtype == decomp.dtype, f"{i} Operation: {op}" AssertionError: 1 Operation: aten._weight_norm_interface.default ``` Pull Request resolved: #128956 Approved by: https://github.com/albanD ghstack dependencies: #128955
- Loading branch information