Skip to content

Commit

Permalink
Addressed PR review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed May 3, 2023
1 parent 5e1bf10 commit 39d59f6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ void upsample_avx_bilinear_uint8(
at::Tensor buffer_horiz, buffer_vert;
// Minor optimization: we can avoid allocating an extra buffer if we're performing
// horizontal-only or vertical-only interpolation, and if the tensor doesn't
// need unpacking
if (need_horizontal && !(skip_packing && !need_vertical)) {
// need repacking
if (need_horizontal && (need_vertical || !skip_packing)) {
auto c = (skip_unpacking) ? num_channels : 4;
buffer_horiz = at::empty({c, yin, xout}, input.options());
}
Expand All @@ -379,7 +379,7 @@ void upsample_avx_bilinear_uint8(
at::Tensor unpacked_output;

if (need_horizontal) {
at::Tensor unpacked_output_temp = (skip_packing && !need_vertical) ? output[i] : buffer_horiz;
at::Tensor unpacked_output_temp = (need_vertical || !skip_packing) ? buffer_horiz : output[i];

if (skip_unpacking && num_channels == 3) {
ImagingResampleHorizontal<3>(
Expand Down
4 changes: 3 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9698,7 +9698,9 @@ def test_upsamplingBiLinear2d_consistency(
input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias
)

# Check if output is CF for unsqueezed 3d CL tensor
# FIXME if-clause shows the current behaviour which is definitely unexpected.
# Ideally we want to fix it such that both the ui8 and f32 outputs are also channels_last
# See for more details: https://github.com/pytorch/pytorch/pull/100373
if check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last:
self.assertTrue(input_ui8.is_contiguous(memory_format=torch.channels_last))
self.assertTrue(output_ui8.is_contiguous())
Expand Down

0 comments on commit 39d59f6

Please sign in to comment.