Skip to content

Commit 6c53a1e

Browse files
authored
Lower _conj_copy and alias operation. (#8686)
1 parent d06a9c9 commit 6c53a1e

File tree

9 files changed

+72
-4
lines changed

9 files changed

+72
-4
lines changed

codegen/xla_native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ full_codegen:
2626
- clamp.Tensor
2727
- clamp_max.Tensor
2828
- clamp_min.Tensor
29+
- _conj_copy
2930
- cos
3031
- cosh
3132
- elu
@@ -138,6 +139,7 @@ supported:
138139
- add.Scalar
139140
- add.Tensor
140141
- addmm
142+
- alias
141143
- alias_copy
142144
- arange.start_out
143145
- as_strided_copy

test/cpp/test_aten_xla_tensor_2.cpp

100755100644
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,6 +2420,27 @@ TEST_F(AtenXlaTensorTest, TestAcoshInPlace) {
24202420
ExpectCounterChanged("xla::acosh", cpp_test::GetIgnoredCounters());
24212421
}
24222422

2423+
TEST_F(AtenXlaTensorTest, TestAlias) {
2424+
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
2425+
torch::Tensor b = torch::alias(a);
2426+
ForEachDevice([&](const torch::Device& device) {
2427+
torch::Tensor xla_a = CopyToDevice(a, device);
2428+
torch::Tensor xla_b = torch::alias(xla_a);
2429+
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/0e-5);
2430+
});
2431+
}
2432+
2433+
TEST_F(AtenXlaTensorTest, TestConj) {
2434+
torch::Tensor a =
2435+
torch::rand({2, 2}, torch::TensorOptions(torch::kComplexFloat));
2436+
torch::Tensor b = torch::conj(a);
2437+
ForEachDevice([&](const torch::Device& device) {
2438+
torch::Tensor xla_a = CopyToDevice(a, device);
2439+
torch::Tensor xla_b = torch::conj(xla_a);
2440+
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/0e-5);
2441+
});
2442+
}
2443+
24232444
TEST_F(AtenXlaTensorTest, TestCos) {
24242445
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
24252446
torch::Tensor b = torch::cos(a);

test/test_operations.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,6 +2384,24 @@ def test_cummax_0_sized_dimension(self):
23842384

23852385
self.assertEqual(actual, expected)
23862386

2387+
def test_conj(self):
2388+
# Leave the factory out of the fallback count.
2389+
tensor = torch.rand(2, 2, dtype=torch.complex64)
2390+
2391+
met.clear_all()
2392+
2393+
def run(device):
2394+
return torch.conj(tensor.to(device))
2395+
2396+
actual = run("cpu")
2397+
expected = run(xm.xla_device())
2398+
2399+
self.assertEqual(
2400+
met.executed_fallback_ops(), [],
2401+
message="expected no fallback operations.")
2402+
self.assertEqual(
2403+
actual, expected.cpu(), message="XLA results should match CPU results.")
2404+
23872405

23882406
class MNISTComparator(nn.Module):
23892407

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -863,12 +863,17 @@ at::Tensor XLANativeFunctions::addmm(const at::Tensor& self,
863863
/*bias=*/bridge::GetXlaTensor(self)));
864864
}
865865

866-
at::Tensor XLANativeFunctions::alias_copy(const at::Tensor& self) {
866+
at::Tensor XLANativeFunctions::alias(const at::Tensor& self) {
867867
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
868868
return bridge::AtenFromXlaTensor(
869869
tensor_methods::alias(bridge::GetXlaTensor(self)));
870870
}
871871

872+
at::Tensor XLANativeFunctions::alias_copy(const at::Tensor& self) {
873+
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
874+
return alias(self);
875+
}
876+
872877
at::Tensor& XLANativeFunctions::arange_out(const at::Scalar& start,
873878
const at::Scalar& end,
874879
const at::Scalar& step,
@@ -1333,8 +1338,12 @@ at::Tensor XLANativeFunctions::clone(
13331338
const at::Tensor& self,
13341339
std::optional<at::MemoryFormat> /* memory_format */) {
13351340
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1336-
return bridge::AtenFromXlaTensor(
1337-
tensor_methods::clone(bridge::GetXlaTensor(self)));
1341+
auto tensor = bridge::GetXlaTensor(self);
1342+
if (self.is_conj()) {
1343+
// Materialize the conjugate if necessary.
1344+
tensor = tensor_methods::conj(tensor);
1345+
}
1346+
return bridge::AtenFromXlaTensor(tensor_methods::clone(tensor));
13381347
}
13391348

13401349
at::Tensor XLANativeFunctions::constant_pad_nd(const at::Tensor& self,

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,11 @@ torch_xla::XlaOpVector ClampMinTensor::Lower(LoweringContext* loctx) const {
342342
return ReturnOp(xla::Max(xla_input, xla_other), loctx);
343343
}
344344

345+
torch_xla::XlaOpVector ConjCopy::Lower(LoweringContext* loctx) const {
346+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
347+
return ReturnOp(xla::Conj(input), loctx);
348+
}
349+
345350
torch_xla::XlaOpVector Cos::Lower(LoweringContext* loctx) const {
346351
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
347352
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,10 @@ xla::Shape ClampMinTensorOutputShape(const torch::lazy::Value& input,
424424
lower_for_shape_fn);
425425
}
426426

427+
xla::Shape ConjCopyOutputShape(const torch::lazy::Value& input) {
428+
return GetXlaShape(input);
429+
}
430+
427431
xla::Shape CosOutputShape(const torch::lazy::Value& input) {
428432
xla::Shape result_shape = GetXlaShape(input);
429433
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ xla::Shape ClampMaxTensorOutputShape(const torch::lazy::Value& input,
108108
xla::Shape ClampMinTensorOutputShape(const torch::lazy::Value& input,
109109
const torch::lazy::Value& target);
110110

111+
xla::Shape ConjCopyOutputShape(const torch::lazy::Value& input);
112+
111113
xla::Shape CosOutputShape(const torch::lazy::Value& input);
112114

113115
xla::Shape CoshOutputShape(const torch::lazy::Value& input);
@@ -287,4 +289,4 @@ xla::Shape TruncOutputShape(const torch::lazy::Value& input);
287289

288290
} // namespace torch_xla
289291

290-
#endif // XLA_TORCH_XLA_CSRC_OPS_OPS_XLA_SHAPE_FN_H_
292+
#endif // XLA_TORCH_XLA_CSRC_OPS_OPS_XLA_SHAPE_FN_H_

torch_xla/csrc/tensor_methods.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,11 @@ XLATensorPtr clone(const XLATensorPtr& input) {
12381238
return cloned;
12391239
}
12401240

1241+
XLATensorPtr conj(const XLATensorPtr& input) {
1242+
auto ir = input->GetIrValue();
1243+
return input->CreateFrom(torch_xla::MakeNode<ConjCopy>(ir));
1244+
}
1245+
12411246
XLATensorPtr constant_pad_nd(const XLATensorPtr& input,
12421247
absl::Span<const int64_t> pad,
12431248
const at::Scalar& value) {

torch_xla/csrc/tensor_methods.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,8 @@ XLATensorPtr clamp(const XLATensorPtr& input,
327327

328328
XLATensorPtr clone(const XLATensorPtr& input);
329329

330+
XLATensorPtr conj(const XLATensorPtr& input);
331+
330332
// Pad with the given value and size specified by the given list of low and
331333
// high paddings.
332334
XLATensorPtr constant_pad_nd(const XLATensorPtr& input,

0 commit comments

Comments
 (0)