Skip to content

Commit 407d259

Browse files
wonjoo-wjalanwaketan
authored andcommitted
[Functionalization] Lower masked_fill.Tensor and masked_fill.Scalar ops (#4616)
* Lower masked_fill.Scalar and masked_fill.Tensor to fix related cpp tests * Remove in-place versions for masked_fill * Clean-up some code * Update tensor_methods::masked_fill to expand input tensor if needed * Add check to expand only if the rank of the input tensor is less than that of the mask tensor * Update tensor rank comparison if condition * Enable KlDivBackward cpp test
1 parent 8f026c1 commit 407d259

File tree

5 files changed

+51
-46
lines changed

5 files changed

+51
-46
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8827,7 +8827,6 @@ TEST_F(AtenXlaTensorTest, TestUnsqueezeInPlace) {
88278827
}
88288828

88298829
TEST_F(AtenXlaTensorTest, TestMaskedFill) {
8830-
GTEST_SKIP() << "SegFault after functionalization";
88318830
torch::Tensor input =
88328831
torch::rand({2, 3}, torch::TensorOptions(torch::kFloat));
88338832
torch::Tensor mask =
@@ -8842,11 +8841,10 @@ TEST_F(AtenXlaTensorTest, TestMaskedFill) {
88428841
});
88438842

88448843
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
8845-
ExpectCounterChanged("xla::masked_fill_", cpp_test::GetIgnoredCounters());
8844+
ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters());
88468845
}
88478846

88488847
TEST_F(AtenXlaTensorTest, TestMaskedFillInPlace) {
8849-
GTEST_SKIP() << "SegFault after functionalization";
88508848
torch::Scalar value(42);
88518849
torch::Tensor mask =
88528850
torch::randint(0, 2, {2, 3}, torch::TensorOptions(torch::kBool));
@@ -8862,11 +8860,10 @@ TEST_F(AtenXlaTensorTest, TestMaskedFillInPlace) {
88628860
});
88638861

88648862
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
8865-
ExpectCounterChanged("xla::masked_fill_", cpp_test::GetIgnoredCounters());
8863+
ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters());
88668864
}
88678865

8868-
TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast) {
8869-
GTEST_SKIP() << "SegFault after functionalization";
8866+
TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast1) {
88708867
torch::Tensor input =
88718868
torch::rand({2, 5, 4, 3}, torch::TensorOptions(torch::kFloat));
88728869
torch::Tensor mask =
@@ -8881,7 +8878,25 @@ TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast) {
88818878
});
88828879

88838880
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
8884-
ExpectCounterChanged("xla::masked_fill_", cpp_test::GetIgnoredCounters());
8881+
ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters());
8882+
}
8883+
8884+
TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast2) {
8885+
torch::Tensor input =
8886+
torch::rand({2, 1}, torch::TensorOptions(torch::kFloat));
8887+
torch::Tensor mask =
8888+
torch::randint(0, 2, {2, 3}, torch::TensorOptions(torch::kBool));
8889+
torch::Scalar value(42);
8890+
torch::Tensor result = torch::masked_fill(input, mask, value);
8891+
ForEachDevice([&](const torch::Device& device) {
8892+
torch::Tensor xla_input = CopyToDevice(input, device);
8893+
torch::Tensor xla_mask = CopyToDevice(mask, device);
8894+
torch::Tensor xla_result = torch::masked_fill(xla_input, xla_mask, value);
8895+
AllClose(result, xla_result);
8896+
});
8897+
8898+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
8899+
ExpectCounterChanged("xla::masked_fill", cpp_test::GetIgnoredCounters());
88858900
}
88868901

88878902
TEST_F(AtenXlaTensorTest, TestFill) {
@@ -11301,7 +11316,6 @@ TEST_F(AtenXlaTensorTest, TestBCEWithLogitsBackward) {
1130111316
}
1130211317

1130311318
TEST_F(AtenXlaTensorTest, TestKlDivBackward) {
11304-
GTEST_SKIP() << "SegFault after functionalization";
1130511319
torch::Tensor input = torch::rand(
1130611320
{4, 3}, torch::TensorOptions(torch::kFloat).requires_grad(true));
1130711321
torch::Tensor target = torch::rand(

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,23 +1643,23 @@ at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self,
16431643
bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
16441644
}
16451645

1646-
at::Tensor& XLANativeFunctions::masked_fill_(at::Tensor& self,
1647-
const at::Tensor& mask,
1648-
const at::Scalar& value) {
1649-
TORCH_LAZY_FN_COUNTER("xla::");
1650-
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
1651-
tensor_methods::masked_fill_(self_tensor, bridge::GetXlaTensor(mask), value);
1652-
return self;
1653-
}
1654-
1655-
at::Tensor& XLANativeFunctions::masked_fill_(at::Tensor& self,
1656-
const at::Tensor& mask,
1657-
const at::Tensor& value) {
1646+
at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self,
1647+
const at::Tensor& mask,
1648+
const at::Tensor& value) {
16581649
TORCH_LAZY_FN_COUNTER("xla::");
16591650
XLA_CHECK_EQ(value.dim(), 0) << "masked_fill_ only supports a 0-dimensional "
16601651
<< "value tensor, but got tensor "
16611652
<< "with " << value.dim() << " dimension(s).";
1662-
return masked_fill_(self, mask, value.item());
1653+
return masked_fill(self, mask, value.item());
1654+
}
1655+
1656+
at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self,
1657+
const at::Tensor& mask,
1658+
const at::Scalar& value) {
1659+
TORCH_LAZY_FN_COUNTER("xla::");
1660+
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
1661+
return bridge::AtenFromXlaTensor(tensor_methods::masked_fill(
1662+
self_tensor, bridge::GetXlaTensor(mask), value));
16631663
}
16641664

16651665
at::Tensor XLANativeFunctions::masked_scatter(const at::Tensor& self,
@@ -3392,20 +3392,6 @@ at::Tensor XLANativeFunctions::linalg_pinv(
33923392
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
33933393
}
33943394

3395-
at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self,
3396-
const at::Tensor& mask,
3397-
const at::Tensor& value) {
3398-
return at::functionalization::functionalize_aten_op<ATEN_OP2(
3399-
masked_fill, Tensor)>::call(self, mask, value);
3400-
}
3401-
3402-
at::Tensor XLANativeFunctions::masked_fill(const at::Tensor& self,
3403-
const at::Tensor& mask,
3404-
const at::Scalar& value) {
3405-
return at::functionalization::functionalize_aten_op<ATEN_OP2(
3406-
masked_fill, Scalar)>::call(self, mask, value);
3407-
}
3408-
34093395
at::Tensor XLANativeFunctions::mvlgamma(const at::Tensor& self, int64_t p) {
34103396
return at::functionalization::functionalize_aten_op<ATEN_OP(mvlgamma)>::call(
34113397
self, p);

torch_xla/csrc/tensor_methods.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,11 +1536,19 @@ XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other) {
15361536
return DispatchComparisonOp(at::aten::lt, input, other);
15371537
}
15381538

1539-
void masked_fill_(XLATensorPtr& input, const XLATensorPtr& mask,
1540-
const at::Scalar& value) {
1539+
XLATensorPtr masked_fill(XLATensorPtr& input, const XLATensorPtr& mask,
1540+
const at::Scalar& value) {
15411541
torch::lazy::ScopePusher ir_scope(at::aten::masked_fill.toQualString());
1542-
input->SetIrValue(torch::lazy::MakeNode<MaskedFill>(
1543-
input->GetIrValue(), MaybeExpand(mask->GetIrValue(), input->shape()),
1542+
auto input_value = input->GetIrValue();
1543+
// Expand input tensor to mask if needed (same as masked_scatter below).
1544+
// An additional check makes sure to only expand if the rank of input tensor
1545+
// is less than that of the mask tensor.
1546+
if (input->shape().get().rank() <= mask->shape().get().rank() &&
1547+
input->shape().get().dimensions() < mask->shape().get().dimensions()) {
1548+
input_value = MaybeExpand(input->GetIrValue(), mask->shape());
1549+
}
1550+
return input->CreateFrom(torch::lazy::MakeNode<MaskedFill>(
1551+
input_value, MaybeExpand(mask->GetIrValue(), GetXlaShape(input_value)),
15441552
value));
15451553
}
15461554

torch_xla/csrc/tensor_methods.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,8 @@ XLATensorPtr lt(const XLATensorPtr& input, const at::Scalar& other);
484484

485485
XLATensorPtr lt(const XLATensorPtr& input, const XLATensorPtr& other);
486486

487-
// In-place version of the method above.
488-
void masked_fill_(XLATensorPtr& input, const XLATensorPtr& mask,
489-
const at::Scalar& value);
487+
XLATensorPtr masked_fill(XLATensorPtr& input, const XLATensorPtr& mask,
488+
const at::Scalar& value);
490489

491490
XLATensorPtr masked_scatter(XLATensorPtr& input, const XLATensorPtr& mask,
492491
const XLATensorPtr& source);

xla_native_functions.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ supported:
212212
- log2
213213
- log10
214214
- logsumexp
215-
- masked_fill_.Scalar
216-
- masked_fill_.Tensor
215+
- masked_fill.Scalar
216+
- masked_fill.Tensor
217217
- masked_scatter
218218
- masked_select
219219
- max
@@ -368,8 +368,6 @@ supported:
368368
- _trilinear
369369
- linalg_pinv.atol_rtol_tensor
370370
- _cdist_forward
371-
- masked_fill.Scalar
372-
- masked_fill.Tensor
373371
- mvlgamma
374372
- permute
375373
# The same applies to these ops, but we already have direct lowerings for them

0 commit comments

Comments
 (0)