@@ -4266,24 +4266,52 @@ TEST_F(AtenXlaTensorTest, TestUpsampleNearest2DWithScale) {
42664266}
42674267
42684268TEST_F (AtenXlaTensorTest, TestUpsampleNearest2DBackwardWithScale) {
4269- int batch_size = 2 ;
4270- int h = 5 ;
4271- int w = 5 ;
4272- int chans = 2 ;
4273- double scale_h = 2.5 ;
4274- double scale_w = 3.4 ;
4275- auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4276- return torch::upsample_nearest2d (inputs[0 ], c10::nullopt ,
4277- at::ArrayRef<double >{scale_h, scale_w});
4269+ struct ImageInfo {
4270+ int batch_size;
4271+ int h;
4272+ int w;
4273+ int chans;
4274+ double scale_h;
4275+ double scale_w;
42784276 };
4279- ForEachDevice ([&](const torch::Device& device) {
4280- TestBackward (
4281- {torch::rand ({batch_size, chans, h, w},
4282- torch::TensorOptions (torch::kFloat ).requires_grad (true ))},
4283- device, testfn);
4284- });
4285- ExpectCounterChanged (" xla::upsample_nearest2d_backward" ,
4286- cpp_test::GetIgnoredCounters ());
4277+
4278+ /* clang-format off */
4279+ std::vector<ImageInfo> inputs = {
4280+ {/* batch_size=*/ 2 , /* h=*/ 5 , /* w=*/ 5 , /* chans=*/ 2 , /* scale_h*/ 2.5 , /* scale_w*/ 3.4 },
4281+ {/* batch_size=*/ 2 , /* h=*/ 1335 , /* w=*/ 1335 , /* chans=*/ 3 , /* scale_h*/ 2.5 , /* scale_w*/ 3.4 },
4282+ {/* batch_size=*/ 2 , /* h=*/ 1335 , /* w=*/ 1335 , /* chans=*/ 3 , /* scale_h*/ 0.5 , /* scale_w*/ 0.5 },
4283+ };
4284+ /* clang-format on */
4285+
4286+ for (const auto & img_info : inputs) {
4287+ for (bool align_corners : {true , false }) {
4288+ auto testfn =
4289+ [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4290+ return torch::upsample_nearest2d (
4291+ inputs[0 ], c10::nullopt ,
4292+ at::ArrayRef<double >{img_info.scale_h , img_info.scale_w });
4293+ };
4294+ ForEachDevice ([&](const torch::Device& device) {
4295+ TestBackward (
4296+ {torch::rand (
4297+ {img_info.batch_size , img_info.chans , img_info.h , img_info.w },
4298+ torch::TensorOptions (torch::kFloat ).requires_grad (true ))},
4299+ device, testfn);
4300+ XlaDeviceType device_type = static_cast <XlaDeviceType>(
4301+ bridge::AtenDeviceToXlaDevice (device).type ());
4302+ if (device_type == XlaDeviceType::TPU) {
4303+ // Only lowered for TPU, fallback for CPU.
4304+ ExpectCounterNotChanged (" aten::.*" , cpp_test::GetIgnoredCounters ());
4305+ ExpectCounterChanged (" xla::upsample_nearest2d_backward" ,
4306+ cpp_test::GetIgnoredCounters ());
4307+ ResetCounters ();
4308+ } else {
4309+ ExpectCounterChanged (" aten::.*" , cpp_test::GetIgnoredCounters ());
4310+ ResetCounters ();
4311+ }
4312+ });
4313+ }
4314+ }
42874315}
42884316
42894317TEST_F (AtenXlaTensorTest, TestUpsampleBilinear2D) {
@@ -4388,6 +4416,54 @@ TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DBackward) {
43884416 }
43894417}
43904418
4419+ TEST_F (AtenXlaTensorTest, TestUpsampleBilinear2DBackwardWithScale) {
4420+ struct ImageInfo {
4421+ int batch_size;
4422+ int h;
4423+ int w;
4424+ int chans;
4425+ double scale_h;
4426+ double scale_w;
4427+ };
4428+
4429+ /* clang-format off */
4430+ std::vector<ImageInfo> inputs = {
4431+ {/* batch_size=*/ 2 , /* h=*/ 5 , /* w=*/ 5 , /* chans=*/ 2 , /* scale_h*/ 8.0 /5 , /* scale_w*/ 8.0 /5 },
4432+ {/* batch_size=*/ 2 , /* h=*/ 1335 , /* w=*/ 1335 , /* chans=*/ 3 , /* scale_h*/ 255.0 /1335 , /* scale_w*/ 255.0 /1335 },
4433+ };
4434+ /* clang-format on */
4435+
4436+ for (const auto & img_info : inputs) {
4437+ for (bool align_corners : {true , false }) {
4438+ auto testfn =
4439+ [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4440+ return torch::upsample_bilinear2d (
4441+ inputs[0 ], c10::nullopt , align_corners,
4442+ at::ArrayRef<double >{img_info.scale_h , img_info.scale_w });
4443+ };
4444+ ForEachDevice ([&](const torch::Device& device) {
4445+ TestBackward (
4446+ {torch::rand (
4447+ {img_info.batch_size , img_info.chans , img_info.h , img_info.w },
4448+ torch::TensorOptions (torch::kFloat ).requires_grad (true ))},
4449+ device, testfn);
4450+ XlaDeviceType device_type = static_cast <XlaDeviceType>(
4451+ bridge::AtenDeviceToXlaDevice (device).type ());
4452+ if (device_type == XlaDeviceType::TPU) {
4453+ // Only lowered for TPU, fallback for CPU.
4454+ ExpectCounterNotChanged (" aten::.*" , cpp_test::GetIgnoredCounters ());
4455+ ExpectCounterChanged (" xla::upsample_bilinear2d_backward" ,
4456+ cpp_test::GetIgnoredCounters ());
4457+ ResetCounters ();
4458+ } else {
4459+ ExpectCounterChanged (" aten::.*" , cpp_test::GetIgnoredCounters ());
4460+ ResetCounters ();
4461+ }
4462+ });
4463+ }
4464+ }
4465+ }
4466+
43914467TEST_F (AtenXlaTensorTest, TestAddCMul) {
43924468 torch::Tensor a = torch::rand ({2 , 2 }, torch::TensorOptions (torch::kFloat ));
43934469 torch::Tensor b = torch::rand ({2 , 2 }, torch::TensorOptions (torch::kFloat ));
0 commit comments