Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement indexing methods for sparse tensors #24937

Closed
wants to merge 16 commits into from
Closed
3 changes: 3 additions & 0 deletions aten/src/ATen/core/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -4195,6 +4195,9 @@ inline Tensor Tensor::index_select(int64_t dim, const Tensor & index) const {
case Backend::CPU:
return CPUType::index_select(const_cast<Tensor&>(*this), dim, index);
break;
case Backend::SparseCPU:
return SparseCPUType::index_select(const_cast<Tensor&>(*this), dim, index);
break;
default:
AT_ERROR("index_select not implemented for ", at::toString(tensorTypeIdToBackend(type_id())));
}
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,17 @@ def cmpfiles_with_eol_normalization(a, b, names):
results[0].append(x)
else:
results[1].append(x)
import difflib
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what is going on here. Why did this have to change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this code is executed, it was really difficult to guess what was wrong, especially when this happened only within few CI instances, in most CI cases as well as locally, the code was never executed. So, showing the difference was helpful to continue debugging why the files were different.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang @zou3519 Do you want to keep the diff output code for debugging? Or shall I undo these changes to gen.py?

import sys
d = difflib.Differ()
sys.stdout.write('-' * 80 + '\n')
sys.stdout.write('x={}, a={}, b={}\n'.format(x, a, b))
for i, line in enumerate(list(d.compare(ax.splitlines(), bx.splitlines()))):
if line[:2] != ' ':
sys.stdout.write('{:5d}: {}\n'.format(i, line))
sys.stdout.write('-' * 80 + '\n')
sys.stdout.write(ax)
sys.stdout.write('-' * 80 + '\n')
except OSError:
results[2].append(x)
return results
Expand Down
121 changes: 121 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,39 @@ Tensor reshape_as(const Tensor& self, const Tensor& other) {
return self.reshape(other.sizes());
}

static Tensor select_sparse(const Tensor& self, int64_t dim, int64_t index) {
pearu marked this conversation as resolved.
Show resolved Hide resolved
int64_t sparse_dim = self.sparse_dim();
int64_t dense_dim = self.dense_dim();
TORCH_INTERNAL_ASSERT(dim >= 0 && dim < sparse_dim + dense_dim);

auto indices = self._indices();
auto values = self._values();
auto new_sizes = self.sizes().vec();
new_sizes.erase(new_sizes.begin() + dim);

if (dim < sparse_dim) {
auto nzIndices = (indices[dim] == index).nonzero().view(-1);
auto new_values = values.index_select(0, nzIndices);
if (sparse_dim == 1) {
pearu marked this conversation as resolved.
Show resolved Hide resolved
// return dense part:
if (new_values.size(0) == 1) {
return new_values[0];
} else {
return new_values.sum(0);
}
} else {
auto dimIndices = (arange(0, sparse_dim, self.device()) != dim).nonzero().view(-1);
auto new_indices = indices.index_select(1, nzIndices).index_select(0, dimIndices);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim - 1, dense_dim, new_sizes, new_indices, new_values, self.options());
}
} else {
auto new_values = values.select(dim - sparse_dim + 1, index);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim - 1, new_sizes, indices, new_values, self.options());
}
}

Tensor select(const Tensor& self, int64_t dim, int64_t index) {
int64_t ndim = self.dim();
if (ndim == 0) {
Expand All @@ -476,6 +509,9 @@ Tensor select(const Tensor& self, int64_t dim, int64_t index) {
if (index < 0) {
index += size;
}
if (self.is_sparse()) {
return select_sparse(self, dim, index);
}
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();
auto storage_offset = self.storage_offset() + index * strides[dim];
Expand All @@ -494,6 +530,91 @@ Tensor select(const Tensor& self, Dimname dim, int64_t index) {
}
#endif

Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) {
/*
Algorithm:
index - a 1-D tensor of indicies with shape (n,)
self - sparse tensor, its shape is sizes = sparse_shape + dense_shape
indices - 2-D tensor of indices, shape is (sparse_dims, nnz)
values - (1+len(dense_shape))-D tensor of values, shape is (nnz,) + dense_shape
index_select(dim, index) returns a sparse tensor with the follwing data
new_sizes = sizes[:dim] + (n,) + sizes[dim+1:]
new_indices - shape is (sparse_dims, new_nnz)
new_values - shape is (new_nnz,) + dense_shape

if dim < len(sparse_shape):
for i, idx in enumerate(index):
for j, jdx in enumerate(indices[dim]):
if idx == jdx:
icol = indices[:dim][j] + (i,) + indices[dim+1:][j]
new_indices.add_column(icol)
new_values.add_row(values[j])
else:
new_indices = indices
new_values[k] = values[k].index_select(dim - len(sparse_shape), index) for k in range(nnz)
*/
auto ndim = self.dim();
if (ndim == 0) {
AT_INDEX_ERROR("index_select() cannot be applied to a 0-dim tensor.");
}
if (!(index.dim() == 1 && index.dtype() == at::kLong)) {
AT_INDEX_ERROR("index_select() argument index must be 1-D long-tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
auto size = self.size(dim);
auto sparse_dim = self.sparse_dim();
auto dense_dim = self.dense_dim();
auto indices = self._indices();
auto values = self._values();
auto nnz = values.size(0);
auto new_sizes = self.sizes().vec();
new_sizes[dim] = index.size(0);

if (dim < sparse_dim) {

auto dim_indices = indices[dim];
std::vector<int64_t> zindices;
std::vector<int64_t> iindices;
int64_t new_nnz = 0;
for (int64_t i=0; i < new_sizes[dim]; i++) {
auto idx = index[i].item<int64_t>();
if (idx < -size || idx >= size) {
AT_INDEX_ERROR("index_select(): index contains ", idx, " that is out of range for tensor of size ",
self.sizes(), " at dimension ", dim);
}
if (idx < 0) {
idx += size;
}
for (int64_t j=0; j < nnz; j++) {
auto jdx = dim_indices[j].item<int64_t>();
if (idx == jdx) {
new_nnz++;
iindices.push_back(i);
zindices.push_back(j);
}
}
}
auto zIndices = at::from_blob(zindices.data(), {new_nnz}, at::kLong).to(indices.device());
auto new_indices = indices.index_select(1, zIndices);
new_indices[dim] = at::from_blob(iindices.data(), {new_nnz}, at::kLong).to(indices.device());
auto new_values = values.index_select(0, zIndices);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, new_sizes, new_indices, new_values, self.options());

} else {

auto vsize = values.sizes().vec();
vsize[dim + 1 - sparse_dim] = index.size(0);
auto new_values = at::empty(vsize, values.options());
for (int64_t k=0; k < nnz; k++) {
new_values[k] = values[k].index_select(dim - sparse_dim, index);
}
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, new_sizes, indices, new_values, self.options());

}
}

Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
int64_t ndim = self.dim();
if (ndim == 0) {
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3847,6 +3847,8 @@
dispatch:
CPU: legacy::cpu::_th_index_select
CUDA: legacy::cuda::_th_index_select
SparseCPU: index_select_sparse
SparseCUDA: index_select_sparse
pearu marked this conversation as resolved.
Show resolved Hide resolved

- func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand Down
55 changes: 55 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,61 @@ def test_shape(sparse_dims, nnz, sizes, unsqueeze_dim, fail_message=None):
test_shape(3, 10, [5, 7, 11, 13, 17], -7, "Dimension out of range")
test_shape(3, 10, [5, 7, 11, 13, 17], 6, "Dimension out of range")

def test_select(self):
def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes)
if fail_message:
with self.assertRaisesRegex(IndexError, fail_message):
torch.select(x, select_dim, select_index)
else:
result = torch.select(x, select_dim, select_index)
if result.is_sparse:
result = result.to_dense()
dense_result = torch.select(x.to_dense(), select_dim, select_index)
self.assertEqual(dense_result, result)


sizes = [5, 7, 11, 13, 17]
# hybrid sparse/dense, select sparse dim, result is dense
for i in range(sizes[0]):
test_shape(1, 10, sizes, 0, i)
test_shape(1, 10, sizes, 0, sizes[0] + 1, r'select[(][)][:] index \d out of range.*')

# hybrid sparse/dense, select sparse dim, result is sparse
for d in range(3):
for i in range(sizes[d]):
test_shape(3, 10, sizes, d, i)

# hybrid sparse/dense, select dense dim, result is sparse
for d in range(1, 3):
for i in range(sizes[d]):
test_shape(1, 10, sizes, d, i)


def test_index_select(self):
def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
if isinstance(select_index, int):
select_index = [select_index]
if isinstance(select_index, list):
select_index = torch.tensor(select_index, device=self.device, dtype=torch.long)
x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes)
if fail_message:
with self.assertRaisesRegex(IndexError, fail_message):
torch.index_select(x, select_dim, select_index)
else:
result = torch.index_select(x, select_dim, select_index)
if result.is_sparse:
result = result.to_dense()
dense_result = torch.index_select(x.to_dense(), select_dim, select_index)
self.assertEqual(dense_result, result)

sizes = [5, 7, 11, 13, 17]
for d in range(len(sizes)):
for index in [0, sizes[d] - 1, [0, sizes[d] // 2, sizes[d] - 1]]:
test_shape(1, 10, sizes, d, index)
test_shape(len(sizes) // 2, 10, sizes, d, index)
test_shape(len(sizes), 10, sizes, d, index)

@cpu_only
def test_mm(self):
def test_shape(di, dj, dk, nnz):
Expand Down