diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 34a3da5e7ebb..54cbc12ae1e9 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -4311,6 +4311,47 @@ TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2D) { cpp_test::GetIgnoredCounters()); } +TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DWithScale) { + struct ImageInfo { + int batch_size; + int h; + int w; + int chans; + double scale_h; + double scale_w; + }; + + /* clang-format off */ + std::vector inputs = { + {/*batch_size=*/2, /*h=*/5, /*w=*/5, /*chans=*/2, /*scale_h*/8.0/5, /*scale_w*/8.0/5}, + {/*batch_size=*/2, /*h=*/1335, /*w=*/1335, /*chans=*/3, /*scale_h*/255.0/1335, /*scale_w*/255.0/1335}, + {/*batch_size=*/2, /*h=*/255, /*w=*/255, /*chans=*/3, /*scale_h*/1335.0/255, /*scale_w*/1335.0/255}, + {/*batch_size=*/2, /*h=*/254, /*w=*/243, /*chans=*/3, /*scale_h*/784.0/254, /*scale_w*/214.0/243} + }; + /* clang-format on */ + + for (const auto& img_info : inputs) { + for (bool align_corners : {true, false}) { + torch::Tensor input = torch::rand( + {img_info.batch_size, img_info.chans, img_info.h, img_info.w}, + torch::TensorOptions(torch::kFloat)); + ForEachDevice([&](const torch::Device& device) { + torch::Tensor xla_input = CopyToDevice(input, device); + torch::Tensor result = torch::upsample_bilinear2d( + input, c10::nullopt, align_corners, + at::ArrayRef{img_info.scale_h, img_info.scale_w}); + torch::Tensor xla_result = torch::upsample_bilinear2d( + xla_input, c10::nullopt, align_corners, + at::ArrayRef{img_info.scale_h, img_info.scale_w}); + AllClose(result, xla_result, /*rtol=*/1e-4, /*atol=*/1e-4); + }); + } + } + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + ExpectCounterChanged("xla::upsample_bilinear2d", + cpp_test::GetIgnoredCounters()); +} + TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DBackward) { int batch_size = 2; int h = 5; diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d12955e501af..906d4e0fae85 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2904,16 +2904,21 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d( c10::optional scales_h, c10::optional scales_w) { TORCH_LAZY_FN_COUNTER("xla::"); XLATensorPtr self_tensor = bridge::GetXlaTensor(self); + absl::Span input_dims = + self_tensor->shape().get().dimensions(); + std::vector scaled_output_size = + torch::lazy::ToVector(output_size); if ((scales_h && *scales_h != 1.0) || (scales_w && *scales_w != 1.0)) { - return at::native::call_fallback_fn< - &xla_cpu_fallback, ATEN_OP(upsample_bilinear2d)>::call(self, - output_size, - align_corners, - scales_h, - scales_w); + scaled_output_size = GetOutputSizeWithScale(input_dims, scales_h, scales_w, + scaled_output_size); + if (!output_size.empty()) { + XLA_CHECK(scaled_output_size.at(0) == output_size.at(0) && + scaled_output_size.at(1) == output_size.at(1)) + << "Inferred output size and output_size from upstream are different"; + } } return bridge::AtenFromXlaTensor(tensor_methods::upsample_bilinear2d( - self_tensor, torch::lazy::ToVector(output_size), align_corners)); + self_tensor, scaled_output_size, align_corners)); } at::Tensor XLANativeFunctions::upsample_bilinear2d_backward(