Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2884,7 +2884,7 @@ TEST_F(AtenXlaTensorTest, TestClampMinTensorExplicit) {
AllClose(b, xla_b);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::clamp_min_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::clamp_min", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestClampMaxExplicit) {
Expand Down Expand Up @@ -2912,7 +2912,7 @@ TEST_F(AtenXlaTensorTest, TestClampMaxTensorExplicit) {
AllClose(b, xla_b);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::clamp_max_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::clamp_max", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestClampMinExplicitInPlace) {
Expand Down Expand Up @@ -3460,7 +3460,7 @@ TEST_F(AtenXlaTensorTest, TestSiLU) {
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5);
});
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::silu_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::silu", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSiLUBackward) {
Expand Down Expand Up @@ -9029,7 +9029,7 @@ TEST_F(AtenXlaTensorTest, TestBitwiseOr) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_or_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_or", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestBitwiseOrInPlace) {
Expand All @@ -9047,7 +9047,7 @@ TEST_F(AtenXlaTensorTest, TestBitwiseOrInPlace) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_or_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_or", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestBitwiseOrScalar) {
Expand All @@ -9062,7 +9062,7 @@ TEST_F(AtenXlaTensorTest, TestBitwiseOrScalar) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_or_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_or", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestBitwiseOrScalarInPlace) {
Expand All @@ -9078,7 +9078,7 @@ TEST_F(AtenXlaTensorTest, TestBitwiseOrScalarInPlace) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_or_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_or", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestBitwiseXor) {
Expand All @@ -9095,7 +9095,7 @@ TEST_F(AtenXlaTensorTest, TestBitwiseXor) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_xor_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_xor", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestBitwiseXorInPlace) {
Expand All @@ -9113,7 +9113,7 @@ TEST_F(AtenXlaTensorTest, TestBitwiseXorInPlace) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_xor_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_xor", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestBitwiseXorScalar) {
Expand All @@ -9128,7 +9128,7 @@ TEST_F(AtenXlaTensorTest, TestBitwiseXorScalar) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_xor_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_xor", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestBitwiseXorScalarInPlace) {
Expand All @@ -9144,7 +9144,7 @@ TEST_F(AtenXlaTensorTest, TestBitwiseXorScalarInPlace) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_xor_out", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::bitwise_xor", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestLshift) {
Expand Down
98 changes: 43 additions & 55 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ at::Tensor DoBinaryOpWithoutPromo(const at::Tensor& self,
return bridge::AtenFromXlaTensor(result);
}

template <typename B>
at::Tensor DoBinaryOpWithoutPromo(const at::Tensor& self,
const at::Scalar& other, const B& bin_op) {
at::ScalarType dtype = at::result_type(self, other);
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor result = bin_op(self_tensor, other);
return bridge::AtenFromXlaTensor(result);
}

template <typename B>
void DoBinaryOpOut(const at::Tensor& self, const at::Tensor& other,
at::Tensor& out, const B& bin_op_out) {
Expand Down Expand Up @@ -936,57 +945,46 @@ at::Tensor XLANativeFunctions::bitwise_and(const at::Tensor& self,
});
}

at::Tensor& XLANativeFunctions::bitwise_not_out(const at::Tensor& self,
at::Tensor& out) {
at::Tensor XLANativeFunctions::bitwise_not(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::bitwise_not_out(out_tensor, self_tensor);
return out;
return bridge::AtenFromXlaTensor(XLATensor::bitwise_not(self_tensor));
}

at::Tensor& XLANativeFunctions::bitwise_or_out(const at::Tensor& self,
const at::Scalar& other,
at::Tensor& out) {
at::Tensor XLANativeFunctions::bitwise_or(const at::Tensor& self,
const at::Scalar& other) {
XLA_FN_COUNTER("xla::");
CheckBinaryOpTypePromotion(out, self, other);
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor::bitwise_or_out(out_tensor, bridge::GetXlaTensor(self), other);
return out;
return DoBinaryOpWithoutPromo(
self, other, [&](const XLATensor& xself, const at::Scalar& xother) {
return XLATensor::bitwise_or(xself, xother);
});
}

at::Tensor& XLANativeFunctions::bitwise_or_out(const at::Tensor& self,
const at::Tensor& other,
at::Tensor& out) {
at::Tensor XLANativeFunctions::bitwise_or(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
DoBinaryOpOut(
self, other, out,
[&](const XLATensor& xself, const XLATensor& xother, XLATensor& xout) {
XLATensor::bitwise_or_out(xout, xself, xother);
return DoBinaryOpWithoutPromo(
self, other, [&](const XLATensor& xself, const XLATensor& xother) {
return XLATensor::bitwise_or(xself, xother);
});
return out;
}

at::Tensor& XLANativeFunctions::bitwise_xor_out(const at::Tensor& self,
const at::Scalar& other,
at::Tensor& out) {
at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self,
const at::Scalar& other) {
XLA_FN_COUNTER("xla::");
CheckBinaryOpTypePromotion(out, self, other);
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor::bitwise_xor_out(out_tensor, bridge::GetXlaTensor(self), other);
return out;
return DoBinaryOpWithoutPromo(
self, other, [&](const XLATensor& xself, const at::Scalar& xother) {
return XLATensor::bitwise_xor(xself, xother);
});
}

at::Tensor& XLANativeFunctions::bitwise_xor_out(const at::Tensor& self,
const at::Tensor& other,
at::Tensor& out) {
at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
DoBinaryOpOut(
self, other, out,
[&](const XLATensor& xself, const XLATensor& xother, XLATensor& xout) {
XLATensor::bitwise_xor_out(xout, xself, xother);
return DoBinaryOpWithoutPromo(
self, other, [&](const XLATensor& xself, const XLATensor& xother) {
return XLATensor::bitwise_xor(xself, xother);
});
return out;
}

at::Tensor XLANativeFunctions::bmm(const at::Tensor& self,
Expand Down Expand Up @@ -1057,14 +1055,11 @@ at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self,
XLATensor::clamp(bridge::GetXlaTensor(self), c10::nullopt, max));
}

at::Tensor& XLANativeFunctions::clamp_max_out(const at::Tensor& self,
const at::Tensor& max,
at::Tensor& out) {
at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self,
const at::Tensor& max) {
XLA_FN_COUNTER("xla::");
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor::clamp_out(out_tensor, bridge::GetXlaTensor(self), c10::nullopt,
max);
return out;
return bridge::AtenFromXlaTensor(
XLATensor::clamp(bridge::GetXlaTensor(self), c10::nullopt, max));
}

at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self,
Expand All @@ -1074,14 +1069,11 @@ at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self,
XLATensor::clamp(bridge::GetXlaTensor(self), min, c10::nullopt));
}

at::Tensor& XLANativeFunctions::clamp_min_out(const at::Tensor& self,
const at::Tensor& min,
at::Tensor& out) {
at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self,
const at::Tensor& min) {
XLA_FN_COUNTER("xla::");
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor::clamp_out(out_tensor, bridge::GetXlaTensor(self), min,
c10::nullopt);
return out;
return bridge::AtenFromXlaTensor(
XLATensor::clamp(bridge::GetXlaTensor(self), min, c10::nullopt));
}

at::Tensor XLANativeFunctions::clone(
Expand Down Expand Up @@ -2880,13 +2872,9 @@ at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) {
return self;
}

at::Tensor& XLANativeFunctions::silu_out(const at::Tensor& self,
at::Tensor& out) {
at::Tensor XLANativeFunctions::silu(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensor out_tensor = bridge::GetXlaTensor(out);
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::silu_out(self_tensor, out_tensor);
return out;
return bridge::AtenFromXlaTensor(XLATensor::silu(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output,
Expand Down
19 changes: 6 additions & 13 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,19 +437,15 @@ class XLATensor : public c10::intrusive_ptr_target {

static XLATensor bitwise_and(const XLATensor& input, const XLATensor& other);

static void bitwise_not_out(XLATensor& out, const XLATensor& input);
static XLATensor bitwise_not(const XLATensor& input);

static void bitwise_or_out(XLATensor& out, const XLATensor& input,
const at::Scalar& other);
static XLATensor bitwise_or(const XLATensor& input, const at::Scalar& other);

static void bitwise_or_out(XLATensor& out, const XLATensor& input,
const XLATensor& other);
static XLATensor bitwise_or(const XLATensor& input, const XLATensor& other);

static void bitwise_xor_out(XLATensor& out, const XLATensor& input,
const at::Scalar& other);
static XLATensor bitwise_xor(const XLATensor& input, const at::Scalar& other);

static void bitwise_xor_out(XLATensor& out, const XLATensor& input,
const XLATensor& other);
static XLATensor bitwise_xor(const XLATensor& input, const XLATensor& other);

// Batch matrix multiplication. Both tensors must be 3D, the batch size must
// match and the remaining two dimensions must be compatible for matrix
Expand All @@ -476,9 +472,6 @@ class XLATensor : public c10::intrusive_ptr_target {
static XLATensor clamp(const XLATensor& input,
const c10::optional<at::Tensor>& min,
const c10::optional<at::Tensor>& max);
static void clamp_out(XLATensor& out, const XLATensor& input,
const c10::optional<at::Tensor>& min,
const c10::optional<at::Tensor>& max);

static XLATensor clone(const XLATensor& input);

Expand Down Expand Up @@ -1022,7 +1015,7 @@ class XLATensor : public c10::intrusive_ptr_target {
static XLATensor selu(const XLATensor& input);
static void selu_(XLATensor& input);

static void silu_out(XLATensor& input, XLATensor& out);
static XLATensor silu(const XLATensor& input);
static XLATensor silu_backward(XLATensor& grad_output, XLATensor& input);
static XLATensor sigmoid(const XLATensor& input);
static XLATensor sigmoid_backward(const XLATensor& grad_output,
Expand Down
49 changes: 16 additions & 33 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,36 +933,36 @@ XLATensor XLATensor::bitwise_and(const XLATensor& input,
return input.CreateFrom(BitwiseAnd(input.GetIrValue(), other.GetIrValue()));
}

void XLATensor::bitwise_not_out(XLATensor& out, const XLATensor& input) {
out.SetIrValue(Not(input.GetIrValue()));
XLATensor XLATensor::bitwise_not(const XLATensor& input) {
return input.CreateFrom(Not(input.GetIrValue()));
}

void XLATensor::bitwise_or_out(XLATensor& out, const XLATensor& input,
const at::Scalar& other) {
XLATensor XLATensor::bitwise_or(const XLATensor& input,
const at::Scalar& other) {
CheckIsIntegralOrPred(input.shape(), "__or__");
torch::lazy::Value constant =
GetIrValueForScalar(other, input.shape(), input.GetDevice());
out.SetIrValue(BitwiseOr(input.GetIrValue(), constant));
return input.CreateFrom(BitwiseOr(input.GetIrValue(), constant));
}

void XLATensor::bitwise_or_out(XLATensor& out, const XLATensor& input,
const XLATensor& other) {
XLATensor XLATensor::bitwise_or(const XLATensor& input,
const XLATensor& other) {
CheckIsIntegralOrPred(input.shape(), "__or__");
out.SetIrValue(BitwiseOr(input.GetIrValue(), other.GetIrValue()));
return input.CreateFrom(BitwiseOr(input.GetIrValue(), other.GetIrValue()));
}

void XLATensor::bitwise_xor_out(XLATensor& out, const XLATensor& input,
const at::Scalar& other) {
XLATensor XLATensor::bitwise_xor(const XLATensor& input,
const at::Scalar& other) {
CheckIsIntegralOrPred(input.shape(), "__xor__");
torch::lazy::Value constant =
GetIrValueForScalar(other, input.shape(), input.GetDevice());
out.SetIrValue(BitwiseXor(input.GetIrValue(), constant));
return input.CreateFrom(BitwiseXor(input.GetIrValue(), constant));
}

void XLATensor::bitwise_xor_out(XLATensor& out, const XLATensor& input,
const XLATensor& other) {
XLATensor XLATensor::bitwise_xor(const XLATensor& input,
const XLATensor& other) {
CheckIsIntegralOrPred(input.shape(), "__xor__");
out.SetIrValue(BitwiseXor(input.GetIrValue(), other.GetIrValue()));
return input.CreateFrom(BitwiseXor(input.GetIrValue(), other.GetIrValue()));
}

XLATensor XLATensor::bmm(const XLATensor& batch1, const XLATensor& batch2) {
Expand Down Expand Up @@ -1057,23 +1057,6 @@ XLATensor XLATensor::clamp(const XLATensor& input,
return input.CreateFrom(res);
}

void XLATensor::clamp_out(XLATensor& out, const XLATensor& input,
const c10::optional<at::Tensor>& min,
const c10::optional<at::Tensor>& max) {
XLA_CHECK(min || max)
<< "At least one of \'min\' or \'max\' must not be None";
torch::lazy::Value res = input.GetIrValue();
if (min) {
res = torch::lazy::MakeNode<Maximum>(
res, bridge::GetXlaTensor(*min).GetIrValue(),
std::vector<torch::lazy::Shape>());
}
if (max) {
res = Min(res, bridge::GetXlaTensor(*max).GetIrValue());
}
out.SetInPlaceIrValue(res);
}

XLATensor XLATensor::clone(const XLATensor& input) {
return input.CreateFrom(input.GetIrValue());
}
Expand Down Expand Up @@ -2482,8 +2465,8 @@ void XLATensor::selu_(XLATensor& input) {
input.SetInPlaceIrValue(Selu(input.GetIrValue()));
}

void XLATensor::silu_out(XLATensor& input, XLATensor& out) {
out.SetInPlaceIrValue(SiLU(input.GetIrValue()));
XLATensor XLATensor::silu(const XLATensor& input) {
return input.CreateFrom(SiLU(input.GetIrValue()));
}

XLATensor XLATensor::silu_backward(XLATensor& grad_output, XLATensor& input) {
Expand Down
Loading