diff --git a/src/pymanopt/manifolds/group.py b/src/pymanopt/manifolds/group.py index 1c4f24d95..e99ac53a7 100644 --- a/src/pymanopt/manifolds/group.py +++ b/src/pymanopt/manifolds/group.py @@ -1,6 +1,7 @@ import numpy as np import scipy.special +import pymanopt.numerics as nx from pymanopt.manifolds.manifold import RiemannianSubmanifold from pymanopt.tools import extend_docstring from pymanopt.tools.multi import ( @@ -28,7 +29,7 @@ def __init__(self, name, dimension, retraction): raise ValueError(f"Invalid retraction type '{retraction}'") def inner_product(self, point, tangent_vector_a, tangent_vector_b): - return np.tensordot( + return nx.tensordot( tangent_vector_a.conj(), tangent_vector_b, axes=tangent_vector_a.ndim, @@ -222,7 +223,7 @@ def random_point(self): n, k = self._n, self._k if n == 1: point = np.ones((k, 1, 1)) + 1j * np.ones((k, 1, 1)) - point /= np.abs(point) + point /= nx.abs(point) else: point, _ = multiqr( np.random.normal(size=(k, n, n)) diff --git a/src/pymanopt/numerics/__init__.py b/src/pymanopt/numerics/__init__.py new file mode 100644 index 000000000..72ba2f363 --- /dev/null +++ b/src/pymanopt/numerics/__init__.py @@ -0,0 +1,17 @@ +__all__ = ["abs", "allclose", "exp", "tensordot"] + + +import importlib + +from pymanopt.numerics.core import abs, allclose, exp, tensordot + + +def register_backends(): + for backend in ["numpy", "jax", "pytorch", "tensorflow"]: + try: + importlib.import_module(f"pymanopt.numerics._backends.{backend}") + except ImportError: + pass + + +register_backends() diff --git a/src/pymanopt/numerics/_backends/__init__.py b/src/pymanopt/numerics/_backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pymanopt/numerics/_backends/jax.py b/src/pymanopt/numerics/_backends/jax.py new file mode 100644 index 000000000..f68fabbaf --- /dev/null +++ b/src/pymanopt/numerics/_backends/jax.py @@ -0,0 +1,25 @@ +import jax.numpy as jnp + +import pymanopt.numerics.core as nx + + +@nx.abs.register +def _(array: jnp.ndarray) -> jnp.ndarray: + return jnp.abs(array) + + +@nx.allclose.register +def _(array_a: jnp.ndarray, array_b: jnp.ndarray) -> bool: + return jnp.allclose(array_a, array_b) + + +@nx.exp.register +def _(array: jnp.ndarray) -> jnp.ndarray: + return jnp.exp(array) + + +@nx.tensordot.register +def _( + array_a: jnp.ndarray, array_b: jnp.ndarray, *, axes: int = 2 +) -> jnp.ndarray: + return jnp.tensordot(array_a, array_b, axes=axes) diff --git a/src/pymanopt/numerics/_backends/numpy.py b/src/pymanopt/numerics/_backends/numpy.py new file mode 100644 index 000000000..1a28866af --- /dev/null +++ b/src/pymanopt/numerics/_backends/numpy.py @@ -0,0 +1,25 @@ +import numpy as np + +import pymanopt.numerics.core as nx + + +@nx.abs.register +def _(array: np.ndarray) -> np.ndarray: + return np.abs(array) + + +@nx.allclose.register +def _(array_a: np.ndarray, array_b: np.ndarray) -> bool: + return np.allclose(array_a, array_b) + + +@nx.exp.register +def _(array: np.ndarray) -> np.ndarray: + return np.exp(array) + + +@nx.tensordot.register +def _( + array_a: np.ndarray, array_b: np.ndarray, *, axes: int = 2 +) -> np.ndarray: + return np.tensordot(array_a, array_b, axes=axes) diff --git a/src/pymanopt/numerics/_backends/pytorch.py b/src/pymanopt/numerics/_backends/pytorch.py new file mode 100644 index 000000000..b79bd7999 --- /dev/null +++ b/src/pymanopt/numerics/_backends/pytorch.py @@ -0,0 +1,32 @@ +import typing + +import torch + +import pymanopt.numerics.core as nx + + +@nx.abs.register +def _(array: torch.Tensor) -> torch.Tensor: + return torch.abs(array) + + +@nx.allclose.register +def _( + array_a: torch.Tensor, array_b: typing.Union[torch.Tensor, float, int] +) -> bool: + # PyTorch does not automatically coerce values to tensors. + if isinstance(array_b, (float, int)): + array_b = torch.Tensor([array_b]) + return torch.allclose(array_a, array_b) + + +@nx.exp.register +def _(array: torch.Tensor) -> torch.Tensor: + return torch.exp(array) + + +@nx.tensordot.register +def _( + array_a: torch.Tensor, array_b: torch.Tensor, *, axes: int = 2 +) -> torch.Tensor: + return torch.tensordot(array_a, array_b, dims=axes) diff --git a/src/pymanopt/numerics/_backends/tensorflow.py b/src/pymanopt/numerics/_backends/tensorflow.py new file mode 100644 index 000000000..61df4853d --- /dev/null +++ b/src/pymanopt/numerics/_backends/tensorflow.py @@ -0,0 +1,24 @@ +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +import pymanopt.numerics.core as nx + + +@nx.abs.register +def _(array: tf.Tensor) -> tf.Tensor: + return tnp.abs(array) + + +@nx.allclose.register +def _(array_a: tf.Tensor, array_b: tf.Tensor) -> bool: + return tnp.allclose(array_a, array_b) + + +@nx.exp.register +def _(array: tf.Tensor) -> tf.Tensor: + return tnp.exp(array) + + +@nx.tensordot.register +def _(array_a: tf.Tensor, array_b: tf.Tensor, *, axes: int = 2) -> tf.Tensor: + return tnp.tensordot(array_a, array_b, axes=axes) diff --git a/src/pymanopt/numerics/core.py b/src/pymanopt/numerics/core.py new file mode 100644 index 000000000..6acb394b2 --- /dev/null +++ b/src/pymanopt/numerics/core.py @@ -0,0 +1,36 @@ +import functools + + +def _not_implemented(function): + @functools.wraps(function) + def inner(*arguments): + raise TypeError( + f"Function '{function.__name__}' not implemented for arguments of " + f"type '{type(arguments[0])}'" + ) + + return inner + + +@functools.singledispatch +@_not_implemented +def abs(_): + pass + + +@functools.singledispatch +@_not_implemented +def allclose(*_): + pass + + +@functools.singledispatch +@_not_implemented +def exp(_): + pass + + +@functools.singledispatch +@_not_implemented +def tensordot(*_): + pass diff --git a/tests/numerics/__init__.py b/tests/numerics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/numerics/test_core.py b/tests/numerics/test_core.py new file mode 100644 index 000000000..2c47fa08e --- /dev/null +++ b/tests/numerics/test_core.py @@ -0,0 +1,66 @@ +import jax.numpy as jnp +import numpy as np +import pytest +import tensorflow as tf +import torch + +import pymanopt.numerics as nx + + +@pytest.mark.parametrize( + "argument, expected_output", + [ + (np.array([-4, 2]), np.array([4, 2])), + (jnp.array([-4, 2]), jnp.array([4, 2])), + (torch.Tensor([-4, 2]), torch.Tensor([4, 2])), + (tf.constant([-4, 2]), tf.constant([4, 2])), + ], +) +def test_abs(argument, expected_output): + output = nx.abs(argument) + assert nx.allclose(output, expected_output) + + +@pytest.mark.parametrize( + "argument_a, argument_b, expected_output", + [ + (np.array([4, 2]), np.array([4, 2]), True), + (np.array([4, 2]), np.array([2, 4]), False), + (jnp.array([4, 2]), jnp.array([4, 2]), True), + (jnp.array([4, 2]), jnp.array([2, 4]), False), + (torch.Tensor([4, 2]), torch.Tensor([4, 2]), True), + (torch.Tensor([4, 2]), torch.Tensor([2, 4]), False), + (tf.constant([4, 2]), tf.constant([4, 2]), True), + (tf.constant([4, 2]), tf.constant([2, 4]), False), + ], +) +def test_allclose(argument_a, argument_b, expected_output): + assert nx.allclose(argument_a, argument_b) == expected_output + + +@pytest.mark.parametrize( + "argument, expected_output", + [ + (np.log(np.array([4, 2])), np.array([4, 2])), + (jnp.log(jnp.array([4, 2])), jnp.array([4, 2])), + (torch.log(torch.Tensor([4, 2])), torch.Tensor([4, 2])), + (tf.math.log(tf.constant([4.0, 2.0])), tf.constant([4.0, 2.0])), + ], +) +def test_exp(argument, expected_output): + output = nx.exp(argument) + assert nx.allclose(output, expected_output) + + +@pytest.mark.parametrize( + "argument_a, argument_b, expected_output", + [ + (np.array([-4, 2]), np.array([1, 3]), 2), + (jnp.array([-4, 2]), jnp.array([1, 3]), 2), + (torch.Tensor([-4, 2]), torch.Tensor([1, 3]), 2), + (tf.constant([-4, 2]), tf.constant([1, 3]), 2), + ], +) +def test_tensordot(argument_a, argument_b, expected_output): + output = nx.tensordot(argument_a, argument_b, axes=argument_a.ndim) + assert nx.allclose(output, expected_output)