Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add export of aten::is_floating point #46442

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/onnx/test.sh
Expand Up @@ -51,6 +51,7 @@ pytest "${args[@]}" \
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
--ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \
KsenijaS marked this conversation as resolved.
Show resolved Hide resolved
"${test_paths[@]}"

# onnxruntime only support py3
Expand Down
54 changes: 54 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -895,6 +895,60 @@ def test_avgpool_3d_ceil(self):
x = torch.randn(20, 16, 50, 44, 31)
self.run_test(model, x)

@skipIfUnsupportedMinOpsetVersion(9)
def test_floating_point(self):
class FloatingPoint(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
if x.is_floating_point():
return x.new_zeros(x.shape)
return x.new_zeros(x.shape)

x = torch.randn(2, 3, 4)
self.run_test(FloatingPoint(), x)

class FloatingPoint(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
if x.size(0) > 1:
KsenijaS marked this conversation as resolved.
Show resolved Hide resolved
a = x + 2
if a.is_floating_point():
return x + 1
return x + 1
return x

x = torch.randn(2, 3, 4)
self.run_test(FloatingPoint(), x)

@skipIfUnsupportedMinOpsetVersion(9)
@skipIfONNXShapeInference(False)
def test_floating_point_infer_dtype(self):
class FloatingPoint(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
if x.size(0) > 1:
a = x + 2
if a.is_floating_point():
return x.new_zeros(x.shape[1:])
return x.new_zeros(x.shape)
return x

x = torch.randn(2, 3, 4)
self.run_test(FloatingPoint(), x)

class FloatingPoint(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
if x.size(0) > 1:
a = x + 2
if a.is_floating_point():
return x + 1
return x
return x

x = torch.randn(2, 3, 4).to(torch.int32)
self.run_test(FloatingPoint(), x)

def test_arithmetic(self):
class ArithmeticModule(torch.nn.Module):
def forward(self, x):
Expand Down
2 changes: 2 additions & 0 deletions torch/onnx/symbolic_helper.py
Expand Up @@ -229,6 +229,8 @@ def _is_fp(value):
return (type == 'torch.float32') or (type == 'torch.float64') or (type == 'torch.float16')
else:
type = value.type().scalarType()
if type is None:
warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.")
return (type == 'Float') or (type == 'Double') or (type == 'Half')
return False

Expand Down
6 changes: 6 additions & 0 deletions torch/onnx/symbolic_opset9.py
Expand Up @@ -2194,6 +2194,12 @@ def prim_shape(g, self):
return g.op('Shape', self)


def is_floating_point(g, self):
if sym_help._is_fp(self):
return g.op("Constant", value_t=torch.BoolTensor([1]))
return g.op("Constant", value_t=torch.BoolTensor([0]))


@parse_args('v', 'i')
def one_hot(g, self, num_classes):
values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
Expand Down