Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,32 @@ TEST_F(AtenXlaTensorTest, TestCumSumCast) {
}
}

TEST_F(AtenXlaTensorTest, TestCumSumLong) {
at::Tensor input = at::randint(1000, {4, 3, 4}, at::TensorOptions(at::kLong));
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
at::Tensor result = at::cumsum(input, dim);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_result = at::cumsum(xla_input, dim);
AllClose(result, xla_result);
});
}
}

TEST_F(AtenXlaTensorTest, TestCumSumCastLong) {
at::Tensor input = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
at::Tensor result = at::cumsum(input, dim, at::kLong);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_result = at::cumsum(xla_input, dim, at::kLong);
EXPECT_TRUE(EqualValues(result, xla_result));
});
}
}

TEST_F(AtenXlaTensorTest, TestCumProd) {
at::Tensor input = at::rand({4, 3, 4}, at::TensorOptions(at::kFloat));
int rank = input.dim();
Expand Down Expand Up @@ -1514,6 +1540,32 @@ TEST_F(AtenXlaTensorTest, TestCumProdCast) {
}
}

TEST_F(AtenXlaTensorTest, TestCumProdLong) {
at::Tensor input = at::randint(7, {2, 3}, at::TensorOptions(at::kLong));
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
at::Tensor result = at::cumsum(input, dim);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_result = at::cumsum(xla_input, dim);
AllClose(result, xla_result);
});
}
}

TEST_F(AtenXlaTensorTest, TestCumProdCastLong) {
at::Tensor input = at::rand({2, 3}, at::TensorOptions(at::kFloat)) * 7;
int rank = input.dim();
for (int dim = -rank; dim < rank; ++dim) {
at::Tensor result = at::cumsum(input, dim, at::kLong);
ForEachDevice([&](const Device& device) {
at::Tensor xla_input = bridge::CreateXlaTensor(input, device);
at::Tensor xla_result = at::cumsum(xla_input, dim, at::kLong);
EXPECT_TRUE(EqualValues(result, xla_result));
});
}
}

TEST_F(AtenXlaTensorTest, TestArgMin) {
at::Tensor a = at::rand({4, 4, 4}, at::TensorOptions(at::kFloat));
at::Tensor b = at::argmin(a, c10::nullopt, /*keepdim=*/false);
Expand Down
26 changes: 20 additions & 6 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,24 +945,38 @@ at::Tensor AtenXlaType::cross(const at::Tensor& self, const at::Tensor& other,

at::Tensor AtenXlaType::cumprod(const at::Tensor& self, int64_t dim,
at::ScalarType dtype) const {
return bridge::AtenFromXlaTensor(
XLATensor::cumprod(bridge::GetXlaTensor(self), dim, dtype));
XLATensor self_tensor = bridge::GetXlaTensor(self);
if (!HasNativeSupport(dtype, self_tensor.GetDevice())) {
return AtenXlaTypeBase::cumprod(self, dim, dtype);
}
return bridge::AtenFromXlaTensor(XLATensor::cumprod(self_tensor, dim, dtype));
}

at::Tensor AtenXlaType::cumprod(const at::Tensor& self, int64_t dim) const {
XLATensor self_tensor = bridge::GetXlaTensor(self);
if (!HasNativeSupport(self_tensor.dtype(), self_tensor.GetDevice())) {
return AtenXlaTypeBase::cumprod(self, dim);
}
return bridge::AtenFromXlaTensor(
XLATensor::cumprod(bridge::GetXlaTensor(self), dim, c10::nullopt));
XLATensor::cumprod(self_tensor, dim, c10::nullopt));
}

at::Tensor AtenXlaType::cumsum(const at::Tensor& self, int64_t dim,
at::ScalarType dtype) const {
return bridge::AtenFromXlaTensor(
XLATensor::cumsum(bridge::GetXlaTensor(self), dim, dtype));
XLATensor self_tensor = bridge::GetXlaTensor(self);
if (!HasNativeSupport(dtype, self_tensor.GetDevice())) {
return AtenXlaTypeBase::cumsum(self, dim, dtype);
}
return bridge::AtenFromXlaTensor(XLATensor::cumsum(self_tensor, dim, dtype));
}

at::Tensor AtenXlaType::cumsum(const at::Tensor& self, int64_t dim) const {
XLATensor self_tensor = bridge::GetXlaTensor(self);
if (!HasNativeSupport(self_tensor.dtype(), self_tensor.GetDevice())) {
return AtenXlaTypeBase::cumsum(self, dim);
}
return bridge::AtenFromXlaTensor(
XLATensor::cumsum(bridge::GetXlaTensor(self), dim, c10::nullopt));
XLATensor::cumsum(self_tensor, dim, c10::nullopt));
}

at::Tensor AtenXlaType::diag(const at::Tensor& self, int64_t diagonal) const {
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,9 @@ xla::PrimitiveType GetDevicePrimitiveType(xla::PrimitiveType type,
xla::PrimitiveType MakeXlaPrimitiveType(at::ScalarType scalar_type,
const Device* device);

// Returns true iff the device supports the given type natively.
inline bool HasNativeSupport(at::ScalarType type, const Device& device) {
return TensorTypeFromXlaType(MakeXlaPrimitiveType(type, &device)) == type;
}

} // namespace torch_xla