Skip to content

Commit

Permalink
Rename compute_R to return_R, always compute R
Browse files Browse the repository at this point in the history
  • Loading branch information
pmli committed Feb 1, 2019
1 parent 56d0c02 commit dbd8db4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 20 deletions.
26 changes: 10 additions & 16 deletions src/pymor/algorithms/gram_schmidt.py
Expand Up @@ -10,7 +10,7 @@


@defaults('atol', 'rtol', 'reiterate', 'reiteration_threshold', 'check', 'check_tol')
def gram_schmidt(A, product=None, compute_R=False, atol=1e-13, rtol=1e-13, offset=0,
def gram_schmidt(A, product=None, return_R=False, atol=1e-13, rtol=1e-13, offset=0,
reiterate=True, reiteration_threshold=1e-1, check=True, check_tol=1e-3,
copy=True):
"""Orthonormalize a |VectorArray| using the modified Gram-Schmidt algorithm.
Expand All @@ -22,7 +22,7 @@ def gram_schmidt(A, product=None, compute_R=False, atol=1e-13, rtol=1e-13, offse
product
The inner product |Operator| w.r.t. which to orthonormalize.
If `None`, the Euclidean product is used.
compute_R
return_R
If `True`, the R matrix from QR decomposition is returned.
atol
Vectors of norm smaller than `atol` are removed from the array.
Expand Down Expand Up @@ -58,10 +58,8 @@ def gram_schmidt(A, product=None, compute_R=False, atol=1e-13, rtol=1e-13, offse
if copy:
A = A.copy()

if compute_R:
R = np.eye(len(A))

# main loop
R = np.eye(len(A))
remove = [] # indices of to be removed vectors
for i in range(offset, len(A)):
# first calculate norm
Expand All @@ -74,8 +72,7 @@ def gram_schmidt(A, product=None, compute_R=False, atol=1e-13, rtol=1e-13, offse

if i == 0:
A[0].scal(1 / initial_norm)
if compute_R:
R[i, i] = initial_norm
R[i, i] = initial_norm
else:
norm = initial_norm
# If reiterate is True, reiterate as long as the norm of the vector changes
Expand All @@ -87,8 +84,7 @@ def gram_schmidt(A, product=None, compute_R=False, atol=1e-13, rtol=1e-13, offse
continue
p = A[j].pairwise_inner(A[i], product)[0]
A[i].axpy(-p, A[j])
if compute_R:
R[j, i] += p
R[j, i] += p

# calculate new norm
old_norm, norm = norm, A[i].norm(product)[0]
Expand All @@ -104,14 +100,12 @@ def gram_schmidt(A, product=None, compute_R=False, atol=1e-13, rtol=1e-13, offse
logger.info(f"Orthonormalizing vector {i} again")
else:
A[i].scal(1 / norm)
if compute_R:
R[i, i] = norm
R[i, i] = norm
break

if remove:
del A[remove]
if compute_R:
R = np.delete(R, remove, axis=0)
R = np.delete(R, remove, axis=0)

if check:
error_matrix = A[offset:len(A)].inner(A, product)
Expand All @@ -121,10 +115,10 @@ def gram_schmidt(A, product=None, compute_R=False, atol=1e-13, rtol=1e-13, offse
if err >= check_tol:
raise AccuracyError(f"result not orthogonal (max err={err})")

if compute_R:
if return_R:
return A, R

return A
else:
return A


def gram_schmidt_biorth(V, W, product=None,
Expand Down
8 changes: 4 additions & 4 deletions src/pymortests/algorithms/gram_schmidt.py
Expand Up @@ -28,13 +28,13 @@ def test_gram_schmidt_with_R(vector_array):
U = vector_array

V = U.copy()
onb, R = gram_schmidt(U, compute_R=True, copy=True)
onb, R = gram_schmidt(U, return_R=True, copy=True)
assert np.all(almost_equal(U, V))
assert np.allclose(onb.dot(onb), np.eye(len(onb)))
assert np.all(almost_equal(U, onb.lincomb(U.dot(onb)), rtol=1e-13))
assert np.all(almost_equal(V, onb.lincomb(R.T)))

onb2, R2 = gram_schmidt(U, compute_R=True, copy=False)
onb2, R2 = gram_schmidt(U, return_R=True, copy=False)
assert np.all(almost_equal(onb, onb2))
assert np.all(R == R2)
assert np.all(almost_equal(onb, U))
Expand All @@ -58,13 +58,13 @@ def test_gram_schmidt_with_product_and_R(operator_with_arrays_and_products):
_, _, U, _, p, _ = operator_with_arrays_and_products

V = U.copy()
onb, R = gram_schmidt(U, product=p, compute_R=True, copy=True)
onb, R = gram_schmidt(U, product=p, return_R=True, copy=True)
assert np.all(almost_equal(U, V))
assert np.allclose(p.apply2(onb, onb), np.eye(len(onb)))
assert np.all(almost_equal(U, onb.lincomb(p.apply2(U, onb)), rtol=1e-13))
assert np.all(almost_equal(U, onb.lincomb(R.T)))

onb2, R2 = gram_schmidt(U, product=p, compute_R=True, copy=False)
onb2, R2 = gram_schmidt(U, product=p, return_R=True, copy=False)
assert np.all(almost_equal(onb, onb2))
assert np.all(R == R2)
assert np.all(almost_equal(onb, U))
Expand Down

0 comments on commit dbd8db4

Please sign in to comment.