diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 162d496fc597..dfecb8a83b7b 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -1099,6 +1099,18 @@ TEST_F(AtenXlaTensorTest, TestArgMinSameValue) { }); } +TEST_F(AtenXlaTensorTest, TestArgMinWrapper) { + at::Tensor a = at::rand({4, 4, 4}, at::TensorOptions(at::kFloat)); + for (int dim : {1, -2}) { + at::Tensor b = at::_argmin(a, dim, /*keepdim=*/false); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::_argmin(xla_a, dim, /*keepdim=*/false); + AllClose(b, xla_b); + }); + } +} + TEST_F(AtenXlaTensorTest, TestArgMax) { at::Tensor a = at::rand({4, 4, 4}, at::TensorOptions(at::kFloat)); at::Tensor b = at::argmax(a); @@ -1143,6 +1155,18 @@ TEST_F(AtenXlaTensorTest, TestArgMaxSameValue) { }); } +TEST_F(AtenXlaTensorTest, TestArgMaxWrapper) { + at::Tensor a = at::rand({4, 4, 4}, at::TensorOptions(at::kFloat)); + for (int dim : {1, -2}) { + at::Tensor b = at::_argmax(a, dim, /*keepdim=*/false); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = at::_argmax(xla_a, dim, /*keepdim=*/false); + AllClose(b, xla_b); + }); + } +} + TEST_F(AtenXlaTensorTest, TestAsin) { at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat)); at::Tensor b = at::asin(a); @@ -1497,6 +1521,18 @@ TEST_F(AtenXlaTensorTest, TestMatmulBcast) { }); } +TEST_F(AtenXlaTensorTest, TestDot) { + at::Tensor a = at::rand({4}, at::TensorOptions(at::kFloat)); + at::Tensor b = at::rand({4}, at::TensorOptions(at::kFloat)); + at::Tensor c = at::dot(a, b); + ForEachDevice([&](const Device& device) { + at::Tensor xla_a = bridge::CreateXlaTensor(a, device); + at::Tensor xla_b = bridge::CreateXlaTensor(b, device); + at::Tensor xla_c = at::dot(xla_a, xla_b); + AllClose(c, xla_c); + }); +} + TEST_F(AtenXlaTensorTest, TestBatchMatMul) { at::Tensor a = at::rand({3, 6, 4}, at::TensorOptions(at::kFloat)); at::Tensor b = at::rand({3, 4, 5}, at::TensorOptions(at::kFloat)); diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 1db254dc8b9a..039a16b8c206 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1401,6 +1401,15 @@ at::Tensor AtenXlaType::bmm(const at::Tensor& self, XLATensor::bmm(bridge::GetXlaTensor(self), bridge::GetXlaTensor(mat2))); } +at::Tensor AtenXlaType::dot(const at::Tensor& self, + const at::Tensor& tensor) const { + XLA_CHECK_EQ(self.dim(), 1) + << "dot: Expected 1-D argument self, but got " << self.dim() << "-D"; + XLA_CHECK_EQ(tensor.dim(), 1) + << "dot: Expected 1-D argument tensor, but got " << tensor.dim() << "-D"; + return matmul(self, tensor); +} + std::vector AtenXlaType::broadcast_tensors( at::TensorList tensors) const { return bridge::AtenFromXlaTensors( @@ -1787,6 +1796,11 @@ at::Tensor AtenXlaType::argmax(const at::Tensor& self, int64_t dim, XLATensor::argmax(bridge::GetXlaTensor(self), dim, keepdim)); } +at::Tensor AtenXlaType::_argmax(const at::Tensor& self, int64_t dim, + bool keepdim) const { + return at::native::_argmax(self, dim, keepdim); +} + at::Tensor AtenXlaType::argmax(const at::Tensor& self) const { return bridge::AtenFromXlaTensor( XLATensor::argmax(bridge::GetXlaTensor(self))); @@ -1798,6 +1812,11 @@ at::Tensor AtenXlaType::argmin(const at::Tensor& self, int64_t dim, XLATensor::argmin(bridge::GetXlaTensor(self), dim, keepdim)); } +at::Tensor AtenXlaType::_argmin(const at::Tensor& self, int64_t dim, + bool keepdim) const { + return at::native::_argmin(self, dim, keepdim); +} + at::Tensor AtenXlaType::argmin(const at::Tensor& self) const { return bridge::AtenFromXlaTensor( XLATensor::argmin(bridge::GetXlaTensor(self))); diff --git a/torch_xla/csrc/aten_xla_type.h b/torch_xla/csrc/aten_xla_type.h index 594de8824e97..1bd7a060f1df 100644 --- a/torch_xla/csrc/aten_xla_type.h +++ b/torch_xla/csrc/aten_xla_type.h @@ -429,6 +429,9 @@ class AtenXlaType : public AtenXlaTypeBase { at::Tensor bmm(const at::Tensor& self, const at::Tensor& mat2) const override; + at::Tensor dot(const at::Tensor& self, + const at::Tensor& tensor) const override; + std::vector broadcast_tensors( at::TensorList tensors) const override; @@ -562,10 +565,14 @@ class AtenXlaType : public AtenXlaTypeBase { at::Tensor argmax(const at::Tensor& self, int64_t dim, bool keepdim) const override; at::Tensor argmax(const at::Tensor& self) const override; + at::Tensor _argmax(const at::Tensor& self, int64_t dim, + bool keepdim) const override; at::Tensor argmin(const at::Tensor& self, int64_t dim, bool keepdim) const override; at::Tensor argmin(const at::Tensor& self) const override; + at::Tensor _argmin(const at::Tensor& self, int64_t dim, + bool keepdim) const override; std::tuple native_batch_norm_backward( const at::Tensor& grad_out, const at::Tensor& input,