Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse: Remove dispatch in parallel region #60598

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion aten/src/ATen/native/sparse/SparseMatMul.cpp
Expand Up @@ -2,7 +2,6 @@
#include <ATen/Config.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/native/Resize.h>
Expand Down
25 changes: 23 additions & 2 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Expand Up @@ -9,6 +9,7 @@
#include <ATen/SparseTensorUtils.h>
#include <ATen/native/IndexingUtils.h>

#include <ATen/native/Copy.h>
#include <ATen/native/CPUBlas.h>

namespace at {
Expand Down Expand Up @@ -634,12 +635,13 @@ void inline sparse_mask_out_cpu_kernel(
auto r_values_accessor = r_values.accessor<scalar_t, 1>();
auto mask_indices_accessor = mask_indices.accessor<int64_t, 2>();
scalar_t* t_ptr = t.data_ptr<scalar_t>();
auto t_strides = t.strides();

at::parallel_for(0, r_nnz, 1000, [&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
int64_t idx = 0;
for (int64_t d = 0; d < sparse_dim; d++) {
idx += mask_indices_accessor[d][i] * t.stride(d);
idx += mask_indices_accessor[d][i] * t_strides[d];
}
r_values_accessor[i] = t_ptr[idx];
}
Expand Down Expand Up @@ -767,12 +769,31 @@ Tensor sparse_mask_helper_cpu(

auto flattened_mask_indices =
at::sparse::flatten_indices(mask_indices, full_size);

const auto copy_iter = TensorIteratorConfig()
.add_output(r_values)
.add_input(t_v)
.resize_outputs(false)
.declare_static_shape(r_values.sizes(), /*squash_dims=*/0)
.build();

at::parallel_for(0, r_nnz, 0, [&](int64_t start, int64_t end) {
TensorIterator copy_iter_local(copy_iter);
const auto r_values_data = reinterpret_cast<char*>(r_values.data_ptr());
const auto t_values_data = reinterpret_cast<char*>(t_v.data_ptr());
const auto r_values_stride = r_values.strides()[0] * r_values.element_size();
const auto t_values_stride = t_v.strides()[0] * t_v.element_size();

for (auto i = start; i < end; i++) {
int64_t index = flattened_mask_indices.data_ptr<int64_t>()[i];
auto iter = t_flatten_indices.find(index);
if (iter != t_flatten_indices.end()) {
r_values[i] = t_v[iter->second];
// r_values[i].copy_(t_v[iter->second])
copy_iter_local.unsafe_replace_operand(
0, r_values_data + i * r_values_stride);
copy_iter_local.unsafe_replace_operand(
1, t_values_data + iter->second * t_values_stride);
copy_stub(kCPU, copy_iter_local, /*non_blocking=*/false);
}
}
});
Expand Down
30 changes: 26 additions & 4 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Expand Up @@ -11,6 +11,7 @@
#include <ATen/SparseTensorUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/Copy.h>
#include <ATen/native/CPUBlas.h>

#include <algorithm>
Expand Down Expand Up @@ -645,13 +646,15 @@ void add_dense_sparse_worker_cpu(Tensor& r, const Scalar& value, const SparseTen
auto values_accessor = values.accessor<scalar_t, 1>();

scalar_t* r_ptr = r.data_ptr<scalar_t>();
auto r_strides = r.strides();
scalar_t cast_value = value.to<scalar_t>();
const auto sparse_dim = sparse.sparse_dim();

at::parallel_for(0, sparse._nnz(), 0, [&](int64_t start, int64_t end) {
for (auto k: c10::irange(start, end)) {
int64_t index = r.storage_offset();
for (auto d: c10::irange(sparse.sparse_dim())) {
index += r.stride(d) * indices_accessor[d][k];
for (auto d: c10::irange(sparse_dim)) {
index += r_strides[d] * indices_accessor[d][k];
}
r_ptr[index] += cast_value * values_accessor[k];
}
Expand Down Expand Up @@ -1447,16 +1450,35 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_,
auto input_indices_1D = flatten_indices_by_dims(input_indices, input_sizes, sparse_dims_to_keep_v);
auto input_indices_1D_accessor = input_indices_1D.accessor<int64_t, 1>();

// binary search to find matching indices
const auto copy_iter = TensorIteratorConfig()
.add_output(grad_input_values)
.add_input(grad_values_expand)
.resize_outputs(false)
.declare_static_shape(grad_values_expand.sizes(), /*squash_dims=*/0)
.build();
const auto device_type = kCPU;

const auto gIv_data = reinterpret_cast<char*>(grad_input_values.data_ptr());
const auto gOv_data = reinterpret_cast<char*>(grad_values_expand.data_ptr());
const auto gIv_stride = (grad_input_values.strides()[0] *
grad_input_values.element_size());
const auto gOv_stride = (grad_values_expand.strides()[0] *
grad_values_expand.element_size());

// binary search to find matching indices
at::parallel_for(0, input_nnz, 0, [&](int64_t start, int64_t end) {
TensorIterator copy_iter_local(copy_iter);

for (auto i: c10::irange(start, end)) {
int64_t input_idx = input_indices_1D_accessor[i];
int64_t l = 0, r = grad_nnz - 1;
while (l <= r) {
int64_t m = l + (r - l) / 2;
if (grad_indices_1D_accessor[m] == input_idx) {
grad_input_values[i].copy_(grad_values_expand[m]);
// grad_input_values[i].copy_(grad_values_expand[m])
copy_iter_local.unsafe_replace_operand(0, gIv_data + i * gIv_stride);
copy_iter_local.unsafe_replace_operand(1, gOv_data + m * gOv_stride);
copy_stub(device_type, copy_iter_local, /*non_blocking=*/false);
break;
}
if (grad_indices_1D_accessor[m] < input_idx) {
Expand Down