Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@ at::Tensor AtenXlaType::fmod(const at::Tensor& self,
XLATensor::fmod(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor AtenXlaType::fmod(const at::Tensor& self, at::Scalar other) const {
return bridge::AtenFromXlaTensor(
XLATensor::fmod(bridge::GetXlaTensor(self), other));
}

at::Tensor& AtenXlaType::fmod_(at::Tensor& self, at::Scalar other) const {
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::fmod_(self_tensor, other);
Expand Down Expand Up @@ -723,8 +728,8 @@ at::Tensor& AtenXlaType::atan_(at::Tensor& self) const {

at::Tensor AtenXlaType::atan2(const at::Tensor& self,
const at::Tensor& other) const {
return bridge::AtenFromXlaTensor(
XLATensor::atan2(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
return bridge::AtenFromXlaTensor(XLATensor::atan2(
bridge::GetXlaTensor(self), bridge::GetXlaTensor(other)));
}

at::Tensor& AtenXlaType::atan2_(at::Tensor& self,
Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class AtenXlaType : public AtenXlaTypeBase {

at::Tensor fmod(const at::Tensor& self,
const at::Tensor& other) const override;
at::Tensor fmod(const at::Tensor& self, at::Scalar other) const override;
at::Tensor& fmod_(at::Tensor& self, at::Scalar other) const override;
at::Tensor& fmod_(at::Tensor& self, const at::Tensor& other) const override;

Expand Down Expand Up @@ -224,7 +225,8 @@ class AtenXlaType : public AtenXlaTypeBase {
at::Tensor atan(const at::Tensor& self) const override;
at::Tensor& atan_(at::Tensor& self) const override;

at::Tensor atan2(const at::Tensor& self, const at::Tensor& other) const override;
at::Tensor atan2(const at::Tensor& self,
const at::Tensor& other) const override;
at::Tensor& atan2_(at::Tensor& self, const at::Tensor& other) const override;

at::Tensor tan(const at::Tensor& self) const override;
Expand Down
8 changes: 7 additions & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,11 @@ XLATensor XLATensor::fmod(const XLATensor& input, const XLATensor& other) {
ir::ops::Fmod(input.GetIrValue(), other.GetIrValue()));
}

XLATensor XLATensor::fmod(const XLATensor& input, at::Scalar other) {
ir::NodePtr constant = ir::ops::ScalarOp(other, input.shape());
return input.CreateFrom(ir::ops::Fmod(input.GetIrValue(), constant));
}

void XLATensor::fmod_(XLATensor& input, at::Scalar other) {
ir::NodePtr constant = ir::ops::ScalarOp(other, input.shape());
input.SetIrValue(ir::ops::Fmod(input.GetIrValue(), constant));
Expand Down Expand Up @@ -1003,7 +1008,8 @@ void XLATensor::atan_(XLATensor& input) {
}

XLATensor XLATensor::atan2(const XLATensor& input, const XLATensor& other) {
return input.CreateFrom(ir::ops::Atan2(input.GetIrValue(), other.GetIrValue()));
return input.CreateFrom(
ir::ops::Atan2(input.GetIrValue(), other.GetIrValue()));
}

void XLATensor::atan2_(XLATensor& input, const XLATensor& other) {
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class XLATensor {
static void div_(XLATensor& input, const at::Scalar& other);

static XLATensor fmod(const XLATensor& input, const XLATensor& other);
static XLATensor fmod(const XLATensor& input, at::Scalar other);
static void fmod_(XLATensor& input, at::Scalar other);
static void fmod_(XLATensor& input, const XLATensor& other);

Expand Down