Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,18 @@ TEST_F(AtenXlaTensorTest, TestArgMinSameValue) {
});
}

TEST_F(AtenXlaTensorTest, TestArgMinWrapper) {
at::Tensor a = at::rand({4, 4, 4}, at::TensorOptions(at::kFloat));
for (int dim : {1, -2}) {
at::Tensor b = at::_argmin(a, dim, /*keepdim=*/false);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::_argmin(xla_a, dim, /*keepdim=*/false);
AllClose(b, xla_b);
});
}
}

TEST_F(AtenXlaTensorTest, TestArgMax) {
at::Tensor a = at::rand({4, 4, 4}, at::TensorOptions(at::kFloat));
at::Tensor b = at::argmax(a);
Expand Down Expand Up @@ -1143,6 +1155,18 @@ TEST_F(AtenXlaTensorTest, TestArgMaxSameValue) {
});
}

TEST_F(AtenXlaTensorTest, TestArgMaxWrapper) {
at::Tensor a = at::rand({4, 4, 4}, at::TensorOptions(at::kFloat));
for (int dim : {1, -2}) {
at::Tensor b = at::_argmax(a, dim, /*keepdim=*/false);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::_argmax(xla_a, dim, /*keepdim=*/false);
AllClose(b, xla_b);
});
}
}

TEST_F(AtenXlaTensorTest, TestAsin) {
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
at::Tensor b = at::asin(a);
Expand Down Expand Up @@ -1497,6 +1521,18 @@ TEST_F(AtenXlaTensorTest, TestMatmulBcast) {
});
}

TEST_F(AtenXlaTensorTest, TestDot) {
at::Tensor a = at::rand({4}, at::TensorOptions(at::kFloat));
at::Tensor b = at::rand({4}, at::TensorOptions(at::kFloat));
at::Tensor c = at::dot(a, b);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = bridge::CreateXlaTensor(b, device);
at::Tensor xla_c = at::dot(xla_a, xla_b);
AllClose(c, xla_c);
});
}

TEST_F(AtenXlaTensorTest, TestBatchMatMul) {
at::Tensor a = at::rand({3, 6, 4}, at::TensorOptions(at::kFloat));
at::Tensor b = at::rand({3, 4, 5}, at::TensorOptions(at::kFloat));
Expand Down
19 changes: 19 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,15 @@ at::Tensor AtenXlaType::bmm(const at::Tensor& self,
XLATensor::bmm(bridge::GetXlaTensor(self), bridge::GetXlaTensor(mat2)));
}

at::Tensor AtenXlaType::dot(const at::Tensor& self,
const at::Tensor& tensor) const {
XLA_CHECK_EQ(self.dim(), 1)
<< "dot: Expected 1-D argument self, but got " << self.dim() << "-D";
XLA_CHECK_EQ(tensor.dim(), 1)
<< "dot: Expected 1-D argument tensor, but got " << tensor.dim() << "-D";
return matmul(self, tensor);
}

std::vector<at::Tensor> AtenXlaType::broadcast_tensors(
at::TensorList tensors) const {
return bridge::AtenFromXlaTensors(
Expand Down Expand Up @@ -1787,6 +1796,11 @@ at::Tensor AtenXlaType::argmax(const at::Tensor& self, int64_t dim,
XLATensor::argmax(bridge::GetXlaTensor(self), dim, keepdim));
}

at::Tensor AtenXlaType::_argmax(const at::Tensor& self, int64_t dim,
bool keepdim) const {
return at::native::_argmax(self, dim, keepdim);
}

at::Tensor AtenXlaType::argmax(const at::Tensor& self) const {
return bridge::AtenFromXlaTensor(
XLATensor::argmax(bridge::GetXlaTensor(self)));
Expand All @@ -1798,6 +1812,11 @@ at::Tensor AtenXlaType::argmin(const at::Tensor& self, int64_t dim,
XLATensor::argmin(bridge::GetXlaTensor(self), dim, keepdim));
}

at::Tensor AtenXlaType::_argmin(const at::Tensor& self, int64_t dim,
bool keepdim) const {
return at::native::_argmin(self, dim, keepdim);
}

at::Tensor AtenXlaType::argmin(const at::Tensor& self) const {
return bridge::AtenFromXlaTensor(
XLATensor::argmin(bridge::GetXlaTensor(self)));
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ class AtenXlaType : public AtenXlaTypeBase {

at::Tensor bmm(const at::Tensor& self, const at::Tensor& mat2) const override;

at::Tensor dot(const at::Tensor& self,
const at::Tensor& tensor) const override;

std::vector<at::Tensor> broadcast_tensors(
at::TensorList tensors) const override;

Expand Down Expand Up @@ -562,10 +565,14 @@ class AtenXlaType : public AtenXlaTypeBase {
at::Tensor argmax(const at::Tensor& self, int64_t dim,
bool keepdim) const override;
at::Tensor argmax(const at::Tensor& self) const override;
at::Tensor _argmax(const at::Tensor& self, int64_t dim,
bool keepdim) const override;

at::Tensor argmin(const at::Tensor& self, int64_t dim,
bool keepdim) const override;
at::Tensor argmin(const at::Tensor& self) const override;
at::Tensor _argmin(const at::Tensor& self, int64_t dim,
bool keepdim) const override;

std::tuple<at::Tensor, at::Tensor, at::Tensor> native_batch_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input,
Expand Down