Skip to content

Commit

Permalink
BUG: sparse/dsolve: check sparse matrix inputs more carefully
Browse files Browse the repository at this point in the history
  • Loading branch information
pv committed Feb 22, 2014
1 parent 2344e41 commit a18a98a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 50 deletions.
121 changes: 72 additions & 49 deletions scipy/sparse/linalg/dsolve/_superluobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,24 @@ int NRFormat_from_spMatrix(SuperMatrix * A, int m, int n, int nnz,
PyArrayObject * nzvals, PyArrayObject * colind,
PyArrayObject * rowptr, int typenum)
{
int err = 0;

err = (nzvals->descr->type_num != typenum);
err += (nzvals->nd != 1);
err += (nnz > nzvals->dimensions[0]);
if (err) {
PyErr_SetString(PyExc_TypeError,
"Fourth argument must be a 1-D array at least as big as third argument.");
int ok = 0;

ok = (PyArray_EquivTypenums(PyArray_DESCR(nzvals)->type_num, typenum) &&
PyArray_EquivTypenums(PyArray_DESCR(colind)->type_num, NPY_INT) &&
PyArray_EquivTypenums(PyArray_DESCR(rowptr)->type_num, NPY_INT) &&
PyArray_NDIM(nzvals) == 1 &&
PyArray_NDIM(colind) == 1 &&
PyArray_NDIM(rowptr) == 1 &&
PyArray_ISCARRAY(nzvals) &&
PyArray_ISCARRAY(colind) &&
PyArray_ISCARRAY(rowptr) &&
nnz <= PyArray_DIM(nzvals, 0) &&
nnz <= PyArray_DIM(colind, 0) &&
m+1 <= PyArray_DIM(rowptr, 0));
if (!ok) {
PyErr_SetString(PyExc_ValueError,
"sparse matrix arrays must be 1-D C-contigous and of proper "
"sizes and types");
return -1;
}

Expand All @@ -378,14 +388,24 @@ int NCFormat_from_spMatrix(SuperMatrix * A, int m, int n, int nnz,
PyArrayObject * nzvals, PyArrayObject * rowind,
PyArrayObject * colptr, int typenum)
{
int err = 0;

err = (nzvals->descr->type_num != typenum);
err += (nzvals->nd != 1);
err += (nnz > nzvals->dimensions[0]);
if (err) {
PyErr_SetString(PyExc_TypeError,
"Fifth argument must be a 1-D array at least as big as fourth argument.");
int ok = 0;

ok = (PyArray_EquivTypenums(PyArray_DESCR(nzvals)->type_num, typenum) &&
PyArray_EquivTypenums(PyArray_DESCR(rowind)->type_num, NPY_INT) &&
PyArray_EquivTypenums(PyArray_DESCR(colptr)->type_num, NPY_INT) &&
PyArray_NDIM(nzvals) == 1 &&
PyArray_NDIM(rowind) == 1 &&
PyArray_NDIM(colptr) == 1 &&
PyArray_ISCARRAY(nzvals) &&
PyArray_ISCARRAY(rowind) &&
PyArray_ISCARRAY(colptr) &&
nnz <= PyArray_DIM(nzvals, 0) &&
nnz <= PyArray_DIM(rowind, 0) &&
n+1 <= PyArray_DIM(colptr, 0));
if (!ok) {
PyErr_SetString(PyExc_ValueError,
"sparse matrix arrays must be 1-D C-contigous and of proper "
"sizes and types");
return -1;
}

Expand Down Expand Up @@ -812,44 +832,47 @@ int set_superlu_options_from_dict(superlu_options_t * options,
_relax = sp_ienv(2);

if (option_dict == NULL) {
return 0;
/* Proceed with default options */
ret = 1;
}
else {
args = PyTuple_New(0);
ret = PyArg_ParseTupleAndKeywords(args, option_dict,
"|O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&",
kwlist, fact_cvt, &options->Fact,
yes_no_cvt, &options->Equil,
colperm_cvt, &options->ColPerm,
trans_cvt, &options->Trans,
iterrefine_cvt, &options->IterRefine,
double_cvt,
&options->DiagPivotThresh,
yes_no_cvt, &options->PivotGrowth,
yes_no_cvt,
&options->ConditionNumber,
rowperm_cvt, &options->RowPerm,
yes_no_cvt, &options->SymmetricMode,
yes_no_cvt, &options->PrintStat,
yes_no_cvt,
&options->ReplaceTinyPivot,
yes_no_cvt,
&options->SolveInitialized,
yes_no_cvt,
&options->RefineInitialized,
norm_cvt, &options->ILU_Norm,
milu_cvt, &options->ILU_MILU,
double_cvt, &options->ILU_DropTol,
double_cvt, &options->ILU_FillTol,
double_cvt, &options->ILU_FillFactor,
droprule_cvt, &options->ILU_DropRule,
int_cvt, &_panel_size, int_cvt,
&_relax);
Py_DECREF(args);
}

args = PyTuple_New(0);
ret = PyArg_ParseTupleAndKeywords(args, option_dict,
"|O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&O&",
kwlist, fact_cvt, &options->Fact,
yes_no_cvt, &options->Equil,
colperm_cvt, &options->ColPerm,
trans_cvt, &options->Trans,
iterrefine_cvt, &options->IterRefine,
double_cvt,
&options->DiagPivotThresh,
yes_no_cvt, &options->PivotGrowth,
yes_no_cvt,
&options->ConditionNumber,
rowperm_cvt, &options->RowPerm,
yes_no_cvt, &options->SymmetricMode,
yes_no_cvt, &options->PrintStat,
yes_no_cvt,
&options->ReplaceTinyPivot,
yes_no_cvt,
&options->SolveInitialized,
yes_no_cvt,
&options->RefineInitialized,
norm_cvt, &options->ILU_Norm,
milu_cvt, &options->ILU_MILU,
double_cvt, &options->ILU_DropTol,
double_cvt, &options->ILU_FillTol,
double_cvt, &options->ILU_FillFactor,
droprule_cvt, &options->ILU_DropRule,
int_cvt, &_panel_size, int_cvt,
&_relax);
Py_DECREF(args);

if (panel_size != NULL) {
*panel_size = _panel_size;
}

if (relax != NULL) {
*relax = _relax;
}
Expand Down
36 changes: 35 additions & 1 deletion scipy/sparse/linalg/dsolve/tests/test_linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from scipy.linalg import norm, inv
from scipy.sparse import spdiags, SparseEfficiencyWarning, csc_matrix, csr_matrix, \
isspmatrix, dok_matrix, lil_matrix, bsr_matrix
from scipy.sparse.linalg.dsolve import spsolve, use_solver, splu, spilu, MatrixRankWarning
from scipy.sparse.linalg.dsolve import spsolve, use_solver, splu, spilu, \
MatrixRankWarning, _superlu

warnings.simplefilter('ignore',SparseEfficiencyWarning)

Expand Down Expand Up @@ -171,6 +172,39 @@ def test_ndarray_support(self):

assert_array_almost_equal(x, spsolve(A, b))

def test_gssv_badinput(self):
N = 10
d = arange(N) + 1.0
A = spdiags((d, 2*d, d[::-1]), (-3, 0, 5), N, N)

for spmatrix in (csc_matrix, csr_matrix):
A = spmatrix(A)
b = np.arange(N)

def not_c_contig(x):
return x.repeat(2)[::2]
def not_1dim(x):
return x[:,None]
def bad_type(x):
return x.astype(bool)
def too_short(x):
return x[:-1]

badops = [not_c_contig, not_1dim, bad_type, too_short]

for badop in badops:
msg = "%r %r" % (spmatrix, badop)
# Not C-contiguous
assert_raises((ValueError, TypeError), _superlu.gssv,
N, A.nnz, badop(A.data), A.indices, A.indptr,
b, int(spmatrix == csc_matrix), err_msg=msg)
assert_raises((ValueError, TypeError), _superlu.gssv,
N, A.nnz, A.data, badop(A.indices), A.indptr,
b, int(spmatrix == csc_matrix), err_msg=msg)
assert_raises((ValueError, TypeError), _superlu.gssv,
N, A.nnz, A.data, A.indices, badop(A.indptr),
b, int(spmatrix == csc_matrix), err_msg=msg)


class TestSplu(object):
def setUp(self):
Expand Down

0 comments on commit a18a98a

Please sign in to comment.