Skip to content

Commit

Permalink
Merge pull request #2586 from cowlicks/inequalities-override
Browse files Browse the repository at this point in the history
ENH: sparse: implement inequality ops for sparse matrices
  • Loading branch information
pv committed Jun 27, 2013
2 parents cc09a77 + 34ecef7 commit 0e1dd62
Show file tree
Hide file tree
Showing 17 changed files with 50,568 additions and 11 deletions.
12 changes: 12 additions & 0 deletions scipy/sparse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ def __eq__(self, other):
def __ne__(self, other):
return self.tocsr().__ne__(other)

def __lt__(self,other):
return self.tocsr().__lt__(other)

def __gt__(self,other):
return self.tocsr().__gt__(other)

def __le__(self,other):
return self.tocsr().__le__(other)

def __ge__(self,other):
return self.tocsr().__ge__(other)

def __abs__(self):
return abs(self.tocsr())

Expand Down
3 changes: 2 additions & 1 deletion scipy/sparse/bsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ def _binopt(self, other, op, in_shape=None, out_shape=None):
indptr = np.empty_like(self.indptr)
indices = np.empty(max_bnnz, dtype=np.intc)

if op == '_ne_':
bool_ops = ['_ne_', '_lt_', '_gt_', '_le_', '_ge_']
if op in bool_ops:
data = np.empty(R*C*max_bnnz, dtype=np.bool_)
else:
data = np.empty(R*C*max_bnnz, dtype=upcast(self.dtype,other.dtype))
Expand Down
135 changes: 131 additions & 4 deletions scipy/sparse/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __eq__(self, other):
res = self._binopt(other_arr,'_ne_')
if other == 0:
warn("Comparing a sparse matrix with 0 using == is inefficient"
", try using != instead.")
", try using != instead.", SparseEfficiencyWarning)
all_true = self.__class__(np.ones(self.shape, dtype=np.bool_))
return all_true - res
else:
Expand All @@ -189,7 +189,7 @@ def __eq__(self, other):
# Sparse other.
elif isspmatrix(other):
warn("Comparing sparse matrices using == is inefficient, try using"
" != instead.")
" != instead.", SparseEfficiencyWarning)
#TODO sparse broadcasting
if self.shape != other.shape:
return False
Expand All @@ -206,7 +206,7 @@ def __ne__(self, other):
if isscalarlike(other):
if other != 0:
warn("Comparing a sparse matrix with a nonzero scalar using !="
" is inefficient, try using == instead.")
" is inefficient, try using == instead.", SparseEfficiencyWarning)
all_true = self.__class__(np.ones(self.shape), dtype=np.bool_)
res = (self == other)
return all_true - res
Expand All @@ -228,6 +228,132 @@ def __ne__(self, other):
else:
return True

def __lt__(self, other):
# Scalar other.
if isscalarlike(other):
if 0 < other:
warn("Comparing a sparse matrix with a scalar greater than "
"zero using < is inefficient, try using > instead.", SparseEfficiencyWarning)
other_arr = np.empty(self.shape)
other_arr.fill(other)
other_arr = self.__class__(other_arr)
return self._binopt(other_arr, '_lt_')
else:
other_arr = self.copy()
other_arr.data[:] = other
return self._binopt(other_arr, '_lt_')
# Dense other.
elif isdense(other):
return self.todense() < other
# Sparse other.
elif isspmatrix(other):
#TODO sparse broadcasting
if self.shape != other.shape:
raise ValueError("inconsistent shapes")
elif self.format != other.format:
other = other.asformat(self.format)
return self._binopt(other, '_lt_')
else:
raise ValueError("Operands could not be compared.")

def __gt__(self, other):
# Scalar other.
if isscalarlike(other):
if 0 > other:
warn("Comparing a sparse matrix with a scalar less than zero "
"using > is inefficient, try using < instead.", SparseEfficiencyWarning)
other_arr = np.empty(self.shape)
other_arr.fill(other)
other_arr = self.__class__(other_arr)
return self._binopt(other_arr, '_gt_')
else:
other_arr = self.copy()
other_arr.data[:] = other
return self._binopt(other_arr, '_gt_')
# Dense other.
elif isdense(other):
return self.todense() > other
# Sparse other.
elif isspmatrix(other):
#TODO sparse broadcasting
if self.shape != other.shape:
raise ValueError("inconsistent shapes")
elif self.format != other.format:
other = other.asformat(self.format)
return self._binopt(other, '_gt_')
else:
raise ValueError("Operands could not be compared.")

def __le__(self,other):
# Scalar other.
if isscalarlike(other):
if 0 == other:
raise NotImplementedError(" >= and <= don't work with 0.")
elif 0 <= other:
warn("Comparing a sparse matrix with a scalar less than zero "
"using <= is inefficient, try using < instead.", SparseEfficiencyWarning)
other_arr = np.empty(self.shape)
other_arr.fill(other)
other_arr = self.__class__(other_arr)
return self._binopt(other_arr, '_le_')
else:
# Casting as other's type avoids corner case like
# ``spmatrix(True) < -2'' from being True.
other_arr = self.astype(type(other)).copy()
other_arr.data[:] = other
return self._binopt(other_arr, '_le_')
# Dense other.
elif isdense(other):
return self.todense() <= other
# Sparse other.
elif isspmatrix(other):
#TODO sparse broadcasting
if self.shape != other.shape:
raise ValueError("inconsistent shapes")
elif self.format != other.format:
other = other.asformat(self.format)
warn("Comparing sparse matrices using >= and <= is inefficient, "
"using <, >, or !=, instead.", SparseEfficiencyWarning)
all_true = self.__class__(np.ones(self.shape))
res = self._binopt(other, '_gt_')
return all_true - res
else:
raise ValueError("Operands could not be compared.")

def __ge__(self,other):
# Scalar other.
if isscalarlike(other):
if 0 == other:
raise NotImplementedError(" >= and <= don't work with 0.")
elif 0 >= other:
warn("Comparing a sparse matrix with a scalar greater than zero"
" using >= is inefficient, try using < instead.", SparseEfficiencyWarning)
other_arr = np.empty(self.shape)
other_arr.fill(other)
other_arr = self.__class__(other_arr)
return self._binopt(other_arr, '_ge_')
else:
other_arr = self.astype(type(other)).copy()
other_arr.data[:] = other
return self._binopt(other_arr, '_ge_')
# Dense other.
elif isdense(other):
return self.todense() >= other
# Sparse other.
elif isspmatrix(other):
#TODO sparse broadcasting
if self.shape != other.shape:
raise ValueError("inconsistent shapes")
elif self.format != other.format:
other = other.asformat(self.format)
warn("Comparing sparse matrices using >= and <= is inefficient, "
"try using <, >, or !=, instead.", SparseEfficiencyWarning)
all_true = self.__class__(np.ones(self.shape))
res = self._binopt(other, '_lt_')
return all_true - res
else:
raise ValueError("Operands could not be compared.")

#################################
# Arithmatic operator overrides #
#################################
Expand Down Expand Up @@ -763,7 +889,8 @@ def _binopt(self, other, op):
indptr = np.empty_like(self.indptr)
indices = np.empty(maxnnz, dtype=np.intc)

if op == '_ne_':
bool_ops = ['_ne_', '_lt_', '_gt_', '_le_', '_ge_']
if op in bool_ops:
data = np.empty(maxnnz, dtype=np.bool_)
else:
data = np.empty(maxnnz, dtype=upcast(self.dtype,other.dtype))
Expand Down
36 changes: 36 additions & 0 deletions scipy/sparse/sparsetools/bsr.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,42 @@ void bsr_ne_bsr(const I n_row, const I n_col, const I R, const I C,
bsr_binop_bsr(n_row,n_col,R,C,Ap,Aj,Ax,Bp,Bj,Bx,Cp,Cj,Cx,std::not_equal_to<T>());
}

template <class I, class T, class T2>
void bsr_lt_bsr(const I n_row, const I n_col, const I R, const I C,
const I Ap[], const I Aj[], const T Ax[],
const I Bp[], const I Bj[], const T Bx[],
I Cp[], I Cj[], T2 Cx[])
{
bsr_binop_bsr(n_row,n_col,R,C,Ap,Aj,Ax,Bp,Bj,Bx,Cp,Cj,Cx,std::less<T>());
}

template <class I, class T, class T2>
void bsr_gt_bsr(const I n_row, const I n_col, const I R, const I C,
const I Ap[], const I Aj[], const T Ax[],
const I Bp[], const I Bj[], const T Bx[],
I Cp[], I Cj[], T2 Cx[])
{
bsr_binop_bsr(n_row,n_col,R,C,Ap,Aj,Ax,Bp,Bj,Bx,Cp,Cj,Cx,std::greater<T>());
}

template <class I, class T, class T2>
void bsr_le_bsr(const I n_row, const I n_col, const I R, const I C,
const I Ap[], const I Aj[], const T Ax[],
const I Bp[], const I Bj[], const T Bx[],
I Cp[], I Cj[], T2 Cx[])
{
bsr_binop_bsr(n_row,n_col,R,C,Ap,Aj,Ax,Bp,Bj,Bx,Cp,Cj,Cx,std::less_equal<T>());
}

template <class I, class T, class T2>
void bsr_ge_bsr(const I n_row, const I n_col, const I R, const I C,
const I Ap[], const I Aj[], const T Ax[],
const I Bp[], const I Bj[], const T Bx[],
I Cp[], I Cj[], T2 Cx[])
{
bsr_binop_bsr(n_row,n_col,R,C,Ap,Aj,Ax,Bp,Bj,Bx,Cp,Cj,Cx,std::greater_equal<T>());
}

template <class I, class T>
void bsr_elmul_bsr(const I n_row, const I n_col, const I R, const I C,
const I Ap[], const I Aj[], const T Ax[],
Expand Down
4 changes: 4 additions & 0 deletions scipy/sparse/sparsetools/bsr.i
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ INSTANTIATE_ALL(bsr_minus_bsr)
INSTANTIATE_ALL(bsr_sort_indices)

INSTANTIATE_BOOL_OUT(bsr_ne_bsr)
INSTANTIATE_BOOL_OUT(bsr_lt_bsr)
INSTANTIATE_BOOL_OUT(bsr_gt_bsr)
INSTANTIATE_BOOL_OUT(bsr_le_bsr)
INSTANTIATE_BOOL_OUT(bsr_ge_bsr)
Loading

0 comments on commit 0e1dd62

Please sign in to comment.