Skip to content

Commit f594c29

Browse files
authored
Lower UpsampleBilinear/Nearest2DBackward with scale factor on TPU (#4710)
1 parent d86b323 commit f594c29

File tree

2 files changed

+124
-23
lines changed

2 files changed

+124
-23
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 93 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4266,24 +4266,52 @@ TEST_F(AtenXlaTensorTest, TestUpsampleNearest2DWithScale) {
42664266
}
42674267

42684268
TEST_F(AtenXlaTensorTest, TestUpsampleNearest2DBackwardWithScale) {
4269-
int batch_size = 2;
4270-
int h = 5;
4271-
int w = 5;
4272-
int chans = 2;
4273-
double scale_h = 2.5;
4274-
double scale_w = 3.4;
4275-
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4276-
return torch::upsample_nearest2d(inputs[0], c10::nullopt,
4277-
at::ArrayRef<double>{scale_h, scale_w});
4269+
struct ImageInfo {
4270+
int batch_size;
4271+
int h;
4272+
int w;
4273+
int chans;
4274+
double scale_h;
4275+
double scale_w;
42784276
};
4279-
ForEachDevice([&](const torch::Device& device) {
4280-
TestBackward(
4281-
{torch::rand({batch_size, chans, h, w},
4282-
torch::TensorOptions(torch::kFloat).requires_grad(true))},
4283-
device, testfn);
4284-
});
4285-
ExpectCounterChanged("xla::upsample_nearest2d_backward",
4286-
cpp_test::GetIgnoredCounters());
4277+
4278+
/* clang-format off */
4279+
std::vector<ImageInfo> inputs = {
4280+
{/*batch_size=*/2, /*h=*/5, /*w=*/5, /*chans=*/2, /*scale_h*/2.5, /*scale_w*/3.4},
4281+
{/*batch_size=*/2, /*h=*/1335, /*w=*/1335, /*chans=*/3, /*scale_h*/2.5, /*scale_w*/3.4},
4282+
{/*batch_size=*/2, /*h=*/1335, /*w=*/1335, /*chans=*/3, /*scale_h*/0.5, /*scale_w*/0.5},
4283+
};
4284+
/* clang-format on */
4285+
4286+
for (const auto& img_info : inputs) {
4287+
for (bool align_corners : {true, false}) {
4288+
auto testfn =
4289+
[&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4290+
return torch::upsample_nearest2d(
4291+
inputs[0], c10::nullopt,
4292+
at::ArrayRef<double>{img_info.scale_h, img_info.scale_w});
4293+
};
4294+
ForEachDevice([&](const torch::Device& device) {
4295+
TestBackward(
4296+
{torch::rand(
4297+
{img_info.batch_size, img_info.chans, img_info.h, img_info.w},
4298+
torch::TensorOptions(torch::kFloat).requires_grad(true))},
4299+
device, testfn);
4300+
XlaDeviceType device_type = static_cast<XlaDeviceType>(
4301+
bridge::AtenDeviceToXlaDevice(device).type());
4302+
if (device_type == XlaDeviceType::TPU) {
4303+
// Only lowered for TPU, fallback for CPU.
4304+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
4305+
ExpectCounterChanged("xla::upsample_nearest2d_backward",
4306+
cpp_test::GetIgnoredCounters());
4307+
ResetCounters();
4308+
} else {
4309+
ExpectCounterChanged("aten::.*", cpp_test::GetIgnoredCounters());
4310+
ResetCounters();
4311+
}
4312+
});
4313+
}
4314+
}
42874315
}
42884316

42894317
TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2D) {
@@ -4388,6 +4416,54 @@ TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DBackward) {
43884416
}
43894417
}
43904418

4419+
TEST_F(AtenXlaTensorTest, TestUpsampleBilinear2DBackwardWithScale) {
4420+
struct ImageInfo {
4421+
int batch_size;
4422+
int h;
4423+
int w;
4424+
int chans;
4425+
double scale_h;
4426+
double scale_w;
4427+
};
4428+
4429+
/* clang-format off */
4430+
std::vector<ImageInfo> inputs = {
4431+
{/*batch_size=*/2, /*h=*/5, /*w=*/5, /*chans=*/2, /*scale_h*/8.0/5, /*scale_w*/8.0/5},
4432+
{/*batch_size=*/2, /*h=*/1335, /*w=*/1335, /*chans=*/3, /*scale_h*/255.0/1335, /*scale_w*/255.0/1335},
4433+
};
4434+
/* clang-format on */
4435+
4436+
for (const auto& img_info : inputs) {
4437+
for (bool align_corners : {true, false}) {
4438+
auto testfn =
4439+
[&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
4440+
return torch::upsample_bilinear2d(
4441+
inputs[0], c10::nullopt, align_corners,
4442+
at::ArrayRef<double>{img_info.scale_h, img_info.scale_w});
4443+
};
4444+
ForEachDevice([&](const torch::Device& device) {
4445+
TestBackward(
4446+
{torch::rand(
4447+
{img_info.batch_size, img_info.chans, img_info.h, img_info.w},
4448+
torch::TensorOptions(torch::kFloat).requires_grad(true))},
4449+
device, testfn);
4450+
XlaDeviceType device_type = static_cast<XlaDeviceType>(
4451+
bridge::AtenDeviceToXlaDevice(device).type());
4452+
if (device_type == XlaDeviceType::TPU) {
4453+
// Only lowered for TPU, fallback for CPU.
4454+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
4455+
ExpectCounterChanged("xla::upsample_bilinear2d_backward",
4456+
cpp_test::GetIgnoredCounters());
4457+
ResetCounters();
4458+
} else {
4459+
ExpectCounterChanged("aten::.*", cpp_test::GetIgnoredCounters());
4460+
ResetCounters();
4461+
}
4462+
});
4463+
}
4464+
}
4465+
}
4466+
43914467
TEST_F(AtenXlaTensorTest, TestAddCMul) {
43924468
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
43934469
torch::Tensor b = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2951,16 +2951,26 @@ at::Tensor XLANativeFunctions::upsample_bilinear2d_backward(
29512951
// our XLA lowering.
29522952
XlaDeviceType hw_type =
29532953
static_cast<XlaDeviceType>(grad_output_tensor->GetDevice().type());
2954-
if (hw_type != XlaDeviceType::TPU || (scales_h && *scales_h != 1.0) ||
2955-
(scales_w && *scales_w != 1.0)) {
2954+
if (hw_type != XlaDeviceType::TPU) {
29562955
return at::native::call_fallback_fn<
29572956
&xla_cpu_fallback,
29582957
ATEN_OP(upsample_bilinear2d_backward)>::call(grad_output, output_size,
29592958
input_size, align_corners,
29602959
scales_h, scales_w);
29612960
}
2961+
std::vector<int64_t> scaled_output_size =
2962+
torch::lazy::ToVector<int64_t>(output_size);
2963+
if ((scales_h && *scales_h != 1.0) || (scales_w && *scales_w != 1.0)) {
2964+
scaled_output_size = GetOutputSizeWithScale(input_size, scales_h, scales_w,
2965+
scaled_output_size);
2966+
if (!output_size.empty()) {
2967+
XLA_CHECK(scaled_output_size.at(0) == output_size.at(0) &&
2968+
scaled_output_size.at(1) == output_size.at(1))
2969+
<< "Inferred output size and output_size from upstream are different";
2970+
}
2971+
}
29622972
return bridge::AtenFromXlaTensor(tensor_methods::upsample_bilinear2d_backward(
2963-
grad_output_tensor, torch::lazy::ToVector<int64_t>(output_size),
2973+
grad_output_tensor, torch::lazy::ToVector<int64_t>(scaled_output_size),
29642974
torch::lazy::ToVector<int64_t>(input_size), align_corners));
29652975
}
29662976

@@ -2976,6 +2986,11 @@ at::Tensor XLANativeFunctions::upsample_nearest2d(
29762986
if ((scales_h && *scales_h != 1.0) || (scales_w && *scales_w != 1.0)) {
29772987
scaled_output_size = GetOutputSizeWithScale(input_dims, scales_h, scales_w,
29782988
scaled_output_size);
2989+
if (!output_size.empty()) {
2990+
XLA_CHECK(scaled_output_size.at(0) == output_size.at(0) &&
2991+
scaled_output_size.at(1) == output_size.at(1))
2992+
<< "Inferred output size and output_size from upstream are different";
2993+
}
29792994
}
29802995
return bridge::AtenFromXlaTensor(
29812996
tensor_methods::upsample_nearest2d(self_tensor, scaled_output_size));
@@ -2991,16 +3006,26 @@ at::Tensor XLANativeFunctions::upsample_nearest2d_backward(
29913006
// our XLA lowering.
29923007
XlaDeviceType hw_type =
29933008
static_cast<XlaDeviceType>(grad_output_tensor->GetDevice().type());
2994-
if (hw_type != XlaDeviceType::TPU || (scales_h && *scales_h != 1.0) ||
2995-
(scales_w && *scales_w != 1.0)) {
3009+
if (hw_type != XlaDeviceType::TPU) {
29963010
return at::native::call_fallback_fn<
29973011
&xla_cpu_fallback,
29983012
ATEN_OP(upsample_nearest2d_backward)>::call(grad_output, output_size,
29993013
input_size, scales_h,
30003014
scales_w);
30013015
}
3016+
std::vector<int64_t> scaled_output_size =
3017+
torch::lazy::ToVector<int64_t>(output_size);
3018+
if ((scales_h && *scales_h != 1.0) || (scales_w && *scales_w != 1.0)) {
3019+
scaled_output_size = GetOutputSizeWithScale(input_size, scales_h, scales_w,
3020+
scaled_output_size);
3021+
if (!output_size.empty()) {
3022+
XLA_CHECK(scaled_output_size.at(0) == output_size.at(0) &&
3023+
scaled_output_size.at(1) == output_size.at(1))
3024+
<< "Inferred output size and output_size from upstream are different";
3025+
}
3026+
}
30023027
return bridge::AtenFromXlaTensor(tensor_methods::upsample_nearest2d_backward(
3003-
grad_output_tensor, torch::lazy::ToVector<int64_t>(output_size),
3028+
grad_output_tensor, torch::lazy::ToVector<int64_t>(scaled_output_size),
30043029
torch::lazy::ToVector<int64_t>(input_size)));
30053030
}
30063031

0 commit comments

Comments
 (0)