Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Fix pow op export #38065

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 21 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -1361,6 +1361,27 @@ def forward(self, input):
model = StandardDeviation()
self.run_test(model, x)

def test_pow(self):
class PowModule(torch.nn.Module):
def forward(self, x, y):
return x.pow(y)

x = torch.randn(2, 3, 4)
y = torch.randn(2, 3, 4)
self.run_test(PowModule(), (x, y))

x = torch.randint(10, (2, 3, 4))
y = torch.randint(10, (2, 3, 4)).to(dtype=torch.int32)
self.run_test(PowModule(), (x, y))

x = torch.randint(10, (2, 3, 4))
y = torch.randint(10, (2, 3, 4))
self.run_test(PowModule(), (x, y))

x = torch.randn(2, 3, 4).to(dtype=torch.float64)
y = torch.randint(10, (2, 3, 4))
self.run_test(PowModule(), (x, y))

def test_std_along_dims(self):
class StandardDeviation(torch.nn.Module):
def forward(self, input):
Expand Down
5 changes: 5 additions & 0 deletions torch/onnx/symbolic_opset12.py
Expand Up @@ -55,5 +55,10 @@ def nll_loss(g, self, target, weight, reduction, ignore_index):

return nllloss


def nll_loss2d(g, self, target, weight, reduction, ignore_index):
return nll_loss(g, self, target, weight, reduction, ignore_index)


def pow(g, self, exponent):
return g.op("Pow", self, exponent)
11 changes: 10 additions & 1 deletion torch/onnx/symbolic_opset9.py
Expand Up @@ -1253,7 +1253,16 @@ def log1p(g, self):


def pow(g, self, exponent):
return g.op("Pow", self, exponent)
f_dtype = self_dtype = self.type().scalarType()
if not sym_help._is_fp(self):
f_dtype = 'Float'
self = g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[f_dtype])
if not sym_help._is_fp(exponent):
exponent = g.op("Cast", exponent, to_i=sym_help.cast_pytorch_to_onnx[f_dtype])
pow = g.op("Pow", self, exponent)
if self_dtype and self_dtype != f_dtype:
pow = g.op("Cast", pow, to_i=sym_help.cast_pytorch_to_onnx[self_dtype])
return pow


def clamp(g, self, min, max):
Expand Down