Skip to content

Commit

Permalink
Fix weight_norm decomposition behavior (#128956)
Browse files Browse the repository at this point in the history
By upcasting norm to float32 to align with CUDA and CPU behaviors
https://github.com/pytorch/pytorch/blob/e6d4451ae8987bf8d6ad85eb7cde685fac746f6f/aten/src/ATen/native/WeightNorm.cpp#L56-L59

Discovered this when started running OpInfo tests, see https://github.com/pytorch/pytorch/actions/runs/9552858711/job/26332062502#step:20:1060
```
  File "/var/lib/jenkins/workspace/test/test_decomp.py", line 185, in op_assert_ref
    assert orig.dtype == decomp.dtype, f"{i} Operation:  {op}"
AssertionError: 1 Operation:  aten._weight_norm_interface.default
```
Pull Request resolved: #128956
Approved by: https://github.com/albanD
ghstack dependencies: #128955
  • Loading branch information
malfet authored and pytorchmergebot committed Jun 18, 2024
1 parent 2227da4 commit e47603a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4773,8 +4773,10 @@ def squeeze_default(self: Tensor, dim: Optional[int] = None):
def _weight_norm_interface(v, g, dim=0):
# https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58
keep_dim = tuple(i for i in range(len(v.shape)) if i != dim)
norm = v.norm(2, keep_dim, keepdim=True)
return v * (g / norm), norm
# align with cuda behavior, keep norm in 'float' when g is 'bfloat16'
norm_dtype = torch.float if g.dtype == torch.bfloat16 else None
norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype)
return v * (g / norm.to(g.dtype)), norm


@register_decomposition(aten.isin)
Expand Down

0 comments on commit e47603a

Please sign in to comment.