Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,33 +905,6 @@ at::Tensor XLANativeFunctions::binary_cross_entropy_with_logits(
IsDefined(pos_weight) ? *pos_weight : at::Tensor(), reduction);
}

at::Tensor XLANativeFunctions::logical_not(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::logical_not(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::logical_xor(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::logical_xor(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor XLANativeFunctions::logical_and(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::logical_and(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor XLANativeFunctions::logical_or(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::logical_or(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor XLANativeFunctions::bitwise_and(const at::Tensor& self,
const at::Scalar& other) {
XLA_FN_COUNTER("xla::");
Expand Down
93 changes: 0 additions & 93 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -943,99 +943,6 @@ torch::lazy::NodePtr Lerp(const torch::lazy::Value& start,
return start + weight * (end - start);
}

torch::lazy::NodePtr LogicalNot(const torch::lazy::Value& input) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp op = loctx->GetOutputOp(node.operand(0));
return node.ReturnOp(XlaHelpers::PromotedLogicalUnaryOp(
op, [](xla::XlaOp lhs) { return xla::Not(lhs); }),
loctx);
};
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return XlaHelpers::PromotedLogicalUnaryOp(
operands[0], [](xla::XlaOp lhs) { return xla::Not(lhs); });
};
return GenericOp(
torch::lazy::OpKind(at::aten::logical_not), {input},
[&]() { return InferOutputShape({GetXlaShape(input)}, shape_fn); },
std::move(lower_fn));
}

torch::lazy::NodePtr LogicalXor(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(0));
xla::XlaOp op2 = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(
XlaHelpers::PromotedLogicalBinaryOp(
op1, op2,
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Xor(lhs, rhs); }),
loctx);
};
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return XlaHelpers::PromotedLogicalBinaryOp(
operands[0], operands[1],
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Xor(lhs, rhs); });
};
return GenericOp(torch::lazy::OpKind(at::aten::logical_xor), {input, other},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(other)}, shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr LogicalAnd(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(0));
xla::XlaOp op2 = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(
XlaHelpers::PromotedLogicalBinaryOp(
op1, op2,
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::And(lhs, rhs); }),
loctx);
};
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return XlaHelpers::PromotedLogicalBinaryOp(
operands[0], operands[1],
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::And(lhs, rhs); });
};
return GenericOp(torch::lazy::OpKind(at::aten::logical_and), {input, other},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(other)}, shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr LogicalOr(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto lower_fn = [](const XlaNode& node,
LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp op1 = loctx->GetOutputOp(node.operand(0));
xla::XlaOp op2 = loctx->GetOutputOp(node.operand(1));
return node.ReturnOp(
XlaHelpers::PromotedLogicalBinaryOp(
op1, op2,
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Or(lhs, rhs); }),
loctx);
};
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return XlaHelpers::PromotedLogicalBinaryOp(
operands[0], operands[1],
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Or(lhs, rhs); });
};
return GenericOp(torch::lazy::OpKind(at::aten::logical_or), {input, other},
[&]() {
return InferOutputShape(
{GetXlaShape(input), GetXlaShape(other)}, shape_fn);
},
std::move(lower_fn));
}

torch::lazy::NodePtr XLogY(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto lower_fn = [](const XlaNode& node,
Expand Down
11 changes: 0 additions & 11 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,6 @@ torch::lazy::NodePtr Lerp(const torch::lazy::Value& start,
const torch::lazy::Value& end,
const torch::lazy::Value& weight);

torch::lazy::NodePtr LogicalNot(const torch::lazy::Value& input);

torch::lazy::NodePtr LogicalXor(const torch::lazy::Value& input,
const torch::lazy::Value& other);

torch::lazy::NodePtr LogicalAnd(const torch::lazy::Value& input,
const torch::lazy::Value& other);

torch::lazy::NodePtr LogicalOr(const torch::lazy::Value& input,
const torch::lazy::Value& other);

torch::lazy::NodePtr XLogY(const torch::lazy::Value& input,
const torch::lazy::Value& other);

Expand Down
38 changes: 38 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,44 @@ torch_xla::XlaOpVector Logdet::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::LogDet(xla_input), loctx);
}

torch_xla::XlaOpVector LogicalAnd::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));

return ReturnOp(
XlaHelpers::PromotedLogicalBinaryOp(
xla_input, xla_other,
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::And(lhs, rhs); }),
loctx);
}

torch_xla::XlaOpVector LogicalNot::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(XlaHelpers::PromotedLogicalUnaryOp(
xla_input, [](xla::XlaOp lhs) { return xla::Not(lhs); }),
loctx);
}

torch_xla::XlaOpVector LogicalOr::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
return ReturnOp(
XlaHelpers::PromotedLogicalBinaryOp(
xla_input, xla_other,
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Or(lhs, rhs); }),
loctx);
}

torch_xla::XlaOpVector LogicalXor::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
return ReturnOp(
XlaHelpers::PromotedLogicalBinaryOp(
xla_input, xla_other,
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Xor(lhs, rhs); }),
loctx);
}

torch_xla::XlaOpVector Maximum::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(operand(1));
Expand Down
38 changes: 38 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,44 @@ xla::Shape LogdetOutputShape(const torch::lazy::Value& input) {
return logdet_shape;
}

xla::Shape LogicalAndOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return XlaHelpers::PromotedLogicalBinaryOp(
operands[0], operands[1],
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::And(lhs, rhs); });
};
return InferOutputShape({GetXlaShape(input), GetXlaShape(other)}, shape_fn);
}

xla::Shape LogicalNotOutputShape(const torch::lazy::Value& input) {
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return XlaHelpers::PromotedLogicalUnaryOp(
operands[0], [](xla::XlaOp lhs) { return xla::Not(lhs); });
};
return InferOutputShape({GetXlaShape(input)}, shape_fn);
}

xla::Shape LogicalOrOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return XlaHelpers::PromotedLogicalBinaryOp(
operands[0], operands[1],
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Or(lhs, rhs); });
};
return InferOutputShape({GetXlaShape(input), GetXlaShape(other)}, shape_fn);
}

xla::Shape LogicalXorOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto shape_fn = [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return XlaHelpers::PromotedLogicalBinaryOp(
operands[0], operands[1],
[](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Xor(lhs, rhs); });
};
return InferOutputShape({GetXlaShape(input), GetXlaShape(other)}, shape_fn);
}

xla::Shape MaximumOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other) {
auto lower_for_shape_fn =
Expand Down
11 changes: 11 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ xla::Shape InverseOutputShape(const torch::lazy::Value& input);

xla::Shape LogdetOutputShape(const torch::lazy::Value& input);

xla::Shape LogicalAndOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other);

xla::Shape LogicalNotOutputShape(const torch::lazy::Value& input);

xla::Shape LogicalOrOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other);

xla::Shape LogicalXorOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other);

xla::Shape MaximumOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& other);

Expand Down
11 changes: 0 additions & 11 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -769,17 +769,6 @@ class XLATensor : public c10::intrusive_ptr_target {
static XLATensorPtr log1p(const XLATensorPtr& input);
static void log1p_(XLATensorPtr& input);

static XLATensorPtr logical_not(const XLATensorPtr& input);

static XLATensorPtr logical_xor(const XLATensorPtr& input,
const XLATensorPtr& other);

static XLATensorPtr logical_and(const XLATensorPtr& input,
const XLATensorPtr& other);

static XLATensorPtr logical_or(const XLATensorPtr& input,
const XLATensorPtr& other);

static XLATensorPtr logsumexp(const XLATensorPtr& input,
std::vector<int64_t> dimensions,
bool keep_reduced_dimensions);
Expand Down
23 changes: 0 additions & 23 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1730,29 +1730,6 @@ void XLATensor::log1p_(XLATensorPtr& input) {
input->SetInPlaceIrValue(Log1p(input->GetIrValue()));
}

XLATensorPtr XLATensor::logical_not(const XLATensorPtr& input) {
return input->CreateFrom(LogicalNot(input->GetIrValue()),
at::ScalarType::Bool);
}

XLATensorPtr XLATensor::logical_xor(const XLATensorPtr& input,
const XLATensorPtr& other) {
return input->CreateFrom(LogicalXor(input->GetIrValue(), other->GetIrValue()),
at::ScalarType::Bool);
}

XLATensorPtr XLATensor::logical_and(const XLATensorPtr& input,
const XLATensorPtr& other) {
return input->CreateFrom(LogicalAnd(input->GetIrValue(), other->GetIrValue()),
at::ScalarType::Bool);
}

XLATensorPtr XLATensor::logical_or(const XLATensorPtr& input,
const XLATensorPtr& other) {
return input->CreateFrom(LogicalOr(input->GetIrValue(), other->GetIrValue()),
at::ScalarType::Bool);
}

XLATensorPtr XLATensor::logsumexp(const XLATensorPtr& input,
std::vector<int64_t> dimensions,
bool keep_reduced_dimensions) {
Expand Down
8 changes: 4 additions & 4 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ full_codegen:
- floor
- inverse
- logdet
- logical_and
- logical_not
- logical_or
- logical_xor
- maximum
- minimum
- reciprocal
Expand Down Expand Up @@ -178,10 +182,6 @@ supported:
- log10
- log_sigmoid_backward
- log_sigmoid_forward
- logical_and
- logical_not
- logical_or
- logical_xor
- logsumexp
- lt.Scalar
- lt.Tensor
Expand Down