diff --git a/test/onnx/expect/TestOperators.test_upsample_bilinear.expect b/test/onnx/expect/TestOperators.test_upsample_bilinear.expect deleted file mode 100644 index 02ccc8915d843..0000000000000 --- a/test/onnx/expect/TestOperators.test_upsample_bilinear.expect +++ /dev/null @@ -1,272 +0,0 @@ -ir_version: 4 -producer_name: "pytorch" -producer_version: "1.1" -graph { - node { - output: "1" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\002\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "input" - output: "2" - op_type: "Shape" - } - node { - input: "2" - input: "1" - output: "3" - op_type: "Gather" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - output: "4" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\002\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "3" - input: "4" - output: "5" - op_type: "Mul" - } - node { - input: "5" - output: "6" - op_type: "Floor" - } - node { - output: "7" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\003\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "input" - output: "8" - op_type: "Shape" - } - node { - input: "8" - input: "7" - output: "9" - op_type: "Gather" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - output: "10" - op_type: "Constant" - attribute { - name: "value" - t { - data_type: 7 - raw_data: "\002\000\000\000\000\000\000\000" - } - type: TENSOR - } - } - node { - input: "9" - input: "10" - output: "11" - op_type: "Mul" - } - node { - input: "11" - output: "12" - op_type: "Floor" - } - node { - input: "6" - output: "13" - op_type: "Unsqueeze" - attribute { - name: "axes" - ints: 0 - type: INTS - } - } - node { - input: "12" - output: "14" - op_type: "Unsqueeze" - attribute { - name: "axes" - ints: 0 - type: INTS - } - } - node { - input: "13" - input: "14" - output: "15" - op_type: "Concat" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - output: "16" - op_type: "Constant" - attribute { - name: "value" - t { - dims: 2 - data_type: 1 - raw_data: "\000\000\200?\000\000\200?" - } - type: TENSOR - } - } - node { - input: "15" - output: "17" - op_type: "Cast" - attribute { - name: "to" - i: 1 - type: INT - } - } - node { - input: "input" - output: "18" - op_type: "Shape" - } - node { - input: "18" - output: "19" - op_type: "Slice" - attribute { - name: "axes" - ints: 0 - type: INTS - } - attribute { - name: "ends" - ints: 4 - type: INTS - } - attribute { - name: "starts" - ints: 2 - type: INTS - } - } - node { - input: "19" - output: "20" - op_type: "Cast" - attribute { - name: "to" - i: 1 - type: INT - } - } - node { - input: "17" - input: "20" - output: "21" - op_type: "Div" - } - node { - input: "16" - input: "21" - output: "22" - op_type: "Concat" - attribute { - name: "axis" - i: 0 - type: INT - } - } - node { - input: "input" - input: "22" - output: "23" - op_type: "Upsample" - attribute { - name: "mode" - s: "linear" - type: STRING - } - } - name: "torch-jit-export" - input { - name: "input" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 2 - } - dim { - dim_value: 3 - } - dim { - dim_value: 4 - } - } - } - } - } - output { - name: "23" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - dim { - dim_value: 2 - } - dim { - dim_value: 6 - } - dim { - dim_value: 8 - } - } - } - } - } -} -opset_import { - version: 9 -} diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest.expect b/test/onnx/expect/TestOperators.test_upsample_nearest.expect index 98a8853294763..8d70373d96b66 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest.expect @@ -51,10 +51,20 @@ graph { node { input: "5" output: "6" - op_type: "Floor" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } } node { + input: "6" output: "7" + op_type: "Floor" + } + node { + output: "8" op_type: "Constant" attribute { name: "value" @@ -67,13 +77,13 @@ graph { } node { input: "input" - output: "8" + output: "9" op_type: "Shape" } node { + input: "9" input: "8" - input: "7" - output: "9" + output: "10" op_type: "Gather" attribute { name: "axis" @@ -82,7 +92,7 @@ graph { } } node { - output: "10" + output: "11" op_type: "Constant" attribute { name: "value" @@ -94,19 +104,29 @@ graph { } } node { - input: "9" input: "10" - output: "11" + input: "11" + output: "12" op_type: "Mul" } node { - input: "11" - output: "12" + input: "12" + output: "13" + op_type: "Cast" + attribute { + name: "to" + i: 1 + type: INT + } + } + node { + input: "13" + output: "14" op_type: "Floor" } node { - input: "6" - output: "13" + input: "7" + output: "15" op_type: "Unsqueeze" attribute { name: "axes" @@ -115,8 +135,8 @@ graph { } } node { - input: "12" - output: "14" + input: "14" + output: "16" op_type: "Unsqueeze" attribute { name: "axes" @@ -125,9 +145,9 @@ graph { } } node { - input: "13" - input: "14" - output: "15" + input: "15" + input: "16" + output: "17" op_type: "Concat" attribute { name: "axis" @@ -136,7 +156,7 @@ graph { } } node { - output: "16" + output: "18" op_type: "Constant" attribute { name: "value" @@ -149,8 +169,8 @@ graph { } } node { - input: "15" - output: "17" + input: "17" + output: "19" op_type: "Cast" attribute { name: "to" @@ -160,12 +180,12 @@ graph { } node { input: "input" - output: "18" + output: "20" op_type: "Shape" } node { - input: "18" - output: "19" + input: "20" + output: "21" op_type: "Slice" attribute { name: "axes" @@ -184,8 +204,8 @@ graph { } } node { - input: "19" - output: "20" + input: "21" + output: "22" op_type: "Cast" attribute { name: "to" @@ -194,15 +214,15 @@ graph { } } node { - input: "17" - input: "20" - output: "21" + input: "19" + input: "22" + output: "23" op_type: "Div" } node { - input: "16" - input: "21" - output: "22" + input: "18" + input: "23" + output: "24" op_type: "Concat" attribute { name: "axis" @@ -212,8 +232,8 @@ graph { } node { input: "input" - input: "22" - output: "23" + input: "24" + output: "25" op_type: "Upsample" attribute { name: "mode" @@ -245,7 +265,7 @@ graph { } } output { - name: "23" + name: "25" type { tensor_type { elem_type: 1 diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index 94eac4298a8a7..3153605758be2 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -206,6 +206,86 @@ def forward(self, x): ops = {9 : ops, 10 : ops} check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10], training=False) + def test_interpolate(self): + class MyModel(torch.nn.Module): + def forward(self, x): + size = [v * 2 for v in x.size()[2:]] + return torch.nn.functional.interpolate(x, + size=size, + mode='nearest') + ops_9 = [{"op_name" : "Constant"}, + {"op_name" : "Shape"}, + {"op_name" : "Gather"}, + {"op_name" : "Constant"}, + {"op_name" : "Shape"}, + {"op_name" : "Gather"}, + {"op_name" : "Constant"}, + {"op_name" : "Mul"}, + {"op_name" : "Constant"}, + {"op_name" : "Mul"}, + {"op_name" : "Unsqueeze"}, + {"op_name" : "Unsqueeze"}, + {"op_name" : "Concat"}, + {"op_name" : "Constant"}, + {"op_name" : "Cast"}, + {"op_name" : "Shape"}, + {"op_name" : "Slice"}, + {"op_name" : "Cast"}, + {"op_name" : "Div"}, + {"op_name" : "Concat"}, + {"op_name" : "Upsample", + "attributes" : + [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}] + ops_10 = [{"op_name" : "Constant"}, + {"op_name" : "Shape"}, + {"op_name" : "Gather"}, + {"op_name" : "Constant"}, + {"op_name" : "Shape"}, + {"op_name" : "Gather"}, + {"op_name" : "Constant"}, + {"op_name" : "Mul"}, + {"op_name" : "Constant"}, + {"op_name" : "Mul"}, + {"op_name" : "Unsqueeze"}, + {"op_name" : "Unsqueeze"}, + {"op_name" : "Concat"}, + {"op_name" : "Constant"}, + {"op_name" : "Cast"}, + {"op_name" : "Shape"}, + {"op_name" : "Constant"}, + {"op_name" : "Constant"}, + {"op_name" : "Constant"}, + {"op_name" : "Slice"}, + {"op_name" : "Cast"}, + {"op_name" : "Div"}, + {"op_name" : "Concat"}, + {"op_name" : "Resize", + "attributes" : + [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}] + ops = {9 : ops_9, 10 : ops_10} + x = torch.randn(1, 2, 3, 4, requires_grad=True) + check_onnx_opsets_operator(MyModel(), x, ops, opset_versions=[9, 10]) + + class MyDynamicModel(torch.nn.Module): + def forward(self, x): + size = [v * 2 for v in x.size()[2:]] + # work around for now: turn the dynamic sizes into constant + size = [int(i) for i in size] + return torch.nn.functional.interpolate(x, + size=size, + mode='nearest') + ops_9 = [{"op_name" : "Constant"}, + {"op_name" : "Upsample", + "attributes" : + [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}] + ops_10 = [{"op_name" : "Constant"}, + {"op_name" : "Resize", + "attributes" : + [{"name": "mode", "s": ("nearest").encode(), "type": 3}]}] + ops = {9 : ops_9, 10 : ops_10} + x = torch.randn(20, 16, 50) + check_onnx_opsets_operator(MyDynamicModel(), x, ops, opset_versions=[9, 10]) + if __name__ == '__main__': run_tests() diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index f17b8fc1bc46a..142dfca6d8d28 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -503,10 +503,6 @@ def test_upsample_nearest(self): x = torch.randn(1, 2, 3, 4, requires_grad=True) self.assertONNX(lambda x: nn.functional.interpolate(x, scale_factor=2., mode='nearest'), x) - def test_upsample_bilinear(self): - x = torch.randn(1, 2, 3, 4, requires_grad=True) - self.assertONNX(lambda x: nn.functional.interpolate(x, scale_factor=2., mode='bilinear'), x) - def test_unsqueeze(self): x = torch.randn(3, 4, requires_grad=True) self.assertONNX(lambda x: x.unsqueeze(len(x.shape)), x) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index a31a1b7cc7818..07f850a59e9d0 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -479,7 +479,7 @@ def test_inception(self): def test_resnet(self): state_dict = model_zoo.load_url(model_urls['resnet50'], progress=False) self.run_model_test(resnet50(), train=False, batch_size=BATCH_SIZE, - state_dict=state_dict, atol=1e-6) + state_dict=state_dict, atol=1e-5) def test_squeezenet(self): sqnet_v1_1 = SqueezeNet(version=1.1) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index cc2ca5a97b89c..a2027d3db336d 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2523,9 +2523,9 @@ def _output_size(dim): # make scale_factor a tensor in tracing so constant doesn't get baked in if torch._C._get_tracing_state(): - return [(torch.floor(input.size(i + 2) * torch.tensor(float(scale_factors[i])))) for i in range(dim)] + return [(torch.floor((input.size(i + 2) * torch.tensor(float(scale_factors[i]))).float())) for i in range(dim)] else: - return [int(math.floor(int(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] + return [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] if mode in ('nearest', 'area'): if align_corners is not None: diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 8d6f6af7702b1..3c28b0b9d3452 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -187,13 +187,13 @@ def _try_get_scalar_type(*args): pass return None -def _slice_op(g, input, axes, starts, ends, steps=None, dynamic_slice=False): +def _slice_helper(g, input, axes, starts, ends, steps=None, dynamic_slice=False): if _export_onnx_opset_version == 9: - from torch.onnx.symbolic_opset9 import slice_op - return slice_op(g, input, axes, starts, ends) + from torch.onnx.symbolic_opset9 import _slice + return _slice(g, input, axes, starts, ends) if _export_onnx_opset_version == 10: - from torch.onnx.symbolic_opset10 import slice_op - return slice_op(g, input, axes, starts, ends, steps, dynamic_slice) + from torch.onnx.symbolic_opset10 import _slice + return _slice(g, input, axes, starts, ends, steps, dynamic_slice) # --------------------------------------------------------------------- # ONNX operator version diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index c7d3972e61078..b80e77ecca819 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -6,7 +6,7 @@ import torch.onnx.utils import torch.onnx.symbolic_helper as sym_help -from torch.onnx.symbolic_helper import parse_args, _unimplemented, _black_list_in_opset +from torch.onnx.symbolic_helper import parse_args, _unimplemented import torch.onnx.symbolic_opset9 @@ -18,18 +18,6 @@ # release on 04/24/19 -# Blacklist operators for this opset version. -# These operators have been updated in ONNX but not re-implemented here. -# It is very important to blacklist these operators to avoid exporting -# models with mixed versions of operators. -# TODO : add support for the blacklisted operators in black_listed_operators -black_listed_operators = ["upsample_nearest2d", "upsample_bilinear2d"] - -for black_listed_op in black_listed_operators: - vars()[black_listed_op] = _black_list_in_opset(black_listed_op) - - -# Add new operator here @parse_args('v', 'i', 'i', 'i', 'i') def topk(g, self, k, dim, largest, sorted, out=None): if out is not None: @@ -119,7 +107,34 @@ def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include avg_pool3d = _avg_pool('avg_pool3d', _triple) -def slice_op(g, input, axes, starts, ends, steps=None, dynamic_slice=False): +def _interpolate(name, dim, interpolate_mode): + def symbolic_fn(g, input, output_size, align_corners=None): + if align_corners: + return _unimplemented(name, "align_corners == True") + + output_size = sym_help._maybe_get_const(output_size, 'is') + if sym_help._is_value(output_size): + offset = 2 + offsets = g.op("Constant", value_t=torch.tensor([1. for i in range(offset)])) + dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"]) + divisor = sym_help._slice_helper(g, g.op("Shape", input), axes=[0], ends=[dim], starts=[offset]) + divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"]) + scale_dims = g.op("Div", dividend, divisor) + scales = g.op("Concat", offsets, scale_dims, axis_i=0) + else: + scales_constant = [1. if i < 2 else + float(output_size[-(dim - i)]) / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim)] + scales = g.op("Constant", value_t=torch.tensor(scales_constant)) + return g.op("Resize", input, scales, mode_s=interpolate_mode) + return symbolic_fn + +upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest") +upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest") +upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest") + + +def _slice(g, input, axes, starts, ends, steps=None, dynamic_slice=False): if dynamic_slice: starts = g.op("Unsqueeze", starts, axes_i=[0]) ends = g.op("Unsqueeze", ends, axes_i=[0]) @@ -150,12 +165,12 @@ def slice(g, self, dim, start, end, step): end = [sym_help._parse_arg(end, 'i')] dim = [sym_help._parse_arg(dim, 'i')] dynamic_slice = False - return sym_help._slice_op(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice) + return sym_help._slice_helper(g, self, axes=dim, starts=start, ends=end, steps=[step], dynamic_slice=dynamic_slice) @parse_args('v', 'is') def flip(g, input, dims): - return sym_help._slice_op(g, input, axes=dims, - starts=[-1] * len(dims), - ends=[-9223372036854775807] * len(dims), - steps=[-1] * len(dims)) + return sym_help._slice_helper(g, input, axes=dims, + starts=[-1] * len(dims), + ends=[-9223372036854775807] * len(dims), + steps=[-1] * len(dims)) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 329011da86863..14ea35a4b158e 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -194,7 +194,7 @@ def sign(g, self): return g.op("Sign", self) -def slice_op(g, input, axes, starts, ends): +def _slice(g, input, axes, starts, ends): assert len(starts) == len(ends) if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807: return input @@ -360,8 +360,8 @@ def select(g, self, dim, index): # of Gather in caffe2. We need to change this as soon as possible. # TODO: this breaks if index == -1 index_val = _parse_arg(index, 'i') - slice_node = sym_help._slice_op(g, self, axes=[dim], - starts=[index_val], ends=[index_val + 1]) + slice_node = sym_help._slice_helper(g, self, axes=[dim], + starts=[index_val], ends=[index_val + 1]) return g.op("Squeeze", slice_node, axes_i=[dim]) else: return g.op("Gather", self, index, axis_i=dim) @@ -538,8 +538,8 @@ def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode): kernel_shape_i=[1 for _ in range(ndims)], strides_i=[1 for _ in range(ndims)]) # convert indices to have non-flattened indices values - s = sym_help._slice_op(g, flattened_indices, axes=[2 + i for i in range(ndims)], - starts=tuple_fn(0), ends=tuple_fn(1)) + s = sym_help._slice_helper(g, flattened_indices, axes=[2 + i for i in range(ndims)], + starts=tuple_fn(0), ends=tuple_fn(1)) indices = sub(g, indices, s) return r, indices else: @@ -667,55 +667,33 @@ def replication_pad(g, input, padding): replication_pad3d = replication_pad -def upsample_nearest2d(g, input, output_size): - output_size = sym_help._maybe_get_const(output_size, 'is') - if sym_help._is_value(output_size): - offset = 2 - input_length = len(input.type().sizes()) - offsets = g.op("Constant", value_t=torch.tensor([1. for i in range(offset)])) - dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"]) - divisor = sym_help._slice_op(g, - g.op("Shape", input), - axes=[0], - starts=[offset], - ends=[input_length]) - divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"]) - scale_dims = g.op("Div", dividend, divisor) - scales = g.op("Concat", offsets, scale_dims, axis_i=0) - else: - height_scale = float(output_size[-2]) / input.type().sizes()[-2] - width_scale = float(output_size[-1]) / input.type().sizes()[-1] - scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale, width_scale])) - - return g.op("Upsample", input, scales, mode_s="nearest") - - -def upsample_bilinear2d(g, input, output_size, align_corners): - align_corners = sym_help._maybe_get_scalar(align_corners) - if align_corners: - return _unimplemented("upsample_bilinear2d", "align_corners == True") - - output_size = sym_help._maybe_get_const(output_size, 'is') - if sym_help._is_value(output_size): - offset = 2 - input_length = len(input.type().sizes()) - offsets = g.op("Constant", value_t=torch.tensor([1. for i in range(offset)])) - dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"]) - divisor = sym_help._slice_op(g, - g.op("Shape", input), - axes=[0], - starts=[offset], - ends=[input_length]) - divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"]) - scale_dims = g.op("Div", dividend, divisor) - scales = g.op("Concat", offsets, scale_dims, axis_i=0) - else: - height_scale = float(output_size[-2]) / input.type().sizes()[-2] - width_scale = float(output_size[-1]) / input.type().sizes()[-1] - scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale, - width_scale])) - return g.op("Upsample", input, scales, - mode_s="linear") +def _interpolate(name, dim, interpolate_mode): + def symbolic_fn(g, input, output_size, align_corners=None): + align_corners = sym_help._maybe_get_scalar(align_corners) + if align_corners: + return _unimplemented(name, "align_corners == True") + + output_size = sym_help._maybe_get_const(output_size, 'is') + if sym_help._is_value(output_size): + offset = 2 + offsets = g.op("Constant", value_t=torch.tensor([1. for i in range(offset)])) + dividend = g.op("Cast", output_size, to_i=sym_help.cast_pytorch_to_onnx["Float"]) + divisor = sym_help._slice_helper(g, g.op("Shape", input), axes=[0], ends=[dim], starts=[offset]) + divisor = g.op("Cast", divisor, to_i=sym_help.cast_pytorch_to_onnx["Float"]) + scale_dims = g.op("Div", dividend, divisor) + scales = g.op("Concat", offsets, scale_dims, axis_i=0) + else: + scales_constant = [1. if i < 2 else + float(output_size[-(dim - i)]) / float(input.type().sizes()[-(dim - i)]) + for i in range(0, dim)] + scales = g.op("Constant", value_t=torch.tensor(scales_constant)) + return g.op("Upsample", input, scales, mode_s=interpolate_mode) + return symbolic_fn + + +upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest") +upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest") +upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest") def wrap_logical_op_with_cast_to(to_type): @@ -1157,7 +1135,7 @@ def slice(g, self, dim, start, end, step): start = _parse_arg(start, 'i') end = _parse_arg(end, 'i') dim = _parse_arg(dim, 'i') - return sym_help._slice_op(g, self, axes=[dim], starts=[start], ends=[end]) + return sym_help._slice_helper(g, self, axes=[dim], starts=[start], ends=[end]) @parse_args('v', 'f', 'f') @@ -1302,7 +1280,7 @@ def _generic_rnn(g, variant, input, initial_states, all_weights, has_biases, reform_permutation = [(0, 1), (3, 4), (1, 3)] def reform_weights(g, w, n, intervals): - slices = [sym_help._slice_op(g, w, axes=[0], starts=[x * n], ends=[y * n]) for x, y in intervals] + slices = [sym_help._slice_helper(g, w, axes=[0], starts=[x * n], ends=[y * n]) for x, y in intervals] return g.op('Concat', *slices, axis_i=0) def transform_weights(layer_index): @@ -1316,7 +1294,7 @@ def transform_weights(layer_index): return tuple(g.op('Unsqueeze', x, axes_i=[0]) for x in (weight_ih, weight_hh, bias_concat)) def retrieve_state(x, start, end): - return x if num_layers == 1 else sym_help._slice_op(g, x, axes=[0], starts=[start], ends=[end]) + return x if num_layers == 1 else sym_help._slice_helper(g, x, axes=[0], starts=[start], ends=[end]) for i in range(num_layers): if unidirectional: @@ -1548,7 +1526,7 @@ def isnan(g, input): @parse_args('v', 'i', 'i', 'i') def narrow(g, input, dim, start, length): - return sym_help._slice_op(g, input, axes=[dim], starts=[start], ends=[start + length]) + return sym_help._slice_helper(g, input, axes=[dim], starts=[start], ends=[start + length]) def argmax(g, input, dim, keepdim):