diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 9e30d14c5602..6843d3e6b29d 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -489,6 +489,11 @@ at::Tensor AtenXlaType::fmod(const at::Tensor& self, XLATensor::fmod(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } +at::Tensor AtenXlaType::fmod(const at::Tensor& self, at::Scalar other) const { + return bridge::AtenFromXlaTensor( + XLATensor::fmod(bridge::GetXlaTensor(self), other)); +} + at::Tensor& AtenXlaType::fmod_(at::Tensor& self, at::Scalar other) const { XLATensor self_tensor = bridge::GetXlaTensor(self); XLATensor::fmod_(self_tensor, other); @@ -723,8 +728,8 @@ at::Tensor& AtenXlaType::atan_(at::Tensor& self) const { at::Tensor AtenXlaType::atan2(const at::Tensor& self, const at::Tensor& other) const { - return bridge::AtenFromXlaTensor( - XLATensor::atan2(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); + return bridge::AtenFromXlaTensor(XLATensor::atan2( + bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } at::Tensor& AtenXlaType::atan2_(at::Tensor& self, diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 949559c60489..283cc2493674 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -148,6 +148,7 @@ class AtenXlaType : public AtenXlaTypeBase { at::Tensor fmod(const at::Tensor& self, const at::Tensor& other) const override; + at::Tensor fmod(const at::Tensor& self, at::Scalar other) const override; at::Tensor& fmod_(at::Tensor& self, at::Scalar other) const override; at::Tensor& fmod_(at::Tensor& self, const at::Tensor& other) const override; @@ -224,7 +225,8 @@ class AtenXlaType : public AtenXlaTypeBase { at::Tensor atan(const at::Tensor& self) const override; at::Tensor& atan_(at::Tensor& self) const override; - at::Tensor atan2(const at::Tensor& self, const at::Tensor& other) const override; + at::Tensor atan2(const at::Tensor& self, + const at::Tensor& other) const override; at::Tensor& atan2_(at::Tensor& self, const at::Tensor& other) const override; at::Tensor tan(const at::Tensor& self) const override; diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 5cc6ab2b1528..59aab926895b 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -609,6 +609,11 @@ XLATensor XLATensor::fmod(const XLATensor& input, const XLATensor& other) { ir::ops::Fmod(input.GetIrValue(), other.GetIrValue())); } +XLATensor XLATensor::fmod(const XLATensor& input, at::Scalar other) { + ir::NodePtr constant = ir::ops::ScalarOp(other, input.shape()); + return input.CreateFrom(ir::ops::Fmod(input.GetIrValue(), constant)); +} + void XLATensor::fmod_(XLATensor& input, at::Scalar other) { ir::NodePtr constant = ir::ops::ScalarOp(other, input.shape()); input.SetIrValue(ir::ops::Fmod(input.GetIrValue(), constant)); @@ -1003,7 +1008,8 @@ void XLATensor::atan_(XLATensor& input) { } XLATensor XLATensor::atan2(const XLATensor& input, const XLATensor& other) { - return input.CreateFrom(ir::ops::Atan2(input.GetIrValue(), other.GetIrValue())); + return input.CreateFrom( + ir::ops::Atan2(input.GetIrValue(), other.GetIrValue())); } void XLATensor::atan2_(XLATensor& input, const XLATensor& other) { diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index a617be43ab87..f125d0520c0f 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -118,6 +118,7 @@ class XLATensor { static void div_(XLATensor& input, const at::Scalar& other); static XLATensor fmod(const XLATensor& input, const XLATensor& other); + static XLATensor fmod(const XLATensor& input, at::Scalar other); static void fmod_(XLATensor& input, at::Scalar other); static void fmod_(XLATensor& input, const XLATensor& other);