Skip to content

Commit

Permalink
Add a decomposition for _weight_norm_interface. (pytorch#112193)
Browse files Browse the repository at this point in the history
Fixes pytorch#112086

Pull Request resolved: pytorch#112193
Approved by: https://github.com/ezyang
  • Loading branch information
qihqi authored and xuhancn committed Nov 8, 2023
1 parent b1a3f9f commit e4e7bfc
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 2 deletions.
2 changes: 0 additions & 2 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,6 @@ aten::_values
aten::_values_copy
aten::_values_copy.out
aten::_weight_int4pack_mm
aten::_weight_norm_interface
aten::_weight_norm_interface.out
aten::_weight_norm_interface_backward
aten::_weight_norm_interface_backward.out
aten::adaptive_avg_pool2d.out
Expand Down
12 changes: 12 additions & 0 deletions test/test_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,18 @@ def test_threshold_backward_dtype(self, device):
res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1)
self.assertEqual(ref.dtype, res.dtype)

@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
def test_weight_norm_interface(self, device):
g = torch.randn((3, 10, 10), device=device)
v = torch.randn((1, 1, 10), device=device)

ref = torch.ops.aten._weight_norm_interface(g, v, 2)
res = torch._decomp.decompositions._weight_norm_interface(g, v, 2)
self.assertTrue(torch.allclose(ref[0], res[0]))
self.assertTrue(torch.allclose(ref[1], res[1]))


instantiate_device_type_tests(DecompOneOffTests, globals())

Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,5 +438,6 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten.zero_,
aten.zeros,
aten.zeros_like,
aten._weight_norm_interface,
]
)
8 changes: 8 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4240,6 +4240,14 @@ def squeeze_default(self: Tensor, dim: Optional[int] = None):
return aten.squeeze.dims(self, [dim])


@register_decomposition(torch.ops.aten._weight_norm_interface)
def _weight_norm_interface(x, y, dim):
# https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58
keep_dim = tuple(i for i in range(len(x.shape)) if i != dim)
norm = x.norm(2, keep_dim, keepdim=True)
return x * (y / norm), norm


register_inplace(aten.addbmm_, aten.addbmm)
register_inplace(aten.addmm_, aten.addmm)
register_inplace(aten.addmv_, aten.addmv)
Expand Down

0 comments on commit e4e7bfc

Please sign in to comment.