Skip to content

Commit

Permalink
add narrow() support for sparse tensors re: pytorch#8853
Browse files Browse the repository at this point in the history
  • Loading branch information
Doug Friedman committed Sep 19, 2018
1 parent 3da8d71 commit cdedaa7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
40 changes: 40 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,48 @@ Tensor &as_strided_(Tensor& self, IntList size, IntList stride) {
return at::as_strided_(self, size, stride, self.storage_offset());
}

Tensor _narrow_sparse(const Tensor& self, int64_t dim, int64_t start, int64_t length){
LongTensor indices = self._indices();
Tensor values = self._values();
int64_t numCoords = indices.size(1);
int64_t dims = indices.size(0);

std::vector<int64_t> newSizes = self.sizes().vec();
newSizes[dim]=length;

Tensor narrowDim = at::zeros_like(indices[dim]);
narrowDim.copy_(indices[dim]);
indices[dim] = indices[dim].add(-start);

std::vector<int64_t> keep;
int64_t end = start+length;
for(int i=0; i<numCoords; i++){
int64_t val = narrowDim[i].toCLong();
if(val >= start && val < end)
keep.push_back(i);
}

int64_t keepSize = keep.size();
LongTensor newIndices_tmp = indices.type().tensor({keepSize, dims});
Tensor newValues = values.type().tensor({keepSize});
LongTensor tpose = indices.t();

int i=0;
for(int64_t& k : keep){
newIndices_tmp[i] = tpose[k];
newValues[i] = values[k];
i++;
}

LongTensor newIndices = newIndices_tmp.t();
return self.type().sparse_coo_tensor(newIndices, newValues, newSizes);
}

Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
AT_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor.");
if (self.type().is_sparse()) {
return _narrow_sparse(self, dim, start, length);
}
auto cur_size = self.size(dim);
if (start != cur_size) { // start being the end is valid, but not a valid dim specification.
start = maybe_wrap_dim(start, cur_size);
Expand Down
26 changes: 26 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,32 @@ def _test_zeros(self, shape, out_shape_i, out_shape_v=None):
self.assertEqual(out._sparseDims(), len(shape))
self.assertEqual(out._denseDims(), 0)

def test_narrow(self):
if self.is_cuda:
input = torch.cuda.sparse.DoubleTensor(
torch.LongTensor([[0], [1], [2]]).transpose(1, 0).cuda(),
torch.FloatTensor([3, 4, 5]).cuda(),
torch.Size([3]))
else:
input = torch.sparse.DoubleTensor(
torch.LongTensor([[0], [1], [2]]).transpose(1, 0),
torch.FloatTensor([3, 4, 5]),
torch.Size([3]))

narrow_args = [0,0,2] # dim, start, length
expected = torch.tensor([3., 4., 5.]).narrow(*narrow_args)

self.assertEqual(expected, input.narrow(*narrow_args).to_dense())
self.assertEqual(expected, input.coalesce().narrow(*narrow_args).to_dense())

uncoalesced = torch.sparse.DoubleTensor(
torch.LongTensor([[0], [1], [2], [0], [1], [2]]).transpose(1, 0),
torch.FloatTensor([2, 3, 4, 1, 1, 1]),
torch.Size([3]))

self.assertEqual(expected, uncoalesced.narrow(*narrow_args).to_dense())
self.assertEqual(expected, uncoalesced.coalesce().narrow(*narrow_args).to_dense())

@skipIfRocm
def test_log1p(self):
if self.is_cuda:
Expand Down

0 comments on commit cdedaa7

Please sign in to comment.