Skip to content

Commit

Permalink
Fixing interpolate on uint8 unsqueezed 3D CL tensor (#100258)
Browse files Browse the repository at this point in the history
Description:

- Fixed a bug with memory format issue:

When input is channels last 4d tensor that was produced as following
```
t = torch.ones(1, 3, 32, 32).contiguous(memory_format=torch.channels_last)
t = t[0]
t = t[None, ...]
```
upsampling will produce output with channels first memory format but our avx code does not take that into account.

Here is a repro code to show that nightly is broken for this particular case:
```python
import torch

torch.manual_seed(0)

input = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8).contiguous(memory_format=torch.channels_last)
input = input[0]
input = input[None, ...]

assert input.is_contiguous(memory_format=torch.channels_last)

output = torch.nn.functional.interpolate(input, (224, 224), mode="bilinear", antialias=True)
expected = torch.nn.functional.interpolate(input.float(), (224, 224), mode="bilinear", antialias=True)

assert output.is_contiguous()
assert expected.is_contiguous()

torch.testing.assert_close(expected, output.float(), atol=1, rtol=1)
# >
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
#   File "/pytorch/torch/testing/_comparison.py", line 1511, in assert_close
#     raise error_metas[0].to_error(msg)
# AssertionError: Tensor-likes are not close!
#
# Mismatched elements: 14120 / 150528 (9.4%)
# Greatest absolute difference: 214.6112518310547 at index (0, 1, 152, 13) (up to 1 allowed)
# Greatest relative difference: 17.005144119262695 at index (0, 2, 26, 2) (up to 1 allowed)
```

- Also renamed needs_unpacking by skip_unpacking

Pull Request resolved: #100258
Approved by: https://github.com/NicolasHug
  • Loading branch information
vfdev-5 authored and pytorchmergebot committed May 4, 2023
1 parent 9b3552e commit ff974cd
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
35 changes: 19 additions & 16 deletions aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,24 @@ void pack_rgb(
const at::Tensor& unpacked_tensor, // IN
const at::Tensor& packed_tensor // OUT
) {
// Convert from unpacked channels last 4-channels tensor into original data layout.
// Convert from unpacked channels last 3-channels or 4-channels tensor into original data layout.

constexpr int rgba_size = 4;
uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr<uint8_t>();
uint8_t* packed = (uint8_t*)packed_tensor.data_ptr<uint8_t>();
auto num_pixels = packed_tensor.size(1) * packed_tensor.size(2);
auto num_channels = packed_tensor.size(0);

auto unpacked_increment = unpacked_tensor.size(0);
auto packed_increment = packed_tensor.stride(2);
auto packed_stride = packed_tensor.stride(0);

TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4);

for (const auto i C10_UNUSED : c10::irange(num_pixels)) {
for (const auto j : c10::irange(num_channels)) {
packed[j * packed_stride] = unpacked[j];
}
unpacked += rgba_size;
unpacked += unpacked_increment;
packed += packed_increment;
}
}
Expand Down Expand Up @@ -323,11 +325,12 @@ void upsample_avx_bilinear_uint8(
std::vector<at::Tensor> horiz_indices_weights, vert_indices_weights;
unsigned int horiz_weights_precision, vert_weights_precision;

bool needs_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
bool skip_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
bool skip_packing = (num_channels == 3 || num_channels == 4) && output.is_contiguous(at::MemoryFormat::ChannelsLast);

if (need_horizontal) {
int interp_dim = 3;
auto stride = (needs_unpacking) ? num_channels : 4;
auto stride = (skip_unpacking) ? num_channels : 4;
std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
F::compute_indices_int16_weights_aa(
/*input_size=*/xin,
Expand All @@ -343,7 +346,7 @@ void upsample_avx_bilinear_uint8(

if (need_vertical) {
int interp_dim = 2;
auto stride = (needs_unpacking) ? num_channels * xout : 4 * xout;
auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout;
std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
F::compute_indices_int16_weights_aa(
/*input_size=*/yin,
Expand All @@ -360,25 +363,25 @@ void upsample_avx_bilinear_uint8(
at::Tensor buffer_horiz, buffer_vert;
// Minor optimization: we can avoid allocating an extra buffer if we're performing
// horizontal-only or vertical-only interpolation, and if the tensor doesn't
// need unpacking
if (need_horizontal && !(needs_unpacking && !need_vertical)) {
auto c = (needs_unpacking) ? num_channels : 4;
// need repacking
if (need_horizontal && (need_vertical || !skip_packing)) {
auto c = (skip_unpacking) ? num_channels : 4;
buffer_horiz = at::empty({c, yin, xout}, input.options());
}
if (need_vertical && !needs_unpacking) {
auto c = (needs_unpacking) ? num_channels : 4;
if (need_vertical && !skip_packing) {
auto c = (skip_unpacking) ? num_channels : 4;
buffer_vert = at::empty({c, yout, xout}, input.options());
}

for (const auto i : c10::irange(batch_size)) {

at::Tensor unpacked_input = (needs_unpacking) ? input[i] : unpack_rgb(input[i]);
at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]);
at::Tensor unpacked_output;

if (need_horizontal) {
at::Tensor unpacked_output_temp = (needs_unpacking && !need_vertical) ? output[i] : buffer_horiz;
at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i];

if (needs_unpacking && num_channels == 3) {
if (skip_unpacking && num_channels == 3) {
ImagingResampleHorizontal<3>(
unpacked_output_temp,
unpacked_input,
Expand All @@ -396,7 +399,7 @@ void upsample_avx_bilinear_uint8(
unpacked_output = unpacked_input = unpacked_output_temp;
}
if (need_vertical) {
unpacked_output = (needs_unpacking) ? output[i] : buffer_vert;
unpacked_output = (skip_packing) ? output[i] : buffer_vert;

ImagingResampleVertical(
unpacked_output,
Expand All @@ -409,7 +412,7 @@ void upsample_avx_bilinear_uint8(

TORCH_INTERNAL_ASSERT(unpacked_output.defined());

if (!needs_unpacking) {
if (!skip_packing) {
pack_rgb(unpacked_output, output[i]);
}
}
Expand Down
22 changes: 21 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9673,14 +9673,22 @@ def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format):
@parametrize_test("align_corners", [True, False])
@parametrize_test("num_channels", [3, 5])
@parametrize_test("output_size", [32, 600])
def test_upsamplingBiLinear2d_consistency(self, device, memory_format, antialias, align_corners, num_channels, output_size):
@parametrize_test("check_as_unsqueezed_3d_tensor", [True, False])
def test_upsamplingBiLinear2d_consistency(
self, device, memory_format, antialias, align_corners, num_channels, output_size, check_as_unsqueezed_3d_tensor
):
if torch.device(device).type == "cuda":
raise SkipTest("CUDA implementation is not yet supporting uint8")

mode = "bilinear"
# Check if Max Abs Error between resized input_uint8 and resized input_float is smaller than a tolerated value, e.g. 1.0
input_ui8 = torch.randint(0, 256, size=(1, num_channels, 400, 400), dtype=torch.uint8, device=device)
input_ui8 = input_ui8.contiguous(memory_format=memory_format)

if check_as_unsqueezed_3d_tensor:
input_ui8 = input_ui8[0, ...]
input_ui8 = input_ui8[None, ...]

input_f32 = input_ui8.float()

output_f32 = F.interpolate(
Expand All @@ -9690,6 +9698,18 @@ def test_upsamplingBiLinear2d_consistency(self, device, memory_format, antialias
input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias
)

# FIXME if-clause shows the current behaviour which is definitely unexpected.
# Ideally we want to fix it such that both the ui8 and f32 outputs are also channels_last
# See for more details: https://github.com/pytorch/pytorch/pull/100373
if check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last:
self.assertTrue(input_ui8.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(output_ui8.is_contiguous())
self.assertTrue(output_f32.is_contiguous())
else:
self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format))
self.assertTrue(output_ui8.is_contiguous(memory_format=memory_format))
self.assertTrue(output_f32.is_contiguous(memory_format=memory_format))

mae_tol = 0.5
max_abs_err_tol = 1.0
num_wrong_pixels_tol = 5
Expand Down

0 comments on commit ff974cd

Please sign in to comment.