Skip to content

Commit

Permalink
Merge pull request #291 from pyamg/simplify
Browse files Browse the repository at this point in the history
rework conditionals
  • Loading branch information
lukeolson committed Dec 18, 2021
2 parents 9871e66 + c77d6e3 commit 66ead2c
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 64 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ jobs:
- name: Run flake8
run: |
python -m pip install flake8
python -m pip install pep8-naming flake8-quotes flake8-use-fstring
#flake8-docstrings flake8-builtins flake8-pytest-style flake8-simplify
python -m pip install pep8-naming flake8-quotes flake8-use-fstring flake8-pytest-style
#flake8-docstrings flake8-builtins
python -m flake8 --statistics pyamg && echo "flake8 passed."
pylint:
Expand Down
15 changes: 9 additions & 6 deletions pyamg/amg_core/tests/test_bind_examples.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pyamg.amg_core.tests.bind_examples as g
"""Test binding."""

import numpy as np
from numpy.testing import TestCase
from pytest import raises as assert_raises
import pytest

import pyamg.amg_core.tests.bind_examples as g


class TestDocstrings(TestCase):
Expand Down Expand Up @@ -104,25 +107,25 @@ def test_10f(self):
J = np.array([1, 1, 1], dtype=np.int8)
x = np.array([1.0, 2.0, 3.0], dtype=np.float32)

assert_raises(TypeError, g.test10, J, x)
pytest.raises(TypeError, g.test10, J, x)

def test_10g(self):
# int64, float32 (should FAIL on downconvert)
J = np.array([1, 1, 1], dtype=np.int64)
x = np.array([1.0, 2.0, 3.0], dtype=np.float32)

assert_raises(TypeError, g.test10, J, x)
pytest.raises(TypeError, g.test10, J, x)

def test_10h(self):
# int32, float16 (should FAIL on upconvert)
J = np.array([1, 1, 1], dtype=np.int32)
x = np.array([1.0, 2.0, 3.0], dtype=np.float16)

assert_raises(TypeError, g.test10, J, x)
pytest.raises(TypeError, g.test10, J, x)

def test_10i(self):
# int64, float32 (should FAIL on downconvert)
J = np.array([1, 1, 1], dtype=np.int32)
x = np.array([1.0, 2.0, 3.0], dtype=np.longdouble)

assert_raises(TypeError, g.test10, J, x)
pytest.raises(TypeError, g.test10, J, x)
6 changes: 3 additions & 3 deletions pyamg/classical/cr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ def _CRsweep(A, B, Findex, Cindex, nu, thetacr, method):
it = 0

while True:
if method not in ('habituated', 'concurrent'):
raise NotImplementedError('method not recognized: need habituated '
'or concurrent')
if method == 'habituated':
gauss_seidel(A, e, z, iterations=1)
e[Cindex] = 0.0
elif method == 'concurrent':
gauss_seidel_indexed(A, e, z, indices=Findex, iterations=1)
else:
raise NotImplementedError('method not recognized: need habituated '
'or concurrent')

enorm_old = enorm
enorm = norm(e)
Expand Down
5 changes: 3 additions & 2 deletions pyamg/gallery/elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,13 @@ def linear_elasticity_p1(vertices, elements, E=1e5, nu=0.3, format=None):
if elements.shape[1] != D + 1:
raise ValueError('dimension mismatch')

if D not in (2, 3):
raise ValueError('only dimension 2 and 3 are supported')

if D == 2:
local_K = p12d_local
elif D == 3:
local_K = p13d_local
else:
raise NotImplementedError('only dimension 2 and 3 are supported')

row = elements.repeat(D).reshape(-1, D)
row *= D
Expand Down
36 changes: 18 additions & 18 deletions pyamg/relaxation/relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ def schwarz(A, x, b, iterations=1, subdomain=None, subdomain_ptr=None,
schwarz_parameters(A, subdomain, subdomain_ptr,
inv_subblock, inv_subblock_ptr)

if sweep not in ('forward', 'backward', 'symmetric'):
raise ValueError("valid sweep directions: 'forward', 'backward', and 'symmetric'")

if sweep == 'forward':
row_start, row_stop, row_step = 0, subdomain_ptr.shape[0]-1, 1
elif sweep == 'backward':
Expand All @@ -260,9 +263,6 @@ def schwarz(A, x, b, iterations=1, subdomain=None, subdomain_ptr=None,
subdomain_ptr=subdomain_ptr, inv_subblock=inv_subblock,
inv_subblock_ptr=inv_subblock_ptr, sweep='backward')
return
else:
raise ValueError("valid sweep directions are 'forward',\
'backward', and 'symmetric'")

# Call C code, need to make sure that subdomains are sorted and unique
for _iter in range(iterations):
Expand Down Expand Up @@ -328,6 +328,9 @@ def gauss_seidel(A, x, b, iterations=1, sweep='forward'):
raise ValueError('BSR blocks must be square')
blocksize = R

if sweep not in ('forward', 'backward', 'symmetric'):
raise ValueError('valid sweep directions: "forward", "backward", and "symmetric"')

if sweep == 'forward':
row_start, row_stop, row_step = 0, int(len(x)/blocksize), 1
elif sweep == 'backward':
Expand All @@ -337,9 +340,6 @@ def gauss_seidel(A, x, b, iterations=1, sweep='forward'):
gauss_seidel(A, x, b, iterations=1, sweep='forward')
gauss_seidel(A, x, b, iterations=1, sweep='backward')
return
else:
raise ValueError('valid sweep directions are "forward", '
'"backward", and "symmetric"')

if sparse.isspmatrix_csr(A):
for _iter in range(iterations):
Expand Down Expand Up @@ -568,6 +568,9 @@ def block_gauss_seidel(A, x, b, iterations=1, sweep='forward', blocksize=1,
elif (Dinv.shape[1] != blocksize) or (Dinv.shape[2] != blocksize):
raise ValueError('Dinv and blocksize are incompatible')

if sweep not in ('forward', 'backward', 'symmetric'):
raise ValueError('valid sweep directions: "forward", "backward", and "symmetric"')

if sweep == 'forward':
row_start, row_stop, row_step = 0, int(len(x)/blocksize), 1
elif sweep == 'backward':
Expand All @@ -579,9 +582,6 @@ def block_gauss_seidel(A, x, b, iterations=1, sweep='forward', blocksize=1,
block_gauss_seidel(A, x, b, iterations=1, sweep='backward',
blocksize=blocksize, Dinv=Dinv)
return
else:
raise ValueError('valid sweep directions are "forward", '
'"backward", and "symmetric"')

for _iter in range(iterations):
amg_core.block_gauss_seidel(A.indptr, A.indices, np.ravel(A.data),
Expand Down Expand Up @@ -718,6 +718,9 @@ def gauss_seidel_indexed(A, x, b, indices, iterations=1, sweep='forward'):
# if indices.max() >= A.shape[0]
# raise ValueError('row index (%d) is invalid' % indices.max())

if sweep not in ('forward', 'backward', 'symmetric'):
raise ValueError('valid sweep directions: "forward", "backward", and "symmetric"')

if sweep == 'forward':
row_start, row_stop, row_step = 0, len(indices), 1
elif sweep == 'backward':
Expand All @@ -729,9 +732,6 @@ def gauss_seidel_indexed(A, x, b, indices, iterations=1, sweep='forward'):
gauss_seidel_indexed(A, x, b, indices, iterations=1,
sweep='backward')
return
else:
raise ValueError('valid sweep directions are "forward", '
'"backward", and "symmetric"')

for _iter in range(iterations):
amg_core.gauss_seidel_indexed(A.indptr, A.indices, A.data,
Expand Down Expand Up @@ -889,6 +889,9 @@ def gauss_seidel_ne(A, x, b, iterations=1, sweep='forward', omega=1.0,
if Dinv is None:
Dinv = np.ravel(get_diagonal(A, norm_eq=2, inv=True))

if sweep not in ('forward', 'backward', 'symmetric'):
raise ValueError('valid sweep directions: "forward", "backward", and "symmetric"')

if sweep == 'forward':
row_start, row_stop, row_step = 0, len(x), 1
elif sweep == 'backward':
Expand All @@ -900,9 +903,6 @@ def gauss_seidel_ne(A, x, b, iterations=1, sweep='forward', omega=1.0,
gauss_seidel_ne(A, x, b, iterations=1, sweep='backward',
omega=omega, Dinv=Dinv)
return
else:
raise ValueError('valid sweep directions are "forward", '
'"backward", and "symmetric"')

for _i in range(iterations):
amg_core.gauss_seidel_ne(A.indptr, A.indices, A.data,
Expand Down Expand Up @@ -974,6 +974,9 @@ def gauss_seidel_nr(A, x, b, iterations=1, sweep='forward', omega=1.0,
if Dinv is None:
Dinv = np.ravel(get_diagonal(A, norm_eq=1, inv=True))

if sweep not in ('forward', 'backward', 'symmetric'):
raise ValueError('valid sweep directions: "forward", "backward", and "symmetric"')

if sweep == 'forward':
col_start, col_stop, col_step = 0, len(x), 1
elif sweep == 'backward':
Expand All @@ -985,9 +988,6 @@ def gauss_seidel_nr(A, x, b, iterations=1, sweep='forward', omega=1.0,
gauss_seidel_nr(A, x, b, iterations=1, sweep='backward',
omega=omega, Dinv=Dinv)
return
else:
raise ValueError("valid sweep directions are 'forward',\
'backward', and 'symmetric'")

# Calculate initial residual
r = b - A*x
Expand Down
18 changes: 12 additions & 6 deletions pyamg/relaxation/tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ def test_gauss_seidel_csr(self):
x = np.ones(N)
gauss_seidel(A, x, b, iterations=200, sweep='backward')
resid2 = np.linalg.norm(A*x, 2)
self.assertTrue(resid1 < 0.01 and resid2 < 0.01)
assert resid1 < 0.01
assert resid2 < 0.01
assert_almost_equal(resid1, resid2)

def test_gauss_seidel_indexed(self):
Expand Down Expand Up @@ -595,7 +596,8 @@ def test_gauss_seidel_ne_csr(self):
x = np.ones(N)
gauss_seidel_ne(A, x, b, iterations=200, sweep='backward')
resid2 = np.linalg.norm(A*x, 2)
self.assertTrue(resid1 < 0.2 and resid2 < 0.2)
assert resid1 < 0.2
assert resid2 < 0.2
assert_almost_equal(resid1, resid2)

def test_gauss_seidel_nr_bsr(self):
Expand Down Expand Up @@ -683,7 +685,8 @@ def gold(A, x, b, iterations, sweep):
x = np.ones(N)
gauss_seidel_nr(A, x, b, iterations=200, sweep='backward')
resid2 = np.linalg.norm(A*x, 2)
self.assertTrue(resid1 < 0.2 and resid2 < 0.2)
assert resid1 < 0.2
assert resid2 < 0.2
assert_almost_equal(resid1, resid2)

def test_schwarz_gold(self):
Expand Down Expand Up @@ -1113,7 +1116,8 @@ def test_gauss_seidel_csr(self):
x = x + 1.0j*x
gauss_seidel(A, x, b, iterations=200, sweep='backward')
resid2 = np.linalg.norm(A*x, 2)
self.assertTrue(resid1 < 0.03 and resid2 < 0.03)
assert resid1 < 0.03
assert resid2 < 0.03
assert_almost_equal(resid1, resid2)

def test_jacobi_ne(self):
Expand Down Expand Up @@ -1341,7 +1345,8 @@ def test_gauss_seidel_ne_csr(self):
x = x + 1.0j*x
gauss_seidel_ne(A, x, b, iterations=200, sweep='backward')
resid2 = np.linalg.norm(A*x, 2)
self.assertTrue(resid1 < 0.3 and resid2 < 0.3)
assert resid1 < 0.3
assert resid2 < 0.3
assert_almost_equal(resid1, resid2)

def test_gauss_seidel_nr_bsr(self):
Expand Down Expand Up @@ -1436,7 +1441,8 @@ def gold(A, x, b, iterations, sweep):
x = x + 1.0j*x
gauss_seidel_nr(A, x, b, iterations=200, sweep='backward')
resid2 = np.linalg.norm(A*x, 2)
self.assertTrue(resid1 < 0.3 and resid2 < 0.3)
assert resid1 < 0.3
assert resid2 < 0.3
assert_almost_equal(resid1, resid2)


Expand Down
10 changes: 6 additions & 4 deletions pyamg/strength.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,15 @@ def classical_strength_of_connection(A, theta=0.0, norm='abs'):
Sj = np.empty_like(A.indices)
Sx = np.empty_like(A.data)

if norm not in ('abs', 'min'):
raise ValueError('Unknown norm')

if norm == 'abs':
amg_core.classical_strength_of_connection_abs(
A.shape[0], theta, A.indptr, A.indices, A.data, Sp, Sj, Sx)
elif norm == 'min':
amg_core.classical_strength_of_connection_min(
A.shape[0], theta, A.indptr, A.indices, A.data, Sp, Sj, Sx)
else:
raise ValueError('Unknown norm')

S = sparse.csr_matrix((Sx, Sj, Sp), shape=A.shape)

Expand Down Expand Up @@ -275,6 +276,9 @@ def symmetric_strength_of_connection(A, theta=0):
if theta < 0:
raise ValueError('expected a positive theta')

if not sparse.isspmatrix_csr(A) and not sparse.isspmatrix_bsr(A):
raise TypeError('expected csr_matrix or bsr_matrix')

if sparse.isspmatrix_csr(A):
# if theta == 0:
# return A
Expand Down Expand Up @@ -307,8 +311,6 @@ def symmetric_strength_of_connection(A, theta=0):
A = sparse.csr_matrix((data, A.indices, A.indptr),
shape=(int(M / R), int(N / C)))
return symmetric_strength_of_connection(A, theta)
else:
raise TypeError('expected csr_matrix or bsr_matrix')

# Strength represents "distance", so take the magnitude
S.data = np.abs(S.data)
Expand Down
42 changes: 19 additions & 23 deletions pyamg/vis/vis_coarse.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,6 @@ def vis_splitting(V, splitting, output='vtk', fname='output.vtu'):
elif len(a) >= 2:
fname1 = ''.join(a[:-1])
fname2 = a[-1]
else:
raise ValueError('problem with fname')

new_fname = fname
for d in range(0, Ndof):
Expand All @@ -225,6 +223,9 @@ def vis_splitting(V, splitting, output='vtk', fname='output.vtu'):

cdata = splitting[(d*N):((d+1)*N)]

if output not in ('vtk', 'matplotlib'):
raise ValueError('problem with outputtype')

if output == 'vtk':
write_basic_mesh(V=V, E2V=E2V, mesh_type='vertex',
cdata=cdata, fname=new_fname)
Expand All @@ -246,41 +247,36 @@ def vis_splitting(V, splitting, output='vtk', fname='output.vtu'):
plt.show()
except ImportError:
print('\nNote: matplotlib is needed for plotting.')
else:
raise ValueError('problem with outputtype')


def check_input(V=None, E2V=None, AggOp=None, A=None, splitting=None,
mesh_type=None):
def check_input(V=None, E2V=None, AggOp=None, A=None, splitting=None, mesh_type=None):
"""Check input for local functions."""
if V is not None:
if not np.issubdtype(V.dtype, np.floating):
raise ValueError('V should be of type float')
if V is not None and not np.issubdtype(V.dtype, np.floating):
raise ValueError('V should be of type float')

if E2V is not None:
if not np.issubdtype(E2V.dtype, np.integer):
raise ValueError('E2V should be of type integer')
if E2V.min() != 0:
warnings.warn(f'Element indices begin at {E2V.min()}')

if AggOp is not None:
if AggOp.shape[1] > AggOp.shape[0]:
raise ValueError('AggOp should be of size N x Nagg')
if AggOp is not None and AggOp.shape[1] > AggOp.shape[0]:
raise ValueError('AggOp should be of size N x Nagg')

if A is not None and AggOp is None:
raise ValueError('problem with check_input')

if (A is not None and AggOp is not None
and ((A.shape[0] != A.shape[1]) or (A.shape[0] != AggOp.shape[0]))):
raise ValueError('expected square matrix A and compatible with AggOp')

if A is not None:
if AggOp is not None:
if (A.shape[0] != A.shape[1]) or (A.shape[0] != AggOp.shape[0]):
raise ValueError('expected square matrix A and compatible with AggOp')
else:
raise ValueError('problem with check_input')
if splitting is not None and V is None:
raise ValueError('problem with check_input')

if splitting is not None:
splitting = splitting.ravel()
if V is not None:
if (len(splitting) % V.shape[0]) != 0:
raise ValueError('splitting must be a multiple of N')
else:
raise ValueError('problem with check_input')
if V is not None and (len(splitting) % V.shape[0]) != 0:
raise ValueError('splitting must be a multiple of N')

if mesh_type is not None:
valid_mesh_types = ('vertex', 'tri', 'quad', 'tet', 'hex')
Expand Down

0 comments on commit 66ead2c

Please sign in to comment.