Skip to content

Commit

Permalink
Merge pull request #206 from nkoep/jax-backend
Browse files Browse the repository at this point in the history
Add JAX backend
  • Loading branch information
nkoep committed Jan 2, 2023
2 parents 923c271 + 2541aab commit 88615c6
Show file tree
Hide file tree
Showing 21 changed files with 300 additions and 33 deletions.
8 changes: 0 additions & 8 deletions .flake8

This file was deleted.

9 changes: 3 additions & 6 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# TODO/Roadmap

## 2.1.x:
- attrs
## 3.x:
- Add 'check_hessian' function
- Refactor optimizer implementations
- Add complex manifolds #125, #170
- Add JAX backend #115
- Add L-BFGS and other quasi-Newton optimizers
- Add patience parameter to terminate optimization if cost does not improve
anymore #114
- Add callback mechanism to allow for custom termination criteria
- Add callback mechanism to allow for custom termination criteria #133

## 3.0.x:
## 4.x:
- Raise exception if dimension of manifold is 0
- Add pep8-naming (requires breaking public API to fix all errors)
- Make FixedRankEmbedded manifold compatible with autodiff backends
Expand All @@ -24,4 +22,3 @@
- Revist 'reuse_line_searcher' and 'self._line_searcher' vs.
'self.line_searcher' instance attributes
- Rename 'orth_value' to 'restart_threshold'
- Revisit checking docstrings with darglint if the package is more mature
3 changes: 2 additions & 1 deletion docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ Installation
Pymanopt is compatible with Python 3.6+, and depends on NumPy and SciPy.
Additionally, to use Pymanopt's built-in automatic differentiation, which we
strongly recommend, you need to setup your cost functions using either
`Autograd <https://github.com/HIPS/autograd>`_, `TensorFlow
`Autograd <https://github.com/HIPS/autograd>`_, `JAX
<https://jax.readthedocs.io/en/latest/>_`, `TensorFlow
<https://www.tensorflow.org>`_ or `PyTorch <http://www.pytorch.org/>`_.
If you are unfamiliar with these packages and you are unsure which to go for,
we suggest to start with Autograd.
Expand Down
9 changes: 8 additions & 1 deletion examples/closest_unit_norm_column_approximation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import ConjugateGradient


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, matrix, backend):
Expand All @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend):
def cost(X):
return 0.5 * np.sum((X - matrix) ** 2)

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(X):
return 0.5 * jnp.sum((X - matrix) ** 2)

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
8 changes: 7 additions & 1 deletion examples/dominant_eigenvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pymanopt.optimizers import SteepestDescent


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, matrix, backend):
Expand All @@ -20,6 +20,12 @@ def create_cost_and_derivates(manifold, matrix, backend):
def cost(x):
return -x.T @ matrix @ x

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(x):
return -x.T @ matrix @ x

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
9 changes: 8 additions & 1 deletion examples/dominant_invariant_complex_subspace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import TrustRegions


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, matrix, backend):
Expand All @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend):
def cost(X):
return -np.real(np.trace(np.conj(X.T) @ matrix @ X))

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(X):
return -jnp.real(jnp.trace(jnp.conj(X.T) @ matrix @ X))

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
9 changes: 8 additions & 1 deletion examples/dominant_invariant_subspace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import TrustRegions


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, matrix, backend):
Expand All @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend):
def cost(X):
return -np.trace(X.T @ matrix @ X)

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(X):
return -jnp.trace(X.T @ matrix @ X)

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
10 changes: 9 additions & 1 deletion examples/low_rank_matrix_approximation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import ConjugateGradient


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, matrix, backend):
Expand All @@ -21,6 +22,13 @@ def cost(u, s, vt):
X = u @ np.diag(s) @ vt
return np.linalg.norm(X - matrix) ** 2

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(u, s, vt):
X = u @ jnp.diag(s) @ vt
return jnp.linalg.norm(X - matrix) ** 2

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
9 changes: 8 additions & 1 deletion examples/low_rank_psd_matrix_approximation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import TrustRegions


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, matrix, backend):
Expand All @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend):
def cost(Y):
return np.linalg.norm(Y @ Y.T - matrix, "fro") ** 2

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(Y):
return jnp.linalg.norm(Y @ Y.T - matrix, "fro") ** 2

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
10 changes: 8 additions & 2 deletions examples/multiple_linear_regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import TrustRegions


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, samples, targets, backend):
Expand All @@ -18,9 +19,14 @@ def create_cost_and_derivates(manifold, samples, targets, backend):

@pymanopt.function.autograd(manifold)
def cost(weights):
# Use autograd's linalg.norm wrapper.
return np.linalg.norm(targets - samples @ weights) ** 2

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(weights):
return jnp.linalg.norm(targets - samples @ weights) ** 2

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
9 changes: 8 additions & 1 deletion examples/optimal_rotations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import SteepestDescent


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, ABt, backend):
Expand All @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, ABt, backend):
def cost(X):
return -np.tensordot(X, ABt, axes=X.ndim)

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(X):
return -jnp.tensordot(X, ABt, axes=X.ndim)

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
14 changes: 13 additions & 1 deletion examples/packing_on_the_sphere.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import ConjugateGradient


SUPPORTED_BACKENDS = ("autograd", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "pytorch", "tensorflow")


def create_cost(manifold, epsilon, backend):
Expand All @@ -26,6 +27,17 @@ def cost(X):
u = np.triu(expY, 1).sum()
return s + epsilon * np.log(u)

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(X):
Y = X @ X.T
s = jnp.triu(Y, 1).max()
expY = jnp.exp((Y - s) / epsilon)
expY -= jnp.diag(jnp.diag(expY))
u = jnp.triu(expY, 1).sum()
return s + epsilon * jnp.log(u)

elif backend == "pytorch":

@pymanopt.function.pytorch(manifold)
Expand Down
9 changes: 8 additions & 1 deletion examples/pca.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import TrustRegions


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, samples, backend):
Expand All @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, samples, backend):
def cost(w):
return np.linalg.norm(samples - samples @ w @ w.T) ** 2

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(w):
return jnp.linalg.norm(samples - samples @ w @ w.T) ** 2

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
9 changes: 8 additions & 1 deletion examples/rank_k_correlation_matrix_approximation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import autograd.numpy as np
import jax.numpy as jnp
import tensorflow as tf
import torch

Expand All @@ -8,7 +9,7 @@
from pymanopt.optimizers import TrustRegions


SUPPORTED_BACKENDS = ("autograd", "numpy", "pytorch", "tensorflow")
SUPPORTED_BACKENDS = ("autograd", "jax", "numpy", "pytorch", "tensorflow")


def create_cost_and_derivates(manifold, matrix, backend):
Expand All @@ -20,6 +21,12 @@ def create_cost_and_derivates(manifold, matrix, backend):
def cost(X):
return 0.25 * np.linalg.norm(X.T @ X - matrix) ** 2

elif backend == "jax":

@pymanopt.function.jax(manifold)
def cost(X):
return 0.25 * jnp.linalg.norm(X.T @ X - matrix) ** 2

elif backend == "numpy":

@pymanopt.function.numpy(manifold)
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@ python_requires = >=3.7
[options.extras_require]
autograd =
autograd>=1.2
jax =
jax>=0.2.0
jaxlib
tensorflow =
tensorflow>=2.0
torch =
torch>=1.0
backends = pymanopt[autograd,tensorflow,torch]
backends = pymanopt[autograd,jax,tensorflow,torch]
dev =
black[jupyter]==22.3.0
flake8==5.0.4
Expand Down
4 changes: 3 additions & 1 deletion src/pymanopt/autodiff/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
__all__ = ["autograd", "numpy", "pytorch", "tensorflow"]
__all__ = ["autograd", "jax", "numpy", "pytorch", "tensorflow"]

from .. import backend_decorator_factory
from ._autograd import AutogradBackend
from ._jax import JaxBackend
from ._numpy import NumPyBackend
from ._pytorch import PyTorchBackend
from ._tensorflow import TensorFlowBackend


autograd = backend_decorator_factory(AutogradBackend)
jax = backend_decorator_factory(JaxBackend)
numpy = backend_decorator_factory(NumPyBackend)
pytorch = backend_decorator_factory(PyTorchBackend)
tensorflow = backend_decorator_factory(TensorFlowBackend)

0 comments on commit 88615c6

Please sign in to comment.