Permalink
Browse files

added better CSR slicing

  • Loading branch information...
1 parent 82be113 commit c19c8a8faa9fdc71e74e8776e55bef0991c32956 wnbell committed Feb 14, 2008
Showing with 249 additions and 5 deletions.
  1. +186 −2 scipy/sparse/csr.py
  2. +63 −3 scipy/sparse/tests/test_base.py
View
@@ -8,10 +8,11 @@
import numpy
from numpy import array, matrix, asarray, asmatrix, zeros, rank, intc, \
empty, hstack, isscalar, ndarray, shape, searchsorted, where, \
- concatenate, deprecate
+ concatenate, deprecate, arange, ones
from base import spmatrix, isspmatrix
-from sparsetools import csr_tocsc, csr_tobsr, csr_count_blocks
+from sparsetools import csr_tocsc, csr_tobsr, csr_count_blocks, \
+ get_csr_submatrix
from sputils import upcast, to_native, isdense, isshape, getdtype, \
isscalarlike, isintlike
@@ -182,6 +183,189 @@ def _swap(self,x):
return (x[0],x[1])
+ def __getitem__(self, key):
+ def asindices(x):
+ try:
+ x = asarray(x,dtype='intc')
+ except:
+ raise IndexError('invalid index')
+ else:
+ return x
+
+ def extractor(indices,N):
+ """Return a sparse matrix P so that P*self implements
+ slicing of the form self[[1,2,3],:]
+ """
+ indices = asindices(indices)
+
+ max_indx = indices.max()
+
+ if max_indx > N:
+ raise ValueError('index (%d) out of range' % max_indx)
+
+ min_indx = indices.min()
+ if min_indx < -N:
+ raise ValueError('index (%d) out of range' % (N + min_indx))
+
+ if min_indx < 0:
+ indices = indices.copy()
+ indices[indices < 0] += N
+
+ indptr = arange(len(indices) + 1, dtype='intc')
+ data = ones(len(indices), dtype=self.dtype)
+ shape = (len(indices),N)
+
+ return csr_matrix( (data,indices,indptr), shape=shape)
+
+
+ if isinstance(key, tuple):
+ row = key[0]
+ col = key[1]
+
+ #TODO implement CSR[ [1,2,3], X ] with sparse matmat
+ #TODO make use of sorted indices
+
+ if isintlike(row):
+ #[1,??]
+ if isintlike(col):
+ return self._get_single_element(row, col) #[i,j]
+ elif isinstance(col, slice):
+ return self._get_row_slice(row, col) #[i,1:2]
+ else:
+ P = extractor(col,self.shape[1]).T #[i,[1,2]]
+ return self[row,:]*P
+
+ elif isinstance(row, slice):
+ #[1:2,??]
+ if isintlike(col) or isinstance(col, slice):
+ return self._get_submatrix(row, col) #[1:2,j]
+ else:
+ P = extractor(col,self.shape[1]).T #[1:2,[1,2]]
+ return self[row,:]*P
+ else:
+ #[[1,2],??]
+ if isintlike(col) or isinstance(col,slice):
+ P = extractor(row, self.shape[0])
+ return (P*self)[:,col] #[[1,2],j] or [[1,2],1:2]
+ else:
+ row = asindices(row) #[[1,2],[1,2]]
+ col = asindices(col)
+ if len(row) != len(col):
+ raise ValueError('number of row and column indices differ')
+ val = []
+ for i,j in zip(row,col):
+ val.append(self._get_single_element(i,j))
+ return asmatrix(val)
+
+
+ elif isintlike(key) or isinstance(key,slice):
+ return self[key,:] #[i] or [1:2]
+ else:
+ return self[asindices(key),:] #[[1,2]]
+
+
+ def _get_single_element(self,row,col):
+ M, N = self.shape
+ if (row < 0):
+ row += M
+ if (col < 0):
+ col += N
+ if not (0<=row<M) or not (0<=col<N):
+ raise IndexError, "index out of bounds"
+
+ start = self.indptr[row]
+ end = self.indptr[row+1]
+ indxs = where(col == self.indices[start:end])[0]
+
+ num_matches = len(indxs)
+
+ if num_matches == 0:
+ # entry does not appear in the matrix
+ return self.dtype.type(0)
+ elif num_matches == 1:
+ return self.data[start:end][indxs[0]]
+ else:
+ raise ValueError('nonzero entry (%d,%d) occurs more than once' % (row,col) )
+
+ def _get_row_slice(self, i, cslice ):
+ """Returns a copy of self[i, cslice]
+ """
+ if i < 0:
+ i += self.shape[0]
+
+ if i < 0:
+ raise ValueError('index (%d) out of range' % i )
+
+ start, stop, stride = cslice.indices(self.shape[1])
+
+ if stride != 1:
+ raise ValueError, "slicing with step != 1 not supported"
+ if stop <= start:
+ raise ValueError, "slice width must be >= 1"
+
+ #TODO make [i,:] faster
+ #TODO implement [i,x:y:z]
+
+ indices = []
+
+ for ind in xrange(self.indptr[i], self.indptr[i+1]):
+ if self.indices[ind] >= start and self.indices[ind] < stop:
+ indices.append(ind)
+
+ index = self.indices[indices] - start
+ data = self.data[indices]
+ indptr = numpy.array([0, len(indices)])
+ return csr_matrix( (data, index, indptr), shape=(1, stop-start) )
+
+ def _get_submatrix( self, row_slice, col_slice ):
+ """Return a submatrix of this matrix (new matrix is created)."""
+
+ M,N = self.shape
+
+ def process_slice( sl, num ):
+ if isinstance( sl, slice ):
+ i0, i1 = sl.start, sl.stop
+ if i0 is None:
+ i0 = 0
+ elif i0 < 0:
+ i0 = num + i0
+
+ if i1 is None:
+ i1 = num
+ elif i1 < 0:
+ i1 = num + i1
+
+ return i0, i1
+
+ elif isscalar( sl ):
+ if sl < 0:
+ sl += num
+
+ return sl, sl + 1
+
+ else:
+ raise TypeError('expected slice or scalar')
+
+ def check_bounds( i0, i1, num ):
+ if not (0<=i0<num) or not (0<i1<=num) or not (i0<i1):
+ raise IndexError,\
+ "index out of bounds: 0<=%d<%d, 0<=%d<%d, %d<%d" %\
+ (i0, num, i1, num, i0, i1)
+
+ i0, i1 = process_slice( row_slice, M )
+ j0, j1 = process_slice( col_slice, N )
+ check_bounds( i0, i1, M )
+ check_bounds( j0, j1, N )
+
+ indptr, indices, data = get_csr_submatrix( M, N, \
+ self.indptr, self.indices, self.data, i0, i1, j0, j1 )
+
+ shape = (i1 - i0, j1 - j0)
+
+ return self.__class__( (data,indices,indptr), shape=shape )
+
+
+
from sputils import _isinstance
def isspmatrix_csr(x):
@@ -628,9 +628,9 @@ class _TestBothSlicing:
def test_get_slices(self):
B = asmatrix(arange(50.).reshape(5,10))
A = self.spmatrix(B)
- assert_array_equal(B[2:5,0:3], A[2:5,0:3].todense())
- assert_array_equal(B[1:,:-1], A[1:,:-1].todense())
- assert_array_equal(B[:-1,1:], A[:-1,1:].todense())
+ assert_array_equal(A[2:5,0:3].todense(), B[2:5,0:3])
+ assert_array_equal(A[1:,:-1].todense(), B[1:,:-1])
+ assert_array_equal(A[:-1,1:].todense(), B[:-1,1:])
# Now test slicing when a column contains only zeros
E = matrix([[1, 0, 1], [4, 0, 0], [0, 0, 0], [0, 0, 1]])
@@ -852,6 +852,66 @@ def test_eliminate_zeros(self):
assert_array_equal(asp.todense(),bsp.todense())
+ def test_fancy_slicing(self):
+ #TODO add this to csc_matrix
+ B = asmatrix(arange(50).reshape(5,10))
+ A = csr_matrix( B )
+
+ # [i,j]
+ assert_equal(A[2,3],B[2,3])
+ assert_equal(A[-1,8],B[-1,8])
+ assert_equal(A[-1,-2],B[-1,-2])
+
+ # [i,1:2]
+ assert_equal(A[2,:].todense(),B[2,:])
+ assert_equal(A[2,5:-2].todense(),B[2,5:-2])
+
+ # [i,[1,2]]
+ assert_equal(A[3,[1,3]].todense(),B[3,[1,3]])
+ assert_equal(A[-1,[2,-5]].todense(),B[-1,[2,-5]])
+
+ # [1:2,j]
+ assert_equal(A[:,2].todense(),B[:,2])
+ assert_equal(A[3:4,9].todense(),B[3:4,9])
+ assert_equal(A[1:4,-5].todense(),B[1:4,-5])
+
+ # [1:2,[1,2]]
+ assert_equal(A[:,[2,8,3,-1]].todense(),B[:,[2,8,3,-1]])
+ assert_equal(A[3:4,[9]].todense(),B[3:4,[9]])
+ assert_equal(A[1:4,[-1,-5]].todense(),B[1:4,[-1,-5]])
+
+ # [[1,2],j]
+ assert_equal(A[[1,3],3].todense(),B[[1,3],3])
+ assert_equal(A[[2,-5],-4].todense(),B[[2,-5],-4])
+
+ # [[1,2],1:2]
+ assert_equal(A[[1,3],:].todense(),B[[1,3],:])
+ assert_equal(A[[2,-5],8:-1].todense(),B[[2,-5],8:-1])
+
+ # [[1,2],[1,2]]
+ assert_equal(A[[1,3],[2,4]],B[[1,3],[2,4]])
+ assert_equal(A[[-1,-3],[2,-4]],B[[-1,-3],[2,-4]])
+
+ # [i]
+ assert_equal(A[1].todense(),B[1])
+ assert_equal(A[-2].todense(),B[-2])
+
+ # [1:2]
+ assert_equal(A[1:4].todense(),B[1:4])
+ assert_equal(A[1:-2].todense(),B[1:-2])
+
+ # [[1,2]]
+ assert_equal(A[[1,3]].todense(),B[[1,3]])
+ assert_equal(A[[-1,-3]].todense(),B[[-1,-3]])
+
+ # [[1,2],:][:,[1,2]]
+ assert_equal(A[[1,3],:][:,[2,4]].todense(), B[[1,3],:][:,[2,4]] )
+ assert_equal(A[[-1,-3],:][:,[2,-4]].todense(), B[[-1,-3],:][:,[2,-4]] )
+
+ # [:,[1,2]][[1,2],:]
+ assert_equal(A[:,[1,3]][[2,4],:].todense(), B[:,[1,3]][[2,4],:] )
+ assert_equal(A[:,[-1,-3]][[2,-4],:].todense(), B[:,[-1,-3]][[2,-4],:] )
+
class TestCSC(_TestCommon, _TestGetSet, _TestSolve,
_TestInplaceArithmetic, _TestArithmetic, _TestMatvecOutput,
_TestHorizSlicing, _TestVertSlicing, _TestBothSlicing,

0 comments on commit c19c8a8

Please sign in to comment.