@@ -8827,7 +8827,6 @@ TEST_F(AtenXlaTensorTest, TestUnsqueezeInPlace) {
88278827}
88288828
88298829TEST_F (AtenXlaTensorTest, TestMaskedFill) {
8830- GTEST_SKIP () << " SegFault after functionalization" ;
88318830 torch::Tensor input =
88328831 torch::rand ({2 , 3 }, torch::TensorOptions (torch::kFloat ));
88338832 torch::Tensor mask =
@@ -8842,11 +8841,10 @@ TEST_F(AtenXlaTensorTest, TestMaskedFill) {
88428841 });
88438842
88448843 ExpectCounterNotChanged (" aten::.*" , cpp_test::GetIgnoredCounters ());
8845- ExpectCounterChanged (" xla::masked_fill_ " , cpp_test::GetIgnoredCounters ());
8844+ ExpectCounterChanged (" xla::masked_fill " , cpp_test::GetIgnoredCounters ());
88468845}
88478846
88488847TEST_F (AtenXlaTensorTest, TestMaskedFillInPlace) {
8849- GTEST_SKIP () << " SegFault after functionalization" ;
88508848 torch::Scalar value (42 );
88518849 torch::Tensor mask =
88528850 torch::randint (0 , 2 , {2 , 3 }, torch::TensorOptions (torch::kBool ));
@@ -8862,11 +8860,10 @@ TEST_F(AtenXlaTensorTest, TestMaskedFillInPlace) {
88628860 });
88638861
88648862 ExpectCounterNotChanged (" aten::.*" , cpp_test::GetIgnoredCounters ());
8865- ExpectCounterChanged (" xla::masked_fill_ " , cpp_test::GetIgnoredCounters ());
8863+ ExpectCounterChanged (" xla::masked_fill " , cpp_test::GetIgnoredCounters ());
88668864}
88678865
8868- TEST_F (AtenXlaTensorTest, TestMaskedFillBroadcast) {
8869- GTEST_SKIP () << " SegFault after functionalization" ;
8866+ TEST_F (AtenXlaTensorTest, TestMaskedFillBroadcast1) {
88708867 torch::Tensor input =
88718868 torch::rand ({2 , 5 , 4 , 3 }, torch::TensorOptions (torch::kFloat ));
88728869 torch::Tensor mask =
@@ -8881,7 +8878,25 @@ TEST_F(AtenXlaTensorTest, TestMaskedFillBroadcast) {
88818878 });
88828879
88838880 ExpectCounterNotChanged (" aten::.*" , cpp_test::GetIgnoredCounters ());
8884- ExpectCounterChanged (" xla::masked_fill_" , cpp_test::GetIgnoredCounters ());
8881+ ExpectCounterChanged (" xla::masked_fill" , cpp_test::GetIgnoredCounters ());
8882+ }
8883+
8884+ TEST_F (AtenXlaTensorTest, TestMaskedFillBroadcast2) {
8885+ torch::Tensor input =
8886+ torch::rand ({2 , 1 }, torch::TensorOptions (torch::kFloat ));
8887+ torch::Tensor mask =
8888+ torch::randint (0 , 2 , {2 , 3 }, torch::TensorOptions (torch::kBool ));
8889+ torch::Scalar value (42 );
8890+ torch::Tensor result = torch::masked_fill (input, mask, value);
8891+ ForEachDevice ([&](const torch::Device& device) {
8892+ torch::Tensor xla_input = CopyToDevice (input, device);
8893+ torch::Tensor xla_mask = CopyToDevice (mask, device);
8894+ torch::Tensor xla_result = torch::masked_fill (xla_input, xla_mask, value);
8895+ AllClose (result, xla_result);
8896+ });
8897+
8898+ ExpectCounterNotChanged (" aten::.*" , cpp_test::GetIgnoredCounters ());
8899+ ExpectCounterChanged (" xla::masked_fill" , cpp_test::GetIgnoredCounters ());
88858900}
88868901
88878902TEST_F (AtenXlaTensorTest, TestFill) {
@@ -11301,7 +11316,6 @@ TEST_F(AtenXlaTensorTest, TestBCEWithLogitsBackward) {
1130111316}
1130211317
1130311318TEST_F (AtenXlaTensorTest, TestKlDivBackward) {
11304- GTEST_SKIP () << " SegFault after functionalization" ;
1130511319 torch::Tensor input = torch::rand (
1130611320 {4 , 3 }, torch::TensorOptions (torch::kFloat ).requires_grad (true ));
1130711321 torch::Tensor target = torch::rand (
0 commit comments