Skip to content

Commit

Permalink
Implement indexing methods for sparse tensors (#24937)
Browse files Browse the repository at this point in the history
Summary:
Resolves #7416 .

This PR implements the following indexing methods for sparse tensors:
-  [x] `select`
-  [x] `index_select`

Note that this PR also modifies [gen.py](https://github.com/pytorch/pytorch/pull/24937/files#diff-76aa8cb3d0fad99c5f761d08cbcb4d19) that is not directly required to resolve the original issue but to work around a CI build issue reported in issue #24931 .
Pull Request resolved: #24937

Differential Revision: D17163796

Pulled By: ezyang

fbshipit-source-id: 06613301ec456d9ed3491b9ce48e804048600f09
  • Loading branch information
pearu authored and facebook-github-bot committed Sep 3, 2019
1 parent 832c72a commit f793a7c
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aten/src/ATen/core/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -4195,6 +4195,9 @@ inline Tensor Tensor::index_select(int64_t dim, const Tensor & index) const {
case Backend::CPU:
return CPUType::index_select(const_cast<Tensor&>(*this), dim, index);
break;
case Backend::SparseCPU:
return SparseCPUType::index_select(const_cast<Tensor&>(*this), dim, index);
break;
default:
AT_ERROR("index_select not implemented for ", at::toString(tensorTypeIdToBackend(type_id())));
}
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,17 @@ def cmpfiles_with_eol_normalization(a, b, names):
results[0].append(x)
else:
results[1].append(x)
import difflib
import sys
d = difflib.Differ()
sys.stdout.write('-' * 80 + '\n')
sys.stdout.write('x={}, a={}, b={}\n'.format(x, a, b))
for i, line in enumerate(list(d.compare(ax.splitlines(), bx.splitlines()))):
if line[:2] != ' ':
sys.stdout.write('{:5d}: {}\n'.format(i, line))
sys.stdout.write('-' * 80 + '\n')
sys.stdout.write(ax)
sys.stdout.write('-' * 80 + '\n')
except OSError:
results[2].append(x)
return results
Expand Down
121 changes: 121 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,39 @@ Tensor reshape_as(const Tensor& self, const Tensor& other) {
return self.reshape(other.sizes());
}

static Tensor select_sparse(const Tensor& self, int64_t dim, int64_t index) {
int64_t sparse_dim = self.sparse_dim();
int64_t dense_dim = self.dense_dim();
TORCH_INTERNAL_ASSERT(dim >= 0 && dim < sparse_dim + dense_dim);

auto indices = self._indices();
auto values = self._values();
auto new_sizes = self.sizes().vec();
new_sizes.erase(new_sizes.begin() + dim);

if (dim < sparse_dim) {
auto nzIndices = (indices[dim] == index).nonzero().view(-1);
auto new_values = values.index_select(0, nzIndices);
if (sparse_dim == 1) {
// return dense part:
if (new_values.size(0) == 1) {
return new_values[0];
} else {
return new_values.sum(0);
}
} else {
auto dimIndices = (arange(0, sparse_dim, self.device()) != dim).nonzero().view(-1);
auto new_indices = indices.index_select(1, nzIndices).index_select(0, dimIndices);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim - 1, dense_dim, new_sizes, new_indices, new_values, self.options());
}
} else {
auto new_values = values.select(dim - sparse_dim + 1, index);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim - 1, new_sizes, indices, new_values, self.options());
}
}

Tensor select(const Tensor& self, int64_t dim, int64_t index) {
int64_t ndim = self.dim();
if (ndim == 0) {
Expand All @@ -476,6 +509,9 @@ Tensor select(const Tensor& self, int64_t dim, int64_t index) {
if (index < 0) {
index += size;
}
if (self.is_sparse()) {
return select_sparse(self, dim, index);
}
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();
auto storage_offset = self.storage_offset() + index * strides[dim];
Expand All @@ -494,6 +530,91 @@ Tensor select(const Tensor& self, Dimname dim, int64_t index) {
}
#endif

Tensor index_select_sparse(const Tensor& self, int64_t dim, const Tensor& index) {
/*
Algorithm:
index - a 1-D tensor of indicies with shape (n,)
self - sparse tensor, its shape is sizes = sparse_shape + dense_shape
indices - 2-D tensor of indices, shape is (sparse_dims, nnz)
values - (1+len(dense_shape))-D tensor of values, shape is (nnz,) + dense_shape
index_select(dim, index) returns a sparse tensor with the follwing data
new_sizes = sizes[:dim] + (n,) + sizes[dim+1:]
new_indices - shape is (sparse_dims, new_nnz)
new_values - shape is (new_nnz,) + dense_shape
if dim < len(sparse_shape):
for i, idx in enumerate(index):
for j, jdx in enumerate(indices[dim]):
if idx == jdx:
icol = indices[:dim][j] + (i,) + indices[dim+1:][j]
new_indices.add_column(icol)
new_values.add_row(values[j])
else:
new_indices = indices
new_values[k] = values[k].index_select(dim - len(sparse_shape), index) for k in range(nnz)
*/
auto ndim = self.dim();
if (ndim == 0) {
AT_INDEX_ERROR("index_select() cannot be applied to a 0-dim tensor.");
}
if (!(index.dim() == 1 && index.dtype() == at::kLong)) {
AT_INDEX_ERROR("index_select() argument index must be 1-D long-tensor.");
}
dim = maybe_wrap_dim(dim, ndim);
auto size = self.size(dim);
auto sparse_dim = self.sparse_dim();
auto dense_dim = self.dense_dim();
auto indices = self._indices();
auto values = self._values();
auto nnz = values.size(0);
auto new_sizes = self.sizes().vec();
new_sizes[dim] = index.size(0);

if (dim < sparse_dim) {

auto dim_indices = indices[dim];
std::vector<int64_t> zindices;
std::vector<int64_t> iindices;
int64_t new_nnz = 0;
for (int64_t i=0; i < new_sizes[dim]; i++) {
auto idx = index[i].item<int64_t>();
if (idx < -size || idx >= size) {
AT_INDEX_ERROR("index_select(): index contains ", idx, " that is out of range for tensor of size ",
self.sizes(), " at dimension ", dim);
}
if (idx < 0) {
idx += size;
}
for (int64_t j=0; j < nnz; j++) {
auto jdx = dim_indices[j].item<int64_t>();
if (idx == jdx) {
new_nnz++;
iindices.push_back(i);
zindices.push_back(j);
}
}
}
auto zIndices = at::from_blob(zindices.data(), {new_nnz}, at::kLong).to(indices.device());
auto new_indices = indices.index_select(1, zIndices);
new_indices[dim] = at::from_blob(iindices.data(), {new_nnz}, at::kLong).to(indices.device());
auto new_values = values.index_select(0, zIndices);
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, new_sizes, new_indices, new_values, self.options());

} else {

auto vsize = values.sizes().vec();
vsize[dim + 1 - sparse_dim] = index.size(0);
auto new_values = at::empty(vsize, values.options());
for (int64_t k=0; k < nnz; k++) {
new_values[k] = values[k].index_select(dim - sparse_dim, index);
}
return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, new_sizes, indices, new_values, self.options());

}
}

Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
int64_t ndim = self.dim();
if (ndim == 0) {
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3847,6 +3847,8 @@
dispatch:
CPU: legacy::cpu::_th_index_select
CUDA: legacy::cuda::_th_index_select
SparseCPU: index_select_sparse
SparseCUDA: index_select_sparse

- func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand Down
55 changes: 55 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,61 @@ def test_shape(sparse_dims, nnz, sizes, unsqueeze_dim, fail_message=None):
test_shape(3, 10, [5, 7, 11, 13, 17], -7, "Dimension out of range")
test_shape(3, 10, [5, 7, 11, 13, 17], 6, "Dimension out of range")

def test_select(self):
def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes)
if fail_message:
with self.assertRaisesRegex(IndexError, fail_message):
torch.select(x, select_dim, select_index)
else:
result = torch.select(x, select_dim, select_index)
if result.is_sparse:
result = result.to_dense()
dense_result = torch.select(x.to_dense(), select_dim, select_index)
self.assertEqual(dense_result, result)


sizes = [5, 7, 11, 13, 17]
# hybrid sparse/dense, select sparse dim, result is dense
for i in range(sizes[0]):
test_shape(1, 10, sizes, 0, i)
test_shape(1, 10, sizes, 0, sizes[0] + 1, r'select[(][)][:] index \d out of range.*')

# hybrid sparse/dense, select sparse dim, result is sparse
for d in range(3):
for i in range(sizes[d]):
test_shape(3, 10, sizes, d, i)

# hybrid sparse/dense, select dense dim, result is sparse
for d in range(1, 3):
for i in range(sizes[d]):
test_shape(1, 10, sizes, d, i)


def test_index_select(self):
def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=None):
if isinstance(select_index, int):
select_index = [select_index]
if isinstance(select_index, list):
select_index = torch.tensor(select_index, device=self.device, dtype=torch.long)
x, _, _ = self._gen_sparse(sparse_dims, nnz, sizes)
if fail_message:
with self.assertRaisesRegex(IndexError, fail_message):
torch.index_select(x, select_dim, select_index)
else:
result = torch.index_select(x, select_dim, select_index)
if result.is_sparse:
result = result.to_dense()
dense_result = torch.index_select(x.to_dense(), select_dim, select_index)
self.assertEqual(dense_result, result)

sizes = [5, 7, 11, 13, 17]
for d in range(len(sizes)):
for index in [0, sizes[d] - 1, [0, sizes[d] // 2, sizes[d] - 1]]:
test_shape(1, 10, sizes, d, index)
test_shape(len(sizes) // 2, 10, sizes, d, index)
test_shape(len(sizes), 10, sizes, d, index)

@cpu_only
def test_mm(self):
def test_shape(di, dj, dk, nnz):
Expand Down

0 comments on commit f793a7c

Please sign in to comment.