New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: orthogonal procrustes solver #3809
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
300c57d
ENH: orthogonal procrustes solver
alexbrc d0366b4
MAINT: replace asarray with more specific format request
alexbrc bcc46d8
MAINT: tweak import
alexbrc 67242f7
TST: more extensive procrustes testing
alexbrc 94cbac0
MAINT: fiddle with procrustes arrays
alexbrc 6602cf8
MAINT: add an optional return value to orthogonal procrustes and add …
alexbrc 671ff41
DOC: add Raises to orthogonal_procrustes docstring
alexbrc 1195439
MAINT: remove unnecessary keyword argument from orthogonal_procrustes
alexbrc File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
Solve the orthogonal Procrustes problem. | ||
|
||
""" | ||
from __future__ import division, print_function, absolute_import | ||
|
||
import numpy as np | ||
from .decomp_svd import svd | ||
|
||
|
||
__all__ = ['orthogonal_procrustes'] | ||
|
||
|
||
def orthogonal_procrustes(A, B, check_finite=True): | ||
""" | ||
Compute the matrix solution of the orthogonal Procrustes problem. | ||
|
||
Given matrices A and B of equal shape, find an orthogonal matrix R | ||
that most closely maps A to B [1]_. | ||
Note that unlike higher level Procrustes analyses of spatial data, | ||
this function only uses orthogonal transformations like rotations | ||
and reflections, and it does not use scaling or translation. | ||
|
||
Parameters | ||
---------- | ||
A : (M, N) array_like | ||
Matrix to be mapped. | ||
B : (M, N) array_like | ||
Target matrix. | ||
check_finite : bool, optional | ||
Whether to check that the input matrices contain only finite numbers. | ||
Disabling may give a performance gain, but may result in problems | ||
(crashes, non-termination) if the inputs do contain infinities or NaNs. | ||
|
||
Returns | ||
------- | ||
R : (N, N) ndarray | ||
The matrix solution of the orthogonal Procrustes problem. | ||
Minimizes the Frobenius norm of dot(A, R) - B, subject to | ||
dot(R.T, R) == I. | ||
scale : float | ||
The sum of singular values of an intermediate matrix. | ||
This value is not returned unless specifically requested. | ||
|
||
Raises | ||
------ | ||
ValueError | ||
If the input arrays are incompatibly shaped. | ||
This may also be raised if matrix A or B contains an inf or nan | ||
and check_finite is True, or if the matrix product AB contains | ||
an inf or nan. | ||
|
||
References | ||
---------- | ||
.. [1] Peter H. Schonemann, "A generalized solution of the orthogonal | ||
Procrustes problem", Psychometrica -- Vol. 31, No. 1, March, 1996. | ||
|
||
""" | ||
if check_finite: | ||
A = np.asarray_chkfinite(A) | ||
B = np.asarray_chkfinite(B) | ||
else: | ||
A = np.asanyarray(A) | ||
B = np.asanyarray(B) | ||
if A.ndim != 2: | ||
raise ValueError('expected ndim to be 2, but observed %s' % A.ndim) | ||
if A.shape != B.shape: | ||
raise ValueError('the shapes of A and B differ (%s vs %s)' % ( | ||
A.shape, B.shape)) | ||
# Be clever with transposes, with the intention to save memory. | ||
u, w, vt = svd(B.T.dot(A).T) | ||
R = u.dot(vt) | ||
scale = w.sum() | ||
return R, scale |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from itertools import product, permutations | ||
|
||
import numpy as np | ||
from numpy.testing import assert_array_less, assert_allclose, assert_raises | ||
|
||
from scipy.linalg import inv, eigh, norm | ||
from scipy.linalg import orthogonal_procrustes | ||
|
||
|
||
def test_orthogonal_procrustes_ndim_too_large(): | ||
np.random.seed(1234) | ||
A = np.random.randn(3, 4, 5) | ||
B = np.random.randn(3, 4, 5) | ||
assert_raises(ValueError, orthogonal_procrustes, A, B) | ||
|
||
|
||
def test_orthogonal_procrustes_ndim_too_small(): | ||
np.random.seed(1234) | ||
A = np.random.randn(3) | ||
B = np.random.randn(3) | ||
assert_raises(ValueError, orthogonal_procrustes, A, B) | ||
|
||
|
||
def test_orthogonal_procrustes_shape_mismatch(): | ||
np.random.seed(1234) | ||
shapes = ((3, 3), (3, 4), (4, 3), (4, 4)) | ||
for a, b in permutations(shapes, 2): | ||
A = np.random.randn(*a) | ||
B = np.random.randn(*b) | ||
assert_raises(ValueError, orthogonal_procrustes, A, B) | ||
|
||
|
||
def test_orthogonal_procrustes_checkfinite_exception(): | ||
np.random.seed(1234) | ||
m, n = 2, 3 | ||
A_good = np.random.randn(m, n) | ||
B_good = np.random.randn(m, n) | ||
for bad_value in np.inf, -np.inf, np.nan: | ||
A_bad = A_good.copy() | ||
A_bad[1, 2] = bad_value | ||
B_bad = B_good.copy() | ||
B_bad[1, 2] = bad_value | ||
for A, B in ((A_good, B_bad), (A_bad, B_good), (A_bad, B_bad)): | ||
assert_raises(ValueError, orthogonal_procrustes, A, B) | ||
|
||
|
||
def test_orthogonal_procrustes_scale_invariance(): | ||
np.random.seed(1234) | ||
m, n = 4, 3 | ||
for i in range(3): | ||
A_orig = np.random.randn(m, n) | ||
B_orig = np.random.randn(m, n) | ||
R_orig, s = orthogonal_procrustes(A_orig, B_orig) | ||
for A_scale in np.square(np.random.randn(3)): | ||
for B_scale in np.square(np.random.randn(3)): | ||
R, s = orthogonal_procrustes(A_orig * A_scale, B_orig * B_scale) | ||
assert_allclose(R, R_orig) | ||
|
||
|
||
def test_orthogonal_procrustes_array_conversion(): | ||
np.random.seed(1234) | ||
for m, n in ((6, 4), (4, 4), (4, 6)): | ||
A_arr = np.random.randn(m, n) | ||
B_arr = np.random.randn(m, n) | ||
As = (A_arr, A_arr.tolist(), np.matrix(A_arr)) | ||
Bs = (B_arr, B_arr.tolist(), np.matrix(B_arr)) | ||
R_arr, s = orthogonal_procrustes(A_arr, B_arr) | ||
AR_arr = A_arr.dot(R_arr) | ||
for A, B in product(As, Bs): | ||
R, s = orthogonal_procrustes(A, B) | ||
AR = A_arr.dot(R) | ||
assert_allclose(AR, AR_arr) | ||
|
||
|
||
def test_orthogonal_procrustes(): | ||
np.random.seed(1234) | ||
for m, n in ((6, 4), (4, 4), (4, 6)): | ||
# Sample a random target matrix. | ||
B = np.random.randn(m, n) | ||
# Sample a random orthogonal matrix | ||
# by computing eigh of a sampled symmetric matrix. | ||
X = np.random.randn(n, n) | ||
w, V = eigh(X.T + X) | ||
assert_allclose(inv(V), V.T) | ||
# Compute a matrix with a known orthogonal transformation that gives B. | ||
A = np.dot(B, V.T) | ||
# Check that an orthogonal transformation from A to B can be recovered. | ||
R, s = orthogonal_procrustes(A, B) | ||
assert_allclose(inv(R), R.T) | ||
assert_allclose(A.dot(R), B) | ||
# Create a perturbed input matrix. | ||
A_perturbed = A + 1e-2 * np.random.randn(m, n) | ||
# Check that the orthogonal procrustes function can find an orthogonal | ||
# transformation that is better than the orthogonal transformation | ||
# computed from the original input matrix. | ||
R_prime, s = orthogonal_procrustes(A_perturbed, B) | ||
assert_allclose(inv(R_prime), R_prime.T) | ||
# Compute the naive and optimal transformations of the perturbed input. | ||
naive_approx = A_perturbed.dot(R) | ||
optim_approx = A_perturbed.dot(R_prime) | ||
# Compute the Frobenius norm errors of the matrix approximations. | ||
naive_approx_error = norm(naive_approx - B, ord='fro') | ||
optim_approx_error = norm(optim_approx - B, ord='fro') | ||
# Check that the orthogonal Procrustes approximation is better. | ||
assert_array_less(optim_approx_error, naive_approx_error) | ||
|
||
|
||
def _centered(A): | ||
mu = A.mean(axis=0) | ||
return A - mu, mu | ||
|
||
|
||
def test_orthogonal_procrustes_exact_example(): | ||
# Check a small application. | ||
# It uses translation, scaling, reflection, and rotation. | ||
# | ||
# | | ||
# a b | | ||
# | | ||
# d c | w | ||
# | | ||
# --------+--- x ----- z --- | ||
# | | ||
# | y | ||
# | | ||
# | ||
A_orig = np.array([[-3, 3], [-2, 3], [-2, 2], [-3, 2]], dtype=float) | ||
B_orig = np.array([[3, 2], [1, 0], [3, -2], [5, 0]], dtype=float) | ||
A, A_mu = _centered(A_orig) | ||
B, B_mu = _centered(B_orig) | ||
R, s = orthogonal_procrustes(A, B) | ||
scale = s / np.square(norm(A)) | ||
B_approx = scale * np.dot(A, R) + B_mu | ||
assert_allclose(B_approx, B_orig, atol=1e-8) | ||
|
||
|
||
def test_orthogonal_procrustes_stretched_example(): | ||
# Try again with a target with a stretched y axis. | ||
A_orig = np.array([[-3, 3], [-2, 3], [-2, 2], [-3, 2]], dtype=float) | ||
B_orig = np.array([[3, 40], [1, 0], [3, -40], [5, 0]], dtype=float) | ||
A, A_mu = _centered(A_orig) | ||
B, B_mu = _centered(B_orig) | ||
R, s = orthogonal_procrustes(A, B) | ||
scale = s / np.square(norm(A)) | ||
B_approx = scale * np.dot(A, R) + B_mu | ||
expected = np.array([[3, 21], [-18, 0], [3, -21], [24, 0]], dtype=float) | ||
assert_allclose(B_approx, expected, atol=1e-8) | ||
# Check disparity symmetry. | ||
expected_disparity = 0.4501246882793018 | ||
AB_disparity = np.square(norm(B_approx - B_orig) / norm(B)) | ||
assert_allclose(AB_disparity, expected_disparity) | ||
R, s = orthogonal_procrustes(B, A) | ||
scale = s / np.square(norm(B)) | ||
A_approx = scale * np.dot(B, R) + A_mu | ||
BA_disparity = np.square(norm(A_approx - A_orig) / norm(A)) | ||
assert_allclose(BA_disparity, expected_disparity) | ||
|
||
|
||
def test_orthogonal_procrustes_skbio_example(): | ||
# This transformation is also exact. | ||
# It uses translation, scaling, and reflection. | ||
# | ||
# | | ||
# | a | ||
# | b | ||
# | c d | ||
# --+--------- | ||
# | | ||
# | w | ||
# | | ||
# | x | ||
# | | ||
# | z y | ||
# | | ||
# | ||
A_orig = np.array([[4, -2], [4, -4], [4, -6], [2, -6]], dtype=float) | ||
B_orig = np.array([[1, 3], [1, 2], [1, 1], [2, 1]], dtype=float) | ||
B_standardized = np.array([ | ||
[-0.13363062, 0.6681531], | ||
[-0.13363062, 0.13363062], | ||
[-0.13363062, -0.40089186], | ||
[0.40089186, -0.40089186]]) | ||
A, A_mu = _centered(A_orig) | ||
B, B_mu = _centered(B_orig) | ||
R, s = orthogonal_procrustes(A, B) | ||
scale = s / np.square(norm(A)) | ||
B_approx = scale * np.dot(A, R) + B_mu | ||
assert_allclose(B_approx, B_orig) | ||
assert_allclose(B / norm(B), B_standardized) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be a
Raises
section here?