Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3416,6 +3416,24 @@ TEST_F(AtenXlaTensorTest, TestLogsumexp) {
}
}

TEST_F(AtenXlaTensorTest, TestXLogY) {
torch::Tensor a = torch::rand({5, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({5, 5}, torch::TensorOptions(torch::kFloat));
a[0][0] = 0.0;
b[0][2] = std::nan("1");
b[0][0] = std::nan("1");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's name this std::nan("2") for better debugging differentiation later?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between nan(1) and nan(2)? I thought they are just nan.

Copy link
Collaborator

@miladm miladm Nov 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I was thinking when something fails it shows nan(1) vs. nan(2) in the error message. It's nit. Feel free to ignore if this is the only change you have left.

torch::Tensor c = torch::xlogy(a, b);
ForEachDevice([&](const torch::Device& device) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = CopyToDevice(b, device);
torch::Tensor xla_c = torch::xlogy(xla_a, xla_b);
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::xlogy", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSiLU) {
torch::Tensor a = torch::rand({2, 2}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::silu(a);
Expand Down Expand Up @@ -10508,6 +10526,7 @@ TEST_F(AtenXlaTensorTest, TestKlDivBackward) {
/*atol=*/1e-5);
});
}
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEmbeddingBackward) {
Expand Down
10 changes: 5 additions & 5 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def __new__(cls, name, variant_test_name=""):
AllowedOpInfoEntry('linalg.matrix_power'),
AllowedOpInfoEntry('linalg.qr'),
AllowedOpInfoEntry('linalg.slogdet'),
AllowedOpInfoEntry('log'),
AllowedOpInfoEntry('log10'),
AllowedOpInfoEntry('log1p'),
AllowedOpInfoEntry('log2'),
AllowedOpInfoEntry('logaddexp'),
AllowedOpInfoEntry('logaddexp2'),
AllowedOpInfoEntry('logical_not'),
Expand All @@ -132,7 +136,6 @@ def __new__(cls, name, variant_test_name=""):
AllowedOpInfoEntry('min', 'reduction_no_dim'),
AllowedOpInfoEntry('nansum'),
AllowedOpInfoEntry('quantile'),
AllowedOpInfoEntry('nanquantile'),
AllowedOpInfoEntry('maximum'),
AllowedOpInfoEntry('minimum'),
AllowedOpInfoEntry('nn.functional.hardswish'),
Expand Down Expand Up @@ -253,6 +256,7 @@ def __new__(cls, name, variant_test_name=""):
# AllowedOpInfoEntry('matmul'), // failing on CPU
# AllowedOpInfoEntry('__rmatmul__'), // failing on CPU
# AllowedOpInfoEntry('linalg.eigvals'), // failing on TPU
# AllowedOpInfoEntry('nanquantile'), // TODO: retried at head once xlogy pr merged
# AllowedOpInfoEntry('amax'),
# AllowedOpInfoEntry('amin'),
# AllowedOpInfoEntry('norm', 'nuc'),
Expand Down Expand Up @@ -293,10 +297,6 @@ def __new__(cls, name, variant_test_name=""):
# AllowedOpInfoEntry('linalg.norm'),
# AllowedOpInfoEntry('linalg.matrix_norm'),
# AllowedOpInfoEntry('linalg.vector_norm'),
# AllowedOpInfoEntry('log'),
# AllowedOpInfoEntry('log10'),
# AllowedOpInfoEntry('log1p'),
# AllowedOpInfoEntry('log2'),
# AllowedOpInfoEntry('std_mean'),
# AllowedOpInfoEntry('sum'),
# AllowedOpInfoEntry('mean'),
Expand Down
18 changes: 7 additions & 11 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,17 +1693,6 @@ at::Tensor XLANativeFunctions::kl_div(const at::Tensor& self,
return at::native::kl_div(self, target, reduction, log_target);
}

at::Tensor XLANativeFunctions::kl_div_backward(const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& target,
int64_t reduction,
bool log_target) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::kl_div_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
bridge::GetXlaTensor(target), reduction, log_target));
}

std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::kthvalue(
const at::Tensor& self, int64_t k, int64_t dim, bool keepdim) {
XLA_FN_COUNTER("xla::");
Expand Down Expand Up @@ -1843,6 +1832,13 @@ at::Tensor XLANativeFunctions::logdet(const at::Tensor& self) {
XLATensor::logdet(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::xlogy(const at::Tensor& self,
const at::Tensor& other) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::xlogy(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor XLANativeFunctions::lt(const at::Tensor& self,
const at::Scalar& other) {
XLA_FN_COUNTER("xla::");
Expand Down
20 changes: 20 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,26 @@ NodePtr LogicalOr(const Value& input, const Value& other) {
std::move(lower_fn));
}

NodePtr XLogY(const Value& input, const Value& other) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_output = BuildXLogY(xla_input, xla_other);
return node.ReturnOp(xla_output, loctx);
};
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we not have the same error check in lower_fn?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually lower and shape fn share the same operand. Shape function is being run first so no need to redo the check in lower fn.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, that statement is not true..shape fn and lower fn uses different parameter. I guess we don't really have a good reason for only doing it in shape fn. I check the history and it seems like it is being added for one shape fn and then in the later code we just copy the code for other ops hence inherit this check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer this leave this check here or I can delete it. If we decide to add operand check for lowering function, we should do that for all ops. It is currently only being done for shape_fn sometimes which is very inconsistent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg. let's leave it out.

return BuildXLogY(operands[0], operands[1]);
};
return GenericOp(OpKind(at::aten::xlogy), {input, other},
[&]() {
return InferOutputShape({input.shape(), other.shape()},
lower_for_shape_fn);
},
std::move(lower_fn));
}

NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf,
const Value& neginf) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ NodePtr LogicalAnd(const Value& input, const Value& other);

NodePtr LogicalOr(const Value& input, const Value& other);

NodePtr XLogY(const Value& input, const Value& other);

NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf,
const Value& neginf);

Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,8 @@ class XLATensor {
std::vector<xla::int64_t> dimensions,
bool keep_reduced_dimensions);

static XLATensor xlogy(const XLATensor& input, const XLATensor& other);

static XLATensor lt(const XLATensor& input, const at::Scalar& other);

static XLATensor lt(const XLATensor& input, const XLATensor& other);
Expand Down
47 changes: 34 additions & 13 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ ir::Value GetIrValueOrDefault(const XLATensor& input,
ir::Value GetFloatingIrValue(const XLATensor& input,
at::ScalarType float_type) {
ir::Value input_value = input.GetIrValue();
if (!xla::primitive_util::IsFloatingPointType(
input_value.shape().element_type())) {
if (xla::primitive_util::IsIntegralType(input_value.shape().element_type())) {
input_value = ir::MakeNode<ir::ops::Cast>(input_value, float_type);
}
return input_value;
Expand Down Expand Up @@ -1449,14 +1448,6 @@ XLATensor XLATensor::isnan(const XLATensor& input) {
at::ScalarType::Bool);
}

XLATensor XLATensor::kl_div_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
xla::int64_t reduction, bool log_target) {
return tensor_ops::KlDivBackward(grad_output, input, target,
GetXlaReductionMode(reduction), log_target);
}

std::tuple<XLATensor, XLATensor> XLATensor::kthvalue(const XLATensor& input,
xla::int64_t k,
xla::int64_t dim,
Expand Down Expand Up @@ -1552,12 +1543,25 @@ XLATensor XLATensor::lerp(const XLATensor& input, const XLATensor& end,
}

XLATensor XLATensor::log(const XLATensor& input) {
return input.CreateFrom(ir::ops::Log(input.GetIrValue()));
// Here we explictly pass c10::nullopt as logical_element_type because
// otherwise result will inherit the input's logical_element_type. In the
// case of log(int) -> float, we want to derive the dtype from IR value
// instead of input's logical_element_type.
return input.CreateFrom(
ir::ops::Log(GetFloatingIrValue(input, at::ScalarType::Float)),
c10::nullopt);
}

XLATensor XLATensor::log_base(const XLATensor& input, ir::OpKind op,
double base) {
return input.CreateFrom(ir::ops::LogBase(input.GetIrValue(), op, base));
// Here we explictly pass c10::nullopt as logical_element_type because
// otherwise result will inherit the input's logical_element_type. In the
// case of logbase(int) -> float, we want to derive the dtype from IR value
// instead of input's logical_element_type.
return input.CreateFrom(
ir::ops::LogBase(GetFloatingIrValue(input, at::ScalarType::Float), op,
base),
c10::nullopt);
}

XLATensor XLATensor::log_sigmoid(const XLATensor& input) {
Expand Down Expand Up @@ -1599,7 +1603,13 @@ XLATensor XLATensor::log_softmax_backward(const XLATensor& grad_output,
}

XLATensor XLATensor::log1p(const XLATensor& input) {
return input.CreateFrom(ir::ops::Log1p(input.GetIrValue()));
// Here we explictly pass c10::nullopt as logical_element_type because
// otherwise result will inherit the input's logical_element_type. In the
// case of log1p(int) -> float, we want to derive the dtype from IR value
// instead of input's logical_element_type.
return input.CreateFrom(
ir::ops::Log1p(GetFloatingIrValue(input, at::ScalarType::Float)),
c10::nullopt);
}

void XLATensor::log1p_(XLATensor& input) {
Expand Down Expand Up @@ -1646,6 +1656,17 @@ XLATensor XLATensor::logsumexp(const XLATensor& input,
keep_reduced_dimensions));
}

XLATensor XLATensor::xlogy(const XLATensor& input, const XLATensor& other) {
// Here we explictly pass c10::nullopt as logical_element_type because
// otherwise result will inherit the input's logical_element_type. In the
// case of xlogy(int,int) -> float, we want to derive the dtype from IR value
// instead of input's logical_element_type.
return input.CreateFrom(
ir::ops::XLogY(input.GetIrValue(),
GetFloatingIrValue(other, at::ScalarType::Float)),
c10::nullopt);
}

XLATensor XLATensor::lt(const XLATensor& input, const at::Scalar& other) {
return DispatchComparisonOp(at::aten::lt, input, other);
}
Expand Down
25 changes: 0 additions & 25 deletions torch_xla/csrc/tensor_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,31 +59,6 @@ XLATensor Cross(const XLATensor& input, const XLATensor& other,
return XLATensor::stack({s1, s2, s3}, canonical_dim);
}

XLATensor KlDivBackward(const XLATensor& grad_output, const XLATensor& input,
const XLATensor& target, ReductionMode reduction,
bool log_target) {
auto input_shape_ref = input.shape();
XLATensor expanded_grad_output = XLATensor::expand(
grad_output,
xla::util::ToVector<xla::int64_t>(input_shape_ref.get().dimensions()));
XLATensor grad_input;
if (!log_target) {
grad_input = XLATensor::where(
XLATensor::gt(target, 0),
XLATensor::neg(XLATensor::mul(target, expanded_grad_output)),
XLATensor::full_like(input, 0, input.GetDevice(), c10::nullopt));
} else {
grad_input = XLATensor::neg(
XLATensor::mul(XLATensor::exp(target), expanded_grad_output));
}
if (reduction == ReductionMode::kMean) {
XLATensor dims_size = XLATensor::get_dimensions_size(
input, XlaHelpers::GetAllDimensions(input_shape_ref));
grad_input = XLATensor::div(grad_input, dims_size);
}
return grad_input;
}

XLATensor MakeMatrixWithDiagonal(const XLATensor& input,
xla::int64_t diagonal) {
xla::int64_t size = input.shape().get().dimensions(0);
Expand Down
4 changes: 0 additions & 4 deletions torch_xla/csrc/tensor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ namespace tensor_ops {
XLATensor Cross(const XLATensor& input, const XLATensor& other,
c10::optional<xla::int64_t> dim);

XLATensor KlDivBackward(const XLATensor& grad_output, const XLATensor& input,
const XLATensor& target, ReductionMode reduction,
bool log_target);

XLATensor MakeMatrixWithDiagonal(const XLATensor& input, xla::int64_t diagonal);

XLATensor SmoothL1Loss(const XLATensor& input, const XLATensor& target,
Expand Down
17 changes: 17 additions & 0 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,4 +866,21 @@ std::vector<xla::XlaOp> BuildSgdOptimizerStep(
return results;
}

xla::XlaOp BuildXLogY(xla::XlaOp input, xla::XlaOp other) {
// input and xla::Log(other) can have different types, need to promote
// the multiply.
xla::XlaOp res = XlaHelpers::PromotedMul(input, xla::Log(other));
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
const xla::Shape& res_shape = XlaHelpers::ShapeOfXlaOp(res);
xla::XlaOp zero = xla::Zero(input.builder(), input_shape.element_type());
xla::XlaOp zeros = xla::ZerosLike(res);
// expand the input and other to the result shape to filter the result.
input = BuildExpand(input, res_shape.dimensions());
other = BuildExpand(other, res_shape.dimensions());
res = xla::Select(xla::Eq(input, zero), zeros, res);
// nan replacement must happen after zero replacement
res = xla::Select(xla::IsNan(other), other, res);
return res;
}

} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,6 @@ std::vector<xla::XlaOp> BuildSgdOptimizerStep(
const xla::XlaOp& lr, const xla::XlaOp& dampening, bool use_weight_decay,
bool use_momentum, bool use_nesterov);

xla::XlaOp BuildXLogY(xla::XlaOp input, xla::XlaOp other);

} // namespace torch_xla
2 changes: 1 addition & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ supported:
- inverse
- isnan
- kl_div
- kl_div_backward
- kthvalue
- l1_loss
- l1_loss_backward
Expand Down Expand Up @@ -325,6 +324,7 @@ supported:
- var.dim
- var_mean.correction
- view
- xlogy.Tensor
- zero_
autograd:
- max_pool2d
Expand Down