Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,16 @@ TEST_F(AtenXlaTensorTest, TestPow) {
});
}

TEST_F(AtenXlaTensorTest, TestFmod) {
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat)) * 100.0;
at::Tensor b = at::fmod(a, 2.0);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::fmod(xla_a, 2.0);
AllClose(b, xla_b);
});
}

TEST_F(AtenXlaTensorTest, TestWhere) {
at::Tensor a = at::rand({3, 3}, at::TensorOptions(at::kFloat));
at::Tensor b = at::rand({3, 3}, at::TensorOptions(at::kFloat));
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,24 @@ at::Tensor& AtenXlaType::div_(at::Tensor& self, const at::Tensor& other) const {
return self;
}

at::Tensor AtenXlaType::fmod(const at::Tensor& self,
const at::Tensor& other) const {
return bridge::AtenFromXlaTensor(
XLATensor::fmod(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor& AtenXlaType::fmod_(at::Tensor& self, at::Scalar other) const {
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::fmod_(self_tensor, other);
return self;
}

at::Tensor& AtenXlaType::fmod_(at::Tensor& self, const at::Tensor& other) const {
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::fmod_(self_tensor, bridge::GetXlaTensor(other));
return self;
}

at::Tensor AtenXlaType::ne(const at::Tensor& self, at::Scalar other) const {
return bridge::AtenFromXlaTensor(
XLATensor::ne(bridge::GetXlaTensor(self), other));
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ class AtenXlaType : public AtenXlaTypeBase {
const at::Tensor& other) const override;
at::Tensor& div_(at::Tensor& self, const at::Tensor& other) const override;

at::Tensor fmod(const at::Tensor& self,
const at::Tensor& other) const override;
at::Tensor& fmod_(at::Tensor& self, at::Scalar other) const override;
at::Tensor& fmod_(at::Tensor& self, const at::Tensor& other) const override;

at::Tensor ne(const at::Tensor& self, at::Scalar other) const override;

at::Tensor ne(const at::Tensor& self, const at::Tensor& other) const override;
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ PTXLA_UNARY_OP(Floor, at::aten::floor, xla::Floor);
PTXLA_BINARY_OP(Min, at::aten::min, xla::Min);
PTXLA_BINARY_OP(Max, at::aten::max, xla::Max);
PTXLA_BINARY_OP(Pow, at::aten::pow, xla::Pow);
PTXLA_BINARY_OP(Fmod, at::aten::fmod, xla::Rem);

NodePtr ReciprocalOp(const Value& input) {
auto lower_fn = [](const ir::Node& node,
Expand All @@ -73,7 +74,7 @@ NodePtr ReciprocalOp(const Value& input) {
};
return ir::ops::GenericOp(ir::OpKind(at::aten::reciprocal), ir::OpList{input},
input.shape(), std::move(lower_fn));
}
}

NodePtr ReluOp(const Value& input) {
auto lower_fn = [](const ir::Node& node,
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ NodePtr ReciprocalOp(const Value& input);

NodePtr Pow(const Value& input, const Value& exponent);

NodePtr Fmod(const Value& dividend, const Value& divisor);

NodePtr Sigmoid(const Value& input);

NodePtr Clamp(const Value& input, c10::optional<at::Scalar> min,
Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,20 @@ void XLATensor::div_(XLATensor& input, const at::Scalar& other) {
input.SetIrValue(input.GetIrValue() / constant);
}

XLATensor XLATensor::fmod(const XLATensor& input, const XLATensor& other) {
return Create(ir::ops::Fmod(input.GetIrValue(), other.GetIrValue()),
input.GetDevice());
}

void XLATensor::fmod_(XLATensor& input, at::Scalar other) {
ir::NodePtr constant = ir::ops::ScalarOp(other, input.shape());
input.SetIrValue(ir::ops::Fmod(input.GetIrValue(), constant));
}

void XLATensor::fmod_(XLATensor& input, const XLATensor& other) {
input.SetIrValue(ir::ops::Fmod(input.GetIrValue(), other.GetIrValue()));
}

void XLATensor::zero_(XLATensor& input) {
input.SetIrValue(ir::ops::ScalarOp(0.0, input.shape()));
}
Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class XLATensor {
static void div_(XLATensor& input, const XLATensor& other);
static void div_(XLATensor& input, const at::Scalar& other);

static XLATensor fmod(const XLATensor& input, const XLATensor& other);
static void fmod_(XLATensor& input, at::Scalar other);
static void fmod_(XLATensor& input, const XLATensor& other);

static void zero_(XLATensor& input);

// Additional operations which are part of the PyTorch Tensor functionality.
Expand Down Expand Up @@ -403,7 +407,7 @@ class XLATensor {
const XLATensor& input,
tensorflow::gtl::ArraySlice<const xla::int64> repeats);

static std::vector<XLATensor> split(const XLATensor& self,
static std::vector<XLATensor> split(const XLATensor& input,
xla::int64 split_size, xla::int64 dim);

// Squeeze out all trivial (size 1) dimensions.
Expand Down