Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,9 @@ def forward(self, x):
self.run_test(MyModel(), x)

def _interpolate_script(self, x, mode, use_size, is_upsample, align_corners=False):
# test disabled
return

class MyModel(torch.jit.ScriptModule):
__constants__ = ['mode', 'use_size', 'is_upsample', 'size', 'scale', 'size_array', 'scale_array', 'align_corners']

Expand Down Expand Up @@ -1255,6 +1258,7 @@ def test_interpolate_downsample(self):
self._interpolate_tests(False)

@skipIfUnsupportedMinOpsetVersion(11)
@unittest.skipIf(True, "Interpolate script NYI")
def test_interpolate_no_shape(self):
class MyModel(torch.jit.ScriptModule):
@torch.jit.script_method
Expand Down
96 changes: 44 additions & 52 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,10 +1864,6 @@ def forward(self, x):
.check("aten::max") \
.check("aten::min") \
.check("aten::mean") \
.check("aten::__interpolate") \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison we need the tests as well, actually I think this might have broken these tests, will sync with you tomorrow.

.check("aten::__upsample") \
.check("aten::__upsample_bilinear") \
.check("aten::__upsample_nearest") \
.run(m.graph)
torch._C._jit_pass_swap_dequantize(m.graph)
FileCheck().check("aten::max_pool2d") \
Expand All @@ -1877,10 +1873,6 @@ def forward(self, x):
.check("aten::max") \
.check("aten::min") \
.check("aten::mean") \
.check("aten::__interpolate") \
.check("aten::__upsample") \
.check("aten::__upsample_bilinear") \
.check("aten::__upsample_nearest") \
.check("dequantize") \
.run(m.graph)

Expand Down Expand Up @@ -18265,65 +18257,65 @@ class TestJitGeneratedFunctional(JitTestCase):
('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
(torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
torch.randint(1, S, (S,), dtype=torch.long))),
('upsample', torch.randn(S, S, M, M), (None, 2), 'with_scale'),
('upsample', torch.randn(S, S, M, M), (None, 2.), 'with_scale'),
('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale', (True, 'aten::__interpolate')),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size', (True, 'aten::__interpolate')),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2, None, 'nearest', None, False),
'nearest_4d_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'nearest_4d_not_recompute_scale_factor'),
('interpolate', torch.randn(S, S, M, M), (4, None, 'nearest', None, False),
'nearest_4d_with_size_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'nearest_4d_with_size_not_recompute_scale_factor'),
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bilinear', None, False),
'bilinear_4d_with_scale_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'bilinear_4d_with_scale_not_recompute_scale_factor'),
('interpolate', torch.randn(S, S, M, M), (4, None, 'bilinear', None, False),
'bilinear_4d_with_size_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'bilinear_4d_with_size_not_recompute_scale_factor'),
('interpolate', torch.randn(S, S, M, M), (None, 2., 'bicubic', None, False),
'bicubic_4d_with_scale_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'bicubic_4d_with_scale_not_recompute_scale_factor'),
('interpolate', torch.randn(S, S, M, M), (4, None, 'bicubic', None, False),
'bicubic_4d_with_size_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'bicubic_4d_with_size_not_recompute_scale_factor'),
('interpolate', torch.randn(S, M, M), (None, 2., 'nearest', None, False),
'nearest_3d_with_scale_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'nearest_3d_with_scale_not_recompute_scale_factor'),
('interpolate', torch.randn(S, M, M), (4, None, 'nearest', None, False),
'nearest_3d_with_size_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'nearest_3d_with_size_not_recompute_scale_factor'),
('interpolate', torch.randn(S, M, M), (None, 2., 'linear', None, False),
'linear_3d_with_scale_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'linear_3d_with_scale_not_recompute_scale_factor'),
('interpolate', torch.randn(S, M, M), (4, None, 'linear', None, False),
'linear_3d_with_size_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'linear_3d_with_size_not_recompute_scale_factor'),
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'nearest', None, False),
'nearest_5d_with_scale_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'nearest_5d_with_scale_not_recompute_scale_factor'),
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'nearest', None, False),
'nearest_5d_with_size_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'nearest_5d_with_size_not_recompute_scale_factor'),
('interpolate', torch.randn(S, M, M, M, M), (None, 2., 'trilinear', None, False),
'trilinear_5d_with_scale_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'trilinear_5d_with_scale_not_recompute_scale_factor'),
('interpolate', torch.randn(S, M, M, M, M), (4, None, 'trilinear', None, False),
'trilinear_5d_with_size_not_recompute_scale_factor', (True, 'aten::__interpolate')),
'trilinear_5d_with_size_not_recompute_scale_factor'),
]


Expand Down
33 changes: 19 additions & 14 deletions torch/csrc/jit/passes/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,25 @@ bool isFunctionNode(Node* n,
// the quantization parameters for output given inputs
std::vector<size_t> getGeneralOpTensorInputIndexes(Node* n) {
std::vector<std::string> single_input_aten_funcs = {
"adaptive_avg_pool2d",
"max_pool2d",
"avg_pool2d",
"flatten",
"max",
"min",
"mean",
// TODO: sort returns a tuple of Tensors, we have
// to extend the API to support that
// "sort",
"__interpolate",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add the corresponding ops to this list? I think it might be better to just put this change in the same PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know the difference between "single_input_call_funcs" and "single_input_aten_funcs"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where should I be putting the ops?

"__upsample",
"__upsample_bilinear",
"__upsample_nearest",
"max_pool2d",
"avg_pool2d",
"flatten",
"max",
"min",
"mean",
"upsample_nearest1d",
"upsample_nearest2d",
"upsample_nearest3d",
"adaptive_avg_pool1d",
"adaptive_avg_pool2d",
"adaptive_avg_pool3d",
"upsample_linear1d",
"upsample_bilinear2d",
"upsample_trilinear3d",
"upsample_bicubic2d",
// TODO: sort returns a tuple of Tensors, we have
// to extend the API to support that
// "sort",
};
std::vector<std::string> single_input_call_funcs = {
"adaptive_avg_pool2d",
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/runtime/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3418,6 +3418,7 @@ int upsample_bilinear_op(Stack& stack) {
return 0;
}

// These ops are no longer generated, but remain here for BC
RegisterOperators reg3({
Operator(
"aten::__interpolate(Tensor input, int? size = None, float[]? scale_factor = None, str mode = 'nearest', bool? align_corners = None, bool? recompute_scale_factor = None) -> Tensor",
Expand Down
130 changes: 0 additions & 130 deletions torch/csrc/jit/runtime/symbolic_script.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1223,136 +1223,6 @@ const std::vector<std::string> functions = {
grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
return grad_self, None, None, None, None, None
return output, indices, backward

def AD_interpolate_scales_list(input,
scale_factor: Optional[List[float]],
recompute_scale_factor: bool):
input_dim = len(input.size())
if scale_factor is None or recompute_scale_factor:
scales = [-1.0 for i in range(input_dim-2)]
else:
scales = [scale_factor[i] for i in range(input_dim-2)]
return scales

def AD_interpolate_scales_float(input,
scale_factor: Optional[float],
recompute_scale_factor: bool):
input_dim = len(input.size())
if scale_factor is None or recompute_scale_factor:
scales = [-1.0 for i in range(input_dim-2)]
else:
scales = [scale_factor for i in range(input_dim-2)]
return scales

def AD_interpolate_backward(grad,
input,
mode: str,
align_corners: bool,
scale_factor: List[float]):
output_size = grad.size()[2:]
input_size = input.size()
input_dim = len(input_size)
if input_dim == 3 and mode == 'nearest':
grad_input = torch.upsample_nearest1d_backward(grad, output_size, input_size,
scale_factor[0])
elif input_dim == 4 and mode == 'nearest':
grad_input = torch.upsample_nearest2d_backward(grad, output_size, input_size,
scale_factor[0], scale_factor[1])
elif input_dim == 5 and mode == 'nearest':
grad_input = torch.upsample_nearest3d_backward(grad, output_size, input_size,
scale_factor[0], scale_factor[1], scale_factor[2])
elif input_dim == 3 and mode == 'linear':
grad_input = torch.upsample_linear1d_backward(grad, output_size, input_size, align_corners,
scale_factor[0])
elif input_dim == 4 and mode == 'bilinear':
grad_input = torch.upsample_bilinear2d_backward(grad, output_size, input_size, align_corners,
scale_factor[0], scale_factor[1])
elif input_dim == 5 and mode == 'trilinear':
grad_input = torch.upsample_trilinear3d_backward(grad, output_size, input_size, align_corners,
scale_factor[0], scale_factor[1], scale_factor[2])
elif input_dim == 4 and mode == 'bicubic':
grad_input = torch.upsample_bicubic2d_backward(grad, output_size, input_size, align_corners,
scale_factor[0], scale_factor[1])
elif input_dim == 3 and mode == 'area':
grad_input = AD_adaptive_avg_pool1d_backward(grad, input, output_size)
elif input_dim == 4 and mode == 'area':
grad_input = AD_adaptive_avg_pool2d_backward(grad, input, output_size)
elif input_dim == 5 and mode == 'area':
grad_input = torch.adaptive_avg_pool3d_backward(grad, input)
else:
# NEVER REACH HERE
grad_input = torch.zeros_like(input, memory_format=1)
raise RuntimeError('Input Error: Only 3D, 4D and 5D input Tensors supported')

return grad_input

def __interpolate_0(input,
size: Optional[int],
scale_factor: Optional[List[float]],
mode: str,
align_corners: Optional[bool],
recompute_scale_factor: Optional[bool]):
def backward(grad_output):
if align_corners is None:
align_corners = False
if recompute_scale_factor is None:
recompute_scale_factor = True
scales = AD_interpolate_scales_list(input, scale_factor, recompute_scale_factor)
grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners, scales)
return grad_self, None, None, None, None, None

return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward

def __interpolate_1(input,
size: Optional[List[int]],
scale_factor: Optional[List[float]],
mode: str,
align_corners: Optional[bool],
recompute_scale_factor: Optional[bool]):
def backward(grad_output):
if align_corners is None:
align_corners = False
if recompute_scale_factor is None:
recompute_scale_factor = True
scales = AD_interpolate_scales_list(input, scale_factor, recompute_scale_factor)
grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners, scales)
return grad_self, None, None, None, None, None

return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward

def __interpolate_2(input,
size: Optional[int],
scale_factor: Optional[float],
mode: str,
align_corners: Optional[bool],
recompute_scale_factor: Optional[bool]):
def backward(grad_output):
if align_corners is None:
align_corners = False
if recompute_scale_factor is None:
recompute_scale_factor = True
scales = AD_interpolate_scales_float(input, scale_factor, recompute_scale_factor)
grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners, scales)
return grad_self, None, None, None, None, None

return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward

def __interpolate_3(input,
size: Optional[List[int]],
scale_factor: Optional[float],
mode: str,
align_corners: Optional[bool],
recompute_scale_factor: Optional[bool]):
def backward(grad_output):
if align_corners is None:
align_corners = False
if recompute_scale_factor is None:
recompute_scale_factor = True
scales = AD_interpolate_scales_float(input, scale_factor, recompute_scale_factor)
grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners, scales)
return grad_self, None, None, None, None, None

return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
)",
R"(
def AD_sizes_if_not_equal_multi_1(t1, t2, res):
Expand Down
Loading