Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aten/src/ATen/native/UpSample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
TORCH_CHECK(static_cast<int64_t>(scale_factors->size()) == spatial_dimensions);
c10::SmallVector<int64_t, 3> ret;
for (const auto i : c10::irange(spatial_dimensions)) {
ret.push_back(static_cast<double>(input_size[i+2]) * scale_factors.value()[i]);
// we perform round (i.e. int(0.5 + x)) to match opencv, scipy, scikit-image output size
ret.push_back(0.5 + static_cast<double>(input_size[i+2]) * scale_factors.value()[i]);
}
return ret;
}
Expand Down
61 changes: 44 additions & 17 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10575,6 +10575,31 @@ def test_channel_shuffle(self):
y = y.contiguous(memory_format=torch.contiguous_format)
self.assertEqual(y, y_ref)

def test_upsamplingNearest1d(self):
m = nn.Upsample(size=4, mode='nearest')
in_t = torch.ones(1, 1, 2)
in_uint8_t = torch.ones(1, 1, 2, dtype=torch.uint8)
with warnings.catch_warnings(record=True) as w:
out_t = m(in_t)
out_uint8_t = m(in_uint8_t)
self.assertEqual(torch.ones(1, 1, 4), out_t.data)
self.assertEqual(torch.ones(1, 1, 4, dtype=torch.uint8), out_uint8_t.data)

input = torch.randn(1, 1, 2, requires_grad=True)
gradcheck(lambda x: F.interpolate(x, 4, mode='nearest'), [input])

# Check https://github.com/pytorch/pytorch/issues/62396
test_scales = [0.1234, 0.9999, 1.8]
isize = 32
expected_out_sizes = [int(0.5 + s * isize) for s in test_scales]
t_in = torch.randint(0, 256, size=(1, 1, isize), dtype=torch.float)
for r in [True, False]:
for s, expected_osize in zip(test_scales, expected_out_sizes):
t_out = F.interpolate(
t_in, scale_factor=s, recompute_scale_factor=r, mode="nearest"
)
self.assertEqual(t_out.shape[-1], expected_osize)

def test_upsamplingLinear1d(self):
for align_corners in [True, False]:
for recompute_scale_factor in [True, False]:
Expand Down Expand Up @@ -10643,17 +10668,20 @@ def test_upsamplingBicubic2d(self):

def test_upsampling_not_recompute_scale_factor(self):
# test output against known input: result must match opencv
# opencv gives output of shape (5, 5, 2)
in_t = torch.arange(8.).view(1, 2, 2, 2)
expected_out_t = torch.tensor(
[[[[-0.32725, -0.08843, 0.37933, 0.79744],
[0.15039, 0.38921, 0.85697, 1.27508],
[1.08591, 1.32473, 1.79249, 2.21060],
[1.92213, 2.16095, 2.62871, 3.04682]],

[[3.67275, 3.91157, 4.37933, 4.79744],
[4.15039, 4.38921, 4.85697, 5.27508],
[5.08591, 5.32473, 5.79249, 6.21060],
[5.92213, 6.16095, 6.62871, 7.04682]]]])
[[[[-0.32725, -0.08843, 0.37933, 0.79744, 0.88296],
[0.15039, 0.38921, 0.85697, 1.27508, 1.3606],
[1.08591, 1.32473, 1.79249, 2.21060, 2.29613],
[1.92213, 2.16095, 2.62871, 3.04682, 3.13234],
[2.09318, 2.33200, 2.79976, 3.21787, 3.30340]],

[[3.67275, 3.91157, 4.37933, 4.79744, 4.88296],
[4.15039, 4.38921, 4.85697, 5.27508, 5.36060],
[5.08591, 5.32473, 5.79249, 6.21060, 6.29613],
[5.92213, 6.16095, 6.62871, 7.04682, 7.13234],
[6.09318, 6.33200, 6.79976, 7.21787, 7.30340]]]])
if IS_PPC:
# Both OpenCV and PyTorch give a slightly different result on PPC
expected_out_t = torch.tensor(
Expand All @@ -10668,7 +10696,10 @@ def test_upsampling_not_recompute_scale_factor(self):
[5.92212, 6.16094, 6.62870, 7.04680]]]])
out_t = F.interpolate(in_t, scale_factor=2.3, mode='bicubic', align_corners=False, recompute_scale_factor=False)
torch.set_printoptions(precision=5)
self.assertEqual(out_t, expected_out_t, atol=1e-4, rtol=0)
if IS_PPC:
self.assertEqual(out_t[..., :3, :3], expected_out_t[..., :3, :3], atol=1e-4, rtol=0)
else:
self.assertEqual(out_t, expected_out_t, atol=1e-4, rtol=0)

device_list = ['cpu']
if TEST_CUDA:
Expand All @@ -10681,7 +10712,7 @@ def test_upsampling_not_recompute_scale_factor(self):
for scale_factor in [0.6, 1.6, 2.3]:
in_t = torch.ones(2, 2, 2, 2).to(device)
out_t = F.interpolate(in_t, scale_factor=scale_factor, **kwargs)
out_size = int(math.floor(in_t.shape[-1] * scale_factor))
out_size = int(math.floor(0.5 + in_t.shape[-1] * scale_factor))
self.assertEqual(torch.ones(2, 2, out_size, out_size), out_t.data, atol=1e-5, rtol=0)

input = torch.randn(2, 2, 2, 2, requires_grad=True)
Expand Down Expand Up @@ -14806,19 +14837,15 @@ def test_upsamplingNearestExact1d_rescale(self, device):
# Checks https://github.com/pytorch/pytorch/issues/62237
isize = 20
in_t = torch.arange(isize, dtype=torch.float, device=device).unsqueeze(0).unsqueeze(0)
# for s in [1.00001, 0.99999]: # 0.9999 case is broken
# See issue: https://github.com/pytorch/pytorch/issues/62396
for s in [1.00001, ]:
for s in [1.00001, 0.99999]:
out_t = F.interpolate(
in_t, scale_factor=s, recompute_scale_factor=False, mode="nearest-exact"
)
expected_out = in_t
self.assertEqual(out_t, expected_out, msg=f"scale: {s}")

# checks data duplication if output_size == 2 * input_size
# for s in [2.00001, 1.99999]: # 1.99999 case is broken
# See issue: https://github.com/pytorch/pytorch/issues/62396
for s in [2.00001, ]:
for s in [2.00001, 1.99999]:
out_t = F.interpolate(
in_t, scale_factor=s, recompute_scale_factor=False, mode="nearest-exact"
)
Expand Down
6 changes: 4 additions & 2 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3814,13 +3814,15 @@ def interpolate(input: Tensor, size: Optional[int] = None, scale_factor: Optiona
# The C++ code will recompute it based on the (integer) output size.
if not torch.jit.is_scripting() and torch._C._get_tracing_state():
# make scale_factor a tensor in tracing so constant doesn't get baked in
# we perform round (i.e. floor(0.5 + x)) to match opencv, scipy, scikit-image output size
output_size = [
(torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
(torch.floor(0.5 + (input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
for i in range(dim)
]
else:
assert scale_factors is not None
output_size = [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
# we perform round (i.e. floor(0.5 + x)) to match opencv, scipy, scikit-image output size
output_size = [int(math.floor(0.5 + float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
scale_factors = None

if input.dim() == 3 and mode == "nearest":
Expand Down