diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 8c973e608..000000000 --- a/.flake8 +++ /dev/null @@ -1,8 +0,0 @@ -[flake8] -application_import_names = pymanopt -docstring-convention = google -exclude = examples/notebooks/*.py -ignore = B024, B027, D1, E203, E501, W503 -import_order_style = pycharm -max-line-length = 79 -select = B, B950, C, D, E, F, W diff --git a/TODO.md b/TODO.md index 76c26bd9b..6c35f70a7 100644 --- a/TODO.md +++ b/TODO.md @@ -1,17 +1,15 @@ # TODO/Roadmap -## 2.1.x: - - attrs +## 3.x: - Add 'check_hessian' function - Refactor optimizer implementations - Add complex manifolds #125, #170 - - Add JAX backend #115 - Add L-BFGS and other quasi-Newton optimizers - Add patience parameter to terminate optimization if cost does not improve anymore #114 - - Add callback mechanism to allow for custom termination criteria + - Add callback mechanism to allow for custom termination criteria #133 -## 3.0.x: +## 4.x: - Raise exception if dimension of manifold is 0 - Add pep8-naming (requires breaking public API to fix all errors) - Make FixedRankEmbedded manifold compatible with autodiff backends @@ -24,4 +22,3 @@ - Revist 'reuse_line_searcher' and 'self._line_searcher' vs. 'self.line_searcher' instance attributes - Rename 'orth_value' to 'restart_threshold' - - Revisit checking docstrings with darglint if the package is more mature diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 20219f17f..ac9bfa38d 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -21,7 +21,8 @@ Installation Pymanopt is compatible with Python 3.6+, and depends on NumPy and SciPy. Additionally, to use Pymanopt's built-in automatic differentiation, which we strongly recommend, you need to setup your cost functions using either -`Autograd `_, `TensorFlow +`Autograd `_, `JAX +_`, `TensorFlow `_ or `PyTorch `_. If you are unfamiliar with these packages and you are unsure which to go for, we suggest to start with Autograd. diff --git a/examples/closest_unit_norm_column_approximation.py b/examples/closest_unit_norm_column_approximation.py index bbdf50c61..71a56e400 100644 --- a/examples/closest_unit_norm_column_approximation.py +++ b/examples/closest_unit_norm_column_approximation.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import ConjugateGradient -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, matrix, backend): @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend): def cost(X): return 0.5 * np.sum((X - matrix) ** 2) + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(X): + return 0.5 * jnp.sum((X - matrix) ** 2) + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/examples/dominant_eigenvector.py b/examples/dominant_eigenvector.py index fef814385..40f512f64 100644 --- a/examples/dominant_eigenvector.py +++ b/examples/dominant_eigenvector.py @@ -8,7 +8,7 @@ from pymanopt.optimizers import SteepestDescent -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, matrix, backend): @@ -20,6 +20,12 @@ def create_cost_and_derivates(manifold, matrix, backend): def cost(x): return -x.T @ matrix @ x + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(x): + return -x.T @ matrix @ x + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/examples/dominant_invariant_complex_subspace.py b/examples/dominant_invariant_complex_subspace.py index 111718673..a49316a86 100644 --- a/examples/dominant_invariant_complex_subspace.py +++ b/examples/dominant_invariant_complex_subspace.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import TrustRegions -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, matrix, backend): @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend): def cost(X): return -np.real(np.trace(np.conj(X.T) @ matrix @ X)) + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(X): + return -jnp.real(jnp.trace(jnp.conj(X.T) @ matrix @ X)) + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/examples/dominant_invariant_subspace.py b/examples/dominant_invariant_subspace.py index c618f34fc..e4db285d3 100644 --- a/examples/dominant_invariant_subspace.py +++ b/examples/dominant_invariant_subspace.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import TrustRegions -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, matrix, backend): @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend): def cost(X): return -np.trace(X.T @ matrix @ X) + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(X): + return -jnp.trace(X.T @ matrix @ X) + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/examples/low_rank_matrix_approximation.py b/examples/low_rank_matrix_approximation.py index acbe96091..ac449129f 100644 --- a/examples/low_rank_matrix_approximation.py +++ b/examples/low_rank_matrix_approximation.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import ConjugateGradient -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, matrix, backend): @@ -21,6 +22,13 @@ def cost(u, s, vt): X = u @ np.diag(s) @ vt return np.linalg.norm(X - matrix) ** 2 + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(u, s, vt): + X = u @ jnp.diag(s) @ vt + return jnp.linalg.norm(X - matrix) ** 2 + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/examples/low_rank_psd_matrix_approximation.py b/examples/low_rank_psd_matrix_approximation.py index f1e9959bb..038f34bae 100644 --- a/examples/low_rank_psd_matrix_approximation.py +++ b/examples/low_rank_psd_matrix_approximation.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import TrustRegions -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, matrix, backend): @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend): def cost(Y): return np.linalg.norm(Y @ Y.T - matrix, "fro") ** 2 + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(Y): + return jnp.linalg.norm(Y @ Y.T - matrix, "fro") ** 2 + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/examples/multiple_linear_regression.py b/examples/multiple_linear_regression.py index fae5ac72c..a5f759cde 100644 --- a/examples/multiple_linear_regression.py +++ b/examples/multiple_linear_regression.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import TrustRegions -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, samples, targets, backend): @@ -18,9 +19,14 @@ def create_cost_and_derivates(manifold, samples, targets, backend): @pymanopt.function.autograd(manifold) def cost(weights): - # Use autograd's linalg.norm wrapper. return np.linalg.norm(targets - samples @ weights) ** 2 + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(weights): + return jnp.linalg.norm(targets - samples @ weights) ** 2 + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/examples/optimal_rotations.py b/examples/optimal_rotations.py index f96d37b68..e5293544c 100644 --- a/examples/optimal_rotations.py +++ b/examples/optimal_rotations.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import SteepestDescent -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, ABt, backend): @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, ABt, backend): def cost(X): return -np.tensordot(X, ABt, axes=X.ndim) + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(X): + return -jnp.tensordot(X, ABt, axes=X.ndim) + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/examples/packing_on_the_sphere.py b/examples/packing_on_the_sphere.py index 9f73af64b..7f461b422 100644 --- a/examples/packing_on_the_sphere.py +++ b/examples/packing_on_the_sphere.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import ConjugateGradient -SUPPORTED_BACKENDS = ("autograd", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "pytorch", "tensorflow") def create_cost(manifold, epsilon, backend): @@ -26,6 +27,17 @@ def cost(X): u = np.triu(expY, 1).sum() return s + epsilon * np.log(u) + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(X): + Y = X @ X.T + s = jnp.triu(Y, 1).max() + expY = jnp.exp((Y - s) / epsilon) + expY -= jnp.diag(jnp.diag(expY)) + u = jnp.triu(expY, 1).sum() + return s + epsilon * jnp.log(u) + elif backend == "pytorch": @pymanopt.function.pytorch(manifold) diff --git a/examples/pca.py b/examples/pca.py index 3a3408540..7ca425c4f 100644 --- a/examples/pca.py +++ b/examples/pca.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import TrustRegions -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, samples, backend): @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, samples, backend): def cost(w): return np.linalg.norm(samples - samples @ w @ w.T) ** 2 + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(w): + return jnp.linalg.norm(samples - samples @ w @ w.T) ** 2 + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/examples/rank_k_correlation_matrix_approximation.py b/examples/rank_k_correlation_matrix_approximation.py index a365d070c..2b2073b78 100644 --- a/examples/rank_k_correlation_matrix_approximation.py +++ b/examples/rank_k_correlation_matrix_approximation.py @@ -1,4 +1,5 @@ import autograd.numpy as np +import jax.numpy as jnp import tensorflow as tf import torch @@ -8,7 +9,7 @@ from pymanopt.optimizers import TrustRegions -SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow") +SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow") def create_cost_and_derivates(manifold, matrix, backend): @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend): def cost(X): return 0.25 * np.linalg.norm(X.T @ X - matrix) ** 2 + elif backend == "jax": + + @pymanopt.function.jax(manifold) + def cost(X): + return 0.25 * jnp.linalg.norm(X.T @ X - matrix) ** 2 + elif backend == "numpy": @pymanopt.function.numpy(manifold) diff --git a/setup.cfg b/setup.cfg index 37e48c248..bfb112fd4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,11 +46,14 @@ python_requires = >=3.7 [options.extras_require] autograd = autograd>=1.2 +jax = + jax>=0.2.0 + jaxlib tensorflow = tensorflow>=2.0 torch = torch>=1.0 -backends = pymanopt[autograd,tensorflow,torch] +backends = pymanopt[autograd,jax,tensorflow,torch] dev = black[jupyter]==22.3.0 flake8==5.0.4 diff --git a/src/pymanopt/autodiff/backends/__init__.py b/src/pymanopt/autodiff/backends/__init__.py index dfdb93907..72f6ae184 100644 --- a/src/pymanopt/autodiff/backends/__init__.py +++ b/src/pymanopt/autodiff/backends/__init__.py @@ -1,13 +1,15 @@ -__all__ = ["autograd", "numpy", "pytorch", "tensorflow"] +__all__ = ["autograd", "jax", "numpy", "pytorch", "tensorflow"] from .. import backend_decorator_factory from ._autograd import AutogradBackend +from ._jax import JaxBackend from ._numpy import NumPyBackend from ._pytorch import PyTorchBackend from ._tensorflow import TensorFlowBackend autograd = backend_decorator_factory(AutogradBackend) +jax = backend_decorator_factory(JaxBackend) numpy = backend_decorator_factory(NumPyBackend) pytorch = backend_decorator_factory(PyTorchBackend) tensorflow = backend_decorator_factory(TensorFlowBackend) diff --git a/src/pymanopt/autodiff/backends/_jax.py b/src/pymanopt/autodiff/backends/_jax.py new file mode 100644 index 000000000..e84e4c960 --- /dev/null +++ b/src/pymanopt/autodiff/backends/_jax.py @@ -0,0 +1,75 @@ +import functools + +import numpy as np + + +try: + import jax +except ImportError: + jax = None +else: + import jax.numpy as jnp + from jax.config import config + + config.update("jax_enable_x64", True) + +from ...tools import bisect_sequence, unpack_singleton_sequence_return_value +from ._backend import Backend + + +def conjugate_result(function): + @functools.wraps(function) + def wrapper(*args, **kwargs): + return list(map(jnp.conj, function(*args, **kwargs))) + + return wrapper + + +def to_ndarray(function): + @functools.wraps(function) + def wrapper(*args, **kwargs): + return list(map(np.asarray, function(*args, **kwargs))) + + return wrapper + + +class JaxBackend(Backend): + def __init__(self): + super().__init__("Jax") + + @staticmethod + def is_available(): + return jax is not None + + @Backend._assert_backend_available + def prepare_function(self, function): + return function + + @Backend._assert_backend_available + def generate_gradient_operator(self, function, num_arguments): + gradient = to_ndarray( + conjugate_result(jax.grad(function, argnums=range(num_arguments))) + ) + if num_arguments == 1: + return unpack_singleton_sequence_return_value(gradient) + return gradient + + @Backend._assert_backend_available + def generate_hessian_operator(self, function, num_arguments): + @to_ndarray + @conjugate_result + def hessian_vector_product(arguments, vectors): + return jax.jvp( + jax.grad(function, argnums=range(num_arguments)), + arguments, + vectors, + )[1] + + @functools.wraps(hessian_vector_product) + def wrapper(*args): + arguments, vectors = bisect_sequence(args) + return hessian_vector_product(arguments, vectors) + + if num_arguments == 1: + return unpack_singleton_sequence_return_value(wrapper) + return wrapper diff --git a/src/pymanopt/function.py b/src/pymanopt/function.py index 3c877c88f..b59c543b1 100644 --- a/src/pymanopt/function.py +++ b/src/pymanopt/function.py @@ -1,3 +1,9 @@ -__all__ = ["autograd", "numpy", "pytorch", "tensorflow"] +__all__ = ["autograd", "jax", "numpy", "pytorch", "tensorflow"] -from pymanopt.autodiff.backends import autograd, numpy, pytorch, tensorflow +from pymanopt.autodiff.backends import ( + autograd, + jax, + numpy, + pytorch, + tensorflow, +) diff --git a/src/pymanopt/optimizers/optimizer.py b/src/pymanopt/optimizers/optimizer.py index f486805f6..06438eb5c 100644 --- a/src/pymanopt/optimizers/optimizer.py +++ b/src/pymanopt/optimizers/optimizer.py @@ -68,7 +68,6 @@ def run(self, problem, *, initial_point=None, **kwargs) -> OptimizerResult: Args: problem: Pymanopt problem class instance exposing the cost function and the manifold to optimize over. - The class must either initial_point: Initial point on the manifold. If no value is provided then a starting point will be randomly generated. diff --git a/tests/backends/test_jax.py b/tests/backends/test_jax.py new file mode 100644 index 000000000..c3a075ec3 --- /dev/null +++ b/tests/backends/test_jax.py @@ -0,0 +1,111 @@ +import jax.numpy as jnp +import pytest + +import pymanopt + +from . import _backend_tests + + +class TestUnaryFunction(_backend_tests.TestUnaryFunction): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(x): + return jnp.sum(x**2) + + self.cost = cost + + +class TestUnaryComplexFunction(_backend_tests.TestUnaryComplexFunction): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(x): + return jnp.real(jnp.sum(x**2)) + + self.cost = cost + + +class TestUnaryVarargFunction(_backend_tests.TestUnaryFunction): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(*x): + (x,) = x + return jnp.sum(x**2) + + self.cost = cost + + +class TestNaryFunction(_backend_tests.TestNaryFunction): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(x, y): + return x @ y + + self.cost = cost + + +class TestNaryVarargFunction(_backend_tests.TestNaryFunction): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(*args): + return jnp.dot(*args) + + self.cost = cost + + +class TestNaryParameterGrouping(_backend_tests.TestNaryParameterGrouping): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(x, y, z): + return jnp.sum(x**2 + y + z**3) + + self.cost = cost + + +class TestVector(_backend_tests.TestVector): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(X): + return jnp.exp(jnp.sum(X**2)) + + self.cost = cost + + +class TestMatrix(_backend_tests.TestMatrix): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(X): + return jnp.exp(jnp.sum(X**2)) + + self.cost = cost + + +class TestTensor3(_backend_tests.TestTensor3): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(X): + return jnp.exp(jnp.sum(X**2)) + + self.cost = cost + + +class TestMixed(_backend_tests.TestMixed): + @pytest.fixture(autouse=True) + def setup(self): + @pymanopt.function.jax(self.manifold) + def cost(x, y, z): + return ( + jnp.exp(jnp.sum(x**2)) + + jnp.exp(jnp.sum(y**2)) + + jnp.exp(jnp.sum(z**2)) + ) + + self.cost = cost diff --git a/tests/optimizers/test_conjugate_gradient.py b/tests/optimizers/test_conjugate_gradient.py index 80ce02bd2..b728558e4 100644 --- a/tests/optimizers/test_conjugate_gradient.py +++ b/tests/optimizers/test_conjugate_gradient.py @@ -78,5 +78,5 @@ def cost(X): column_indices = np.argsort(eigenvalues)[-subspace_dimension:] spanning_set = eigenvectors[:, column_indices] np_testing.assert_allclose( - manifold.dist(spanning_set, estimated_spanning_set), 0, atol=1e-6 + manifold.dist(spanning_set, estimated_spanning_set), 0, atol=1e-5 )