Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Remove usage of isCompleteTensor() in symbolic functions #48162

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 32 additions & 4 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -754,7 +754,10 @@ def forward(self, x):
return x.transpose(0, 1)

x = torch.randn(32, 3, 64, 64)
self.run_test(TransposeModule(), x)
y = torch.randn(16, 3, 8, 64)
self.run_test(TransposeModule(), x, input_names=['x'],
dynamic_axes={'x': [0, 2]},
test_with_inputs=[y])

def squeeze_model_tests(self, d, x1, x2):
class Squeeze(torch.nn.Module):
Expand Down Expand Up @@ -841,7 +844,10 @@ def forward(self, x):
def test_maxpool_adaptive(self):
model = torch.nn.AdaptiveMaxPool1d((5), return_indices=False)
x = torch.randn(20, 16, 50, requires_grad=True)
self.run_test(model, x)
y = torch.randn(32, 16, 50, requires_grad=True)
self.run_test(model, x, input_names=['x'],
dynamic_axes={'x' : [0]},
test_with_inputs=[y])

def test_maxpool_2d(self):
model = torch.nn.MaxPool2d(5, padding=(1, 2))
Expand Down Expand Up @@ -903,7 +909,10 @@ def test_avgpool_2d_ceil(self):
def test_avgpool_3d_ceil(self):
model = torch.nn.AvgPool3d(3, 2, ceil_mode=True)
x = torch.randn(20, 16, 50, 44, 31)
self.run_test(model, x)
y = torch.randn(32, 8, 50, 44, 31)
self.run_test(model, x, input_names=['x'],
dynamic_axes={'x' : [0, 1]},
test_with_inputs=[y])

@skipIfUnsupportedMinOpsetVersion(9)
def test_floating_point(self):
Expand Down Expand Up @@ -3767,7 +3776,11 @@ def forward(self, x):
return x.unfold(dimension=2, size=2, step=2)

x = torch.randn(4, 2, 3, requires_grad=True)
self.run_test(UnfoldModel(), x)
y = torch.randn(2, 1, 3, requires_grad=True)
self.run_test(UnfoldModel(), x,
dynamic_axes={'x': [0, 1]},
input_names=['x'],
test_with_inputs=[y])

@skipIfONNXShapeInference(False)
def test_unfold_infer_shape(self):
Expand All @@ -3784,6 +3797,21 @@ def forward(self, x):
x = torch.randn(32, 3, 64)
self.run_test(UnfoldModule(), x)

def test_prelu(self):
class PReluModel(torch.nn.Module):
def __init__(self):
super(PReluModel, self).__init__()
self.prelu = torch.nn.PReLU()

def forward(self, x):
return self.prelu(x)

x = torch.randn(2, 3, 4)
y = torch.randn(2, 4, 5)
self.run_test(PReluModel(), x, input_names=['x'],
dynamic_axes={'x': [1, 2]},
test_with_inputs=[y])

def test_remainder(self):
class RemainderModel(torch.nn.Module):
def forward(self, input, other):
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/python/python_ir.cpp
Expand Up @@ -696,6 +696,16 @@ void initPythonIRBindings(PyObject* module_) {
}
return py::none();
})
.def(
"varyingSizes",
[](Type& t) -> py::object {
if (auto ptt = t.expect<TensorType>()) {
if (auto s = ptt->sizes().sizes()) {
return py::cast(s.value());
}
}
return py::none();
})
.def(
"strides",
[](Type& t) -> py::object {
Expand Down
26 changes: 25 additions & 1 deletion torch/onnx/symbolic_helper.py
Expand Up @@ -177,6 +177,29 @@ def _is_tensor(x):
def _is_tensor_list(x):
return isinstance(x.type(), torch._C.ListType) and isinstance(x.type().getElementType(), torch._C.TensorType)

def _get_tensor_rank(x):
if not _is_tensor(x) or x.type() is None:
return None
return x.type().dim()

def _get_tensor_sizes(x, allow_nonstatic=True):
if not _is_tensor(x) or x.type() is None:
return None
if allow_nonstatic:
# Each individual symbol is returned as None.
# e.g. [1, 'a', 'b'] -> [1, None, None]
return x.type().varyingSizes()
# returns None, if exists any symbol in sizes.
# e.g. [1, 'a', 'b'] -> None
return x.type().sizes()
BowenBao marked this conversation as resolved.
Show resolved Hide resolved

def _get_tensor_dim_size(x, dim):
try:
sizes = _get_tensor_sizes(x)
return sizes[dim]
except Exception:
pass
return None

def _unimplemented(op, msg):
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
Expand Down Expand Up @@ -319,7 +342,8 @@ def _get_interpolate_attributes(g, mode, args):

def _interpolate_get_scales(g, scale_factor, dim):
offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32))
if isinstance(scale_factor.type(), torch._C.ListType) or (scale_factor.isCompleteTensor() and scale_factor.type().dim() > 0):
scale_factor_rank = _get_tensor_rank(scale_factor)
if isinstance(scale_factor.type(), torch._C.ListType) or (scale_factor_rank is not None and scale_factor_rank > 0):
return g.op("Concat", offsets, scale_factor, axis_i=0)
else:
scale_factor = _unsqueeze_helper(g, scale_factor, 0)
Expand Down
7 changes: 4 additions & 3 deletions torch/onnx/symbolic_opset10.py
Expand Up @@ -209,12 +209,13 @@ def embedding_bag(g,
import warnings
warnings.warn("Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
"Please use opset 11 or higher to export model for dynamic input shape.'")
if offsets.type().sizes() is not None:
offsets_dim_0 = sym_help._get_tensor_dim_size(offsets, 0)
if offsets_dim_0 is not None:
if include_last_offset:
offset_len = offsets.type().sizes()[0] - 1
offset_len = offsets_dim_0 - 1
offsets_extended = offsets
else:
offset_len = offsets.type().sizes()[0]
offset_len = offsets_dim_0
offsets_extended = [offsets, g.op("Constant", value_t=torch.tensor([maxsize]))]
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
list_ = []
Expand Down
57 changes: 30 additions & 27 deletions torch/onnx/symbolic_opset11.py
Expand Up @@ -97,21 +97,21 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
# %28 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
# %29 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
# %15 : None = prim::Constant()
# %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
# %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
# aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
# %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
# %30 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
# %22 : int[] = prim::Constant[value=[-1]]()
# %23 : Tensor = aten::view(%16, %22)
# %24 : Tensor?[] = prim::ListConstruct(%23)
# %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
# %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
# aten::index_put(%mask, %24, %18, %30)
# return (%25)
#
# after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
# %some_const : Float(requires_grad=0, device=cpu)):
# %3 : Tensor = onnx::Equal(%0, %some_const)
# %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3)
# %4 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%3)
# %12 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Cast[to=9](%4)
# %19 : Tensor = onnx::Cast[to=9](%12)
# %20 : Tensor = onnx::Constant[value={1}]()
Expand All @@ -137,7 +137,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
# %37 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
# %22 : None = prim::Constant()
# %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
# = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
# = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
# %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
# %30 : int[] = prim::Constant[value=[-1]]()
# %31 : Tensor = aten::view(%23, %30)
Expand All @@ -148,7 +148,7 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
#
# after graph(%0 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu),
# %some_const : Float(requires_grad=0, device=cpu)):
# %3 : Float(8, strides=[1], requires_grad=0, device=cpu)
# %3 : Float(8, strides=[1], requires_grad=0, device=cpu)
# = onnx::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]()
# %4 : Tensor = onnx::Equal(%0, %some_const)
# %5 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = onnx::Not(%4)
Expand All @@ -168,17 +168,17 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
# %32 : Tensor = onnx::Constant[value={0}]()
# %33 : Tensor = onnx::Unsqueeze[axes=[0]](%32)
# %34 : Tensor = onnx::Slice(%24, %30, %31, %33)
# %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
# %35 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
# = onnx::ScatterND(%0, %22, %34)
# return (%35)

bool_inp = list(index.node().inputs())[0]
if bool_inp.type() is not None and bool_inp.type().scalarType() == 'Bool':
if values.type() is not None:
if values.type().dim() == 0:
from torch.onnx.symbolic_opset9 import masked_fill
return masked_fill(g, self, bool_inp, values)
return masked_scatter(g, self, bool_inp, values)
rank = sym_help._get_tensor_rank(values)
if rank is not None and rank == 0:
from torch.onnx.symbolic_opset9 import masked_fill
return masked_fill(g, self, bool_inp, values)
return masked_scatter(g, self, bool_inp, values)
broadcast_index_shape = g.op("Shape", index)
index = g.op("Unsqueeze", index, axes_i=[-1])
sub_data_shape = sym_help._slice_helper(
Expand All @@ -201,8 +201,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False):

@parse_args('v', 'i')
def pixel_shuffle(g, self, upscale_factor):
dims = self.type().sizes()
if len(dims) != 4:
rank = sym_help._get_tensor_rank(self)
if rank is not None and rank != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")

Expand Down Expand Up @@ -280,11 +280,12 @@ def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_s
"while exporting interpolate. Assuming that it is not a scalar.")

if is_scalar:
if not input.type().dim():
rank = sym_help._get_tensor_rank(input)
if rank is None:
return sym_help._unimplemented("interpolate (with a scalar output_size)",
"missing input shape (try giving an array of output_size values)")
size = unsqueeze(g, size, 0)
size = [size for i in range(input.type().dim() - 2)]
size = [size for i in range(rank - 2)]
size = g.op("Concat", *size, axis_i=0)
size = g.op("Cast", size, to_i=sym_help.cast_pytorch_to_onnx['Long'])
size = g.op("Concat", input_size, size, axis_i=0)
Expand All @@ -299,9 +300,10 @@ def __interpolate(g, input, size, scale_factor, mode, align_corners, recompute_s
mode_s=mode, # nearest, linear, or cubic
nearest_mode_s="floor")
else: # if not sym_help._is_none(scales)
if not input.type().dim():
rank = sym_help._get_tensor_rank(input)
if rank is None:
return sym_help._unimplemented("interpolate (with scales)", "missing input shape")
scales = sym_help._interpolate_get_scales(g, scale_factor, input.type().dim())
scales = sym_help._interpolate_get_scales(g, scale_factor, rank)
return g.op("Resize",
input,
roi,
Expand Down Expand Up @@ -549,19 +551,19 @@ def constant_pad_nd(g, input, padding, value=None):
mode = "constant"
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, input)
pad = _prepare_onnx_paddings(g, input.type().dim(), padding)
pad = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
return g.op("Pad", input, pad, value, mode_s=mode)


def reflection_pad(g, input, padding):
mode = "reflect"
paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
return g.op("Pad", input, paddings, mode_s=mode)


def replication_pad(g, input, padding):
mode = "edge"
paddings = _prepare_onnx_paddings(g, input.type().dim(), padding)
paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
return g.op("Pad", input, paddings, mode_s=mode)


Expand Down Expand Up @@ -639,9 +641,11 @@ def squeeze(g, self, dim=None):

dim = sym_help._get_const(dim, 'i', 'dim')

input_shape = self.type().sizes()
from torch.onnx.symbolic_helper import _onnx_shape_inference
if input_shape is None or not _onnx_shape_inference:
input_rank = sym_help._get_tensor_rank(self)
adjusted_dim = dim
if input_rank is not None and dim < 0:
adjusted_dim += input_rank
if (dim < 0 and input_rank is None) or sym_help._get_tensor_dim_size(self, adjusted_dim) is None:
# If onnx shape inference is not on, export always as dynamic.
# Because we cannot tell if observed static shape is also static at runtime.
# create 'cond' node (condition is shape[i]==1)
Expand All @@ -661,9 +665,8 @@ def squeeze(g, self, dim=None):
return if_node_outputs

# For static input shape
if dim < 0:
dim += self.type().dim()
if input_shape[dim] > 1:
dim = adjusted_dim
if sym_help._get_tensor_dim_size(self, dim) > 1:
warnings.warn("This model contains a squeeze operation on dimension " + str(dim) + ". The size of " +
"this dimension in the given input is " + str(input_shape[dim]) + ". The model will " +
"be exported without the squeeze node. If the model is intended to be used with dynamic " +
Expand Down Expand Up @@ -861,7 +864,7 @@ def narrow(g, input, dim, start, length):

@parse_args('v', 'i', 'i')
def flatten(g, input, start_dim, end_dim):
dim = input.type().dim()
dim = sym_help._get_tensor_rank(input)
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim == 1:
if (end_dim == -1 or (dim is not None and end_dim == dim - 1)):
Expand Down
9 changes: 4 additions & 5 deletions torch/onnx/symbolic_opset8.py
Expand Up @@ -148,10 +148,9 @@ def matmul(g, self, other):


def prelu(g, self, weight):
if self.isCompleteTensor():
self_sizes = self.type().sizes()
if self_sizes and len(self_sizes) > 2:
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, len(self_sizes) - 1)))
self_rank = sym_help._get_tensor_rank(self)
if self_rank is not None and self_rank > 2:
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
if _try_get_scalar_type(self):
old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
Expand Down Expand Up @@ -267,7 +266,7 @@ def full_like(g, input, fill_value, dtype, layout, device, pin_memory=False, mem
def repeat(g, self, repeats):
if not sym_help._is_value(repeats):
repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
if sym_help._is_packed_list(repeats):
if sym_help._is_packed_list(repeats):
repeat_size_len = len(sym_help._unpack_list(repeats))
else:
const_repeats = sym_help._maybe_get_const(repeats, 'is')
Expand Down