Skip to content

Commit

Permalink
[PYTORCH]aten::norm support added (apache#5776)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and Trevor Morris committed Jun 18, 2020
1 parent 6f58957 commit 1655e81
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 0 deletions.
40 changes: 40 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,44 @@ def _impl(inputs, input_types):

return _impl

def _norm():
def _impl(inputs, input_types):
data = inputs[0]
axis = None
keepdims = False
if len(inputs) > 3:
axis = list(_infer_shape(inputs[2]))
keepdims = bool(inputs[3])

order = inputs[1]
if order == np.inf:
return _op.reduce.max(_op.abs(data), axis=axis, keepdims=keepdims)
elif order == np.NINF:
return _op.reduce.min(_op.abs(data), axis=axis, keepdims=keepdims)
else:
reci_order = _expr.const(1.0 / order)
order = _expr.const(order)
return _op.power(_op.reduce.sum(_op.power(_op.abs(data), order),
axis=axis,
keepdims=keepdims),
reci_order)
return _impl


def _frobenius_norm():
def _impl(inputs, input_types):
data = inputs[0]
axis = None
keepdims = False
if len(inputs) > 2:
axis = list(_infer_shape(inputs[1]))
keepdims = bool(inputs[2])

return _op.sqrt(_op.reduce.sum((data * data), axis=axis, keepdims=keepdims))

return _impl


def _std():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1853,6 +1891,8 @@ def _get_convert_map(prelude):
"aten::prod" : _reduce("prod"),
"aten::argmin" : _reduce("argmin"),
"aten::argmax" : _reduce("argmax"),
"aten::norm" : _norm(),
"aten::frobenius_norm" : _frobenius_norm(),
"aten::std" : _std(),
"aten::var" : _variance(),
"aten::abs" : _unary("abs"),
Expand Down
87 changes: 87 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,91 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(LogSoftmax1().float().eval(), input_data=input_data)


def test_forward_norm():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Norm1(Module):
def forward(self, *args):
return torch.norm(args[0], p=float('inf'), dim=None, keepdim=False)

class Norm2(Module):
def forward(self, *args):
return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=False)

class Norm3(Module):
def forward(self, *args):
return torch.norm(args[0], p=float('-inf'), dim=None, keepdim=True)

class Norm4(Module):
def forward(self, *args):
return torch.norm(args[0], p=float('inf'), dim=(1, 2), keepdim=False)

class Norm5(Module):
def forward(self, *args):
return torch.norm(args[0], p=float('inf'), dim=(1), keepdim=True)

class Norm6(Module):
def forward(self, *args):
return torch.norm(args[0], p=float(0.5), dim=(1), keepdim=True)

class Norm7(Module):
def forward(self, *args):
return torch.norm(args[0], p=float(1), dim=None, keepdim=False)

class Norm8(Module):
def forward(self, *args):
return torch.norm(args[0], p=float(2.0), dim=(1), keepdim=True)

class Norm9(Module):
def forward(self, *args):
return torch.norm(args[0], p=float(-0.5), dim=(1, 2), keepdim=True)

class Norm10(Module):
def forward(self, *args):
return torch.norm(args[0], p=float(-2), dim=(1), keepdim=False)

input_data = torch.rand(input_shape).float()
verify_model(Norm1().float().eval(), input_data=input_data)
verify_model(Norm2().float().eval(), input_data=input_data)
verify_model(Norm3().float().eval(), input_data=input_data)
verify_model(Norm4().float().eval(), input_data=input_data)
verify_model(Norm5().float().eval(), input_data=input_data)
verify_model(Norm6().float().eval(), input_data=input_data)
verify_model(Norm7().float().eval(), input_data=input_data)
verify_model(Norm8().float().eval(), input_data=input_data)
verify_model(Norm9().float().eval(), input_data=input_data)
verify_model(Norm10().float().eval(), input_data=input_data)


def test_forward_frobenius_norm():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class FroNorm1(Module):
def forward(self, *args):
return torch.norm(args[0])

class FroNorm2(Module):
def forward(self, *args):
return torch.norm(args[0], p='fro', dim=None, keepdim=True)

class FroNorm3(Module):
def forward(self, *args):
return torch.norm(args[0], p='fro', dim=(1), keepdim=True)

class FroNorm4(Module):
def forward(self, *args):
return torch.norm(args[0], dim=None, keepdim=False)

input_data = torch.rand(input_shape).float()
verify_model(FroNorm1().float().eval(), input_data=input_data)
verify_model(FroNorm2().float().eval(), input_data=input_data)
verify_model(FroNorm3().float().eval(), input_data=input_data)
verify_model(FroNorm4().float().eval(), input_data=input_data)


def test_forward_sigmoid():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -2421,6 +2506,8 @@ def test_forward_pretrained_bert_base_uncased():
test_forward_reduce_prod()
test_forward_argmin()
test_forward_argmax()
test_forward_norm()
test_forward_frobenius_norm()
test_forward_std()
test_forward_variance()
test_forward_relu()
Expand Down

0 comments on commit 1655e81

Please sign in to comment.