Skip to content

Commit

Permalink
[ONNX] Fix bug in unfold symbolic (#50504)
Browse files Browse the repository at this point in the history
Fix bug in unfold symbolic

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Feb 2, 2021
1 parent 3856bf8 commit 5ab9320
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
7 changes: 7 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4324,6 +4324,13 @@ def forward(self, x):
x = torch.randn(4, 2, 4, requires_grad=True)
self.run_test(UnfoldModel(), x)

class UnfoldModel(torch.nn.Module):
def forward(self, x):
return x.unfold(dimension=2, size=x.shape[1], step=1)

x = torch.randn(4, 2, 4, requires_grad=True)
self.run_test(UnfoldModel(), x)

def test_prelu(self):
class PReluModel(torch.nn.Module):
def __init__(self):
Expand Down
9 changes: 4 additions & 5 deletions torch/onnx/symbolic_opset12.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
Expand Down Expand Up @@ -124,11 +123,11 @@ def le(g, input, other):

@parse_args('v', 'i', 'v', 'v')
def unfold(g, input, dimension, size, step):
size = sym_help._maybe_get_const(size, 'i')
step = sym_help._maybe_get_const(step, 'i')
if not sym_help._is_value(size) and not sym_help._is_value(step):
const_size = sym_help._maybe_get_const(size, 'i')
const_step = sym_help._maybe_get_const(step, 'i')
if not sym_help._is_value(const_size) and not sym_help._is_value(const_step):
from torch.onnx.symbolic_opset9 import unfold as _unfold
return _unfold(g, input, dimension, size, step)
return _unfold(g, input, dimension, const_size, const_step)
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)

Expand Down

0 comments on commit 5ab9320

Please sign in to comment.