From d749708d78c207767b9e9ecacf099b553b342544 Mon Sep 17 00:00:00 2001 From: Niklas Koep Date: Sun, 2 Jul 2023 10:39:39 +0200 Subject: [PATCH 1/4] Add PoC numerics package Signed-off-by: Niklas Koep --- src/pymanopt/manifolds/group.py | 5 +-- src/pymanopt/numerics/__init__.py | 17 ++++++++++ src/pymanopt/numerics/_backends/__init__.py | 0 src/pymanopt/numerics/_backends/numpy.py | 23 +++++++++++++ src/pymanopt/numerics/core.py | 36 +++++++++++++++++++++ tests/numerics/__init__.py | 0 tests/numerics/test_core.py | 12 +++++++ 7 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 src/pymanopt/numerics/__init__.py create mode 100644 src/pymanopt/numerics/_backends/__init__.py create mode 100644 src/pymanopt/numerics/_backends/numpy.py create mode 100644 src/pymanopt/numerics/core.py create mode 100644 tests/numerics/__init__.py create mode 100644 tests/numerics/test_core.py diff --git a/src/pymanopt/manifolds/group.py b/src/pymanopt/manifolds/group.py index 1c4f24d95..e99ac53a7 100644 --- a/src/pymanopt/manifolds/group.py +++ b/src/pymanopt/manifolds/group.py @@ -1,6 +1,7 @@ import numpy as np import scipy.special +import pymanopt.numerics as nx from pymanopt.manifolds.manifold import RiemannianSubmanifold from pymanopt.tools import extend_docstring from pymanopt.tools.multi import ( @@ -28,7 +29,7 @@ def __init__(self, name, dimension, retraction): raise ValueError(f"Invalid retraction type '{retraction}'") def inner_product(self, point, tangent_vector_a, tangent_vector_b): - return np.tensordot( + return nx.tensordot( tangent_vector_a.conj(), tangent_vector_b, axes=tangent_vector_a.ndim, @@ -222,7 +223,7 @@ def random_point(self): n, k = self._n, self._k if n == 1: point = np.ones((k, 1, 1)) + 1j * np.ones((k, 1, 1)) - point /= np.abs(point) + point /= nx.abs(point) else: point, _ = multiqr( np.random.normal(size=(k, n, n)) diff --git a/src/pymanopt/numerics/__init__.py b/src/pymanopt/numerics/__init__.py new file mode 100644 index 000000000..1412f859b --- /dev/null +++ b/src/pymanopt/numerics/__init__.py @@ -0,0 +1,17 @@ +__all__ = ["abs", "allclose", "exp", "tensordot"] + + +import importlib + +from pymanopt.numerics.core import abs, allclose, exp, tensordot + + +def _register_backends(): + for backend in ["numpy", "jax", "pytorch", "tensorflow"]: + try: + importlib.import_module(f"pymanopt.numerics._backends.{backend}") + except ImportError: + pass + + +_register_backends() diff --git a/src/pymanopt/numerics/_backends/__init__.py b/src/pymanopt/numerics/_backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/pymanopt/numerics/_backends/numpy.py b/src/pymanopt/numerics/_backends/numpy.py new file mode 100644 index 000000000..b90a88c9e --- /dev/null +++ b/src/pymanopt/numerics/_backends/numpy.py @@ -0,0 +1,23 @@ +import numpy as np + +import pymanopt.numerics.core as nx + + +@nx.abs.register(np.ndarray) +def _(array): + return np.abs(array) + + +@nx.allclose.register(np.ndarray) +def _(array_a, array_b): + return np.allclose(array_a, array_b) + + +@nx.exp.register(np.ndarray) +def _(array): + return np.abs(array) + + +@nx.tensordot.register(np.ndarray) +def _(array_a, array_b, axes: int): + return np.tensordot(array_a, array_b, axes=axes) diff --git a/src/pymanopt/numerics/core.py b/src/pymanopt/numerics/core.py new file mode 100644 index 000000000..6acb394b2 --- /dev/null +++ b/src/pymanopt/numerics/core.py @@ -0,0 +1,36 @@ +import functools + + +def _not_implemented(function): + @functools.wraps(function) + def inner(*arguments): + raise TypeError( + f"Function '{function.__name__}' not implemented for arguments of " + f"type '{type(arguments[0])}'" + ) + + return inner + + +@functools.singledispatch +@_not_implemented +def abs(_): + pass + + +@functools.singledispatch +@_not_implemented +def allclose(*_): + pass + + +@functools.singledispatch +@_not_implemented +def exp(_): + pass + + +@functools.singledispatch +@_not_implemented +def tensordot(*_): + pass diff --git a/tests/numerics/__init__.py b/tests/numerics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/numerics/test_core.py b/tests/numerics/test_core.py new file mode 100644 index 000000000..8a1047793 --- /dev/null +++ b/tests/numerics/test_core.py @@ -0,0 +1,12 @@ +import numpy as np +import pytest + +import pymanopt.numerics as nx + + +@pytest.mark.parametrize( + "argument, expected_output", [(np.array([-4, 2]), np.array([4, 2]))] +) +def test_abs(argument, expected_output): + output = nx.abs(argument) + assert nx.allclose(output, expected_output) From ee0afc13539adb959f70e67f02172bf5a819ab1d Mon Sep 17 00:00:00 2001 From: Antoine Collas <22830806+antoinecollas@users.noreply.github.com> Date: Sun, 2 Jul 2023 12:07:24 +0200 Subject: [PATCH 2/4] fix nx.exp --- src/pymanopt/numerics/_backends/numpy.py | 2 +- tests/numerics/test_core.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/pymanopt/numerics/_backends/numpy.py b/src/pymanopt/numerics/_backends/numpy.py index b90a88c9e..45035bed2 100644 --- a/src/pymanopt/numerics/_backends/numpy.py +++ b/src/pymanopt/numerics/_backends/numpy.py @@ -15,7 +15,7 @@ def _(array_a, array_b): @nx.exp.register(np.ndarray) def _(array): - return np.abs(array) + return np.exp(array) @nx.tensordot.register(np.ndarray) diff --git a/tests/numerics/test_core.py b/tests/numerics/test_core.py index 8a1047793..bd0934605 100644 --- a/tests/numerics/test_core.py +++ b/tests/numerics/test_core.py @@ -10,3 +10,11 @@ def test_abs(argument, expected_output): output = nx.abs(argument) assert nx.allclose(output, expected_output) + + +@pytest.mark.parametrize( + "argument, expected_output", [(np.array([-4, 2]), np.exp([-4, 2]))] +) +def test_exp(argument, expected_output): + output = nx.exp(argument) + assert nx.allclose(output, expected_output) From 09b3881cb86c8144f1cf797f1e2a28327135cef7 Mon Sep 17 00:00:00 2001 From: Niklas Koep Date: Tue, 12 Sep 2023 22:40:49 +0200 Subject: [PATCH 3/4] Add stubs for each backend Signed-off-by: Niklas Koep --- src/pymanopt/numerics/__init__.py | 7 ++- src/pymanopt/numerics/_backends/jax.py | 25 +++++++++ src/pymanopt/numerics/_backends/numpy.py | 20 ++++--- src/pymanopt/numerics/_backends/pytorch.py | 32 +++++++++++ src/pymanopt/numerics/_backends/tensorflow.py | 24 ++++++++ tests/numerics/test_core.py | 56 ++++++++++++++++++- 6 files changed, 152 insertions(+), 12 deletions(-) create mode 100644 src/pymanopt/numerics/_backends/jax.py create mode 100644 src/pymanopt/numerics/_backends/pytorch.py create mode 100644 src/pymanopt/numerics/_backends/tensorflow.py diff --git a/src/pymanopt/numerics/__init__.py b/src/pymanopt/numerics/__init__.py index 1412f859b..ddddf07cd 100644 --- a/src/pymanopt/numerics/__init__.py +++ b/src/pymanopt/numerics/__init__.py @@ -6,7 +6,10 @@ from pymanopt.numerics.core import abs, allclose, exp, tensordot -def _register_backends(): +FUNCTIONS = [abs, allclose, exp, tensordot] + + +def register_backends(): for backend in ["numpy", "jax", "pytorch", "tensorflow"]: try: importlib.import_module(f"pymanopt.numerics._backends.{backend}") @@ -14,4 +17,4 @@ def _register_backends(): pass -_register_backends() +register_backends() diff --git a/src/pymanopt/numerics/_backends/jax.py b/src/pymanopt/numerics/_backends/jax.py new file mode 100644 index 000000000..f68fabbaf --- /dev/null +++ b/src/pymanopt/numerics/_backends/jax.py @@ -0,0 +1,25 @@ +import jax.numpy as jnp + +import pymanopt.numerics.core as nx + + +@nx.abs.register +def _(array: jnp.ndarray) -> jnp.ndarray: + return jnp.abs(array) + + +@nx.allclose.register +def _(array_a: jnp.ndarray, array_b: jnp.ndarray) -> bool: + return jnp.allclose(array_a, array_b) + + +@nx.exp.register +def _(array: jnp.ndarray) -> jnp.ndarray: + return jnp.exp(array) + + +@nx.tensordot.register +def _( + array_a: jnp.ndarray, array_b: jnp.ndarray, *, axes: int = 2 +) -> jnp.ndarray: + return jnp.tensordot(array_a, array_b, axes=axes) diff --git a/src/pymanopt/numerics/_backends/numpy.py b/src/pymanopt/numerics/_backends/numpy.py index b90a88c9e..1a28866af 100644 --- a/src/pymanopt/numerics/_backends/numpy.py +++ b/src/pymanopt/numerics/_backends/numpy.py @@ -3,21 +3,23 @@ import pymanopt.numerics.core as nx -@nx.abs.register(np.ndarray) -def _(array): +@nx.abs.register +def _(array: np.ndarray) -> np.ndarray: return np.abs(array) -@nx.allclose.register(np.ndarray) -def _(array_a, array_b): +@nx.allclose.register +def _(array_a: np.ndarray, array_b: np.ndarray) -> bool: return np.allclose(array_a, array_b) -@nx.exp.register(np.ndarray) -def _(array): - return np.abs(array) +@nx.exp.register +def _(array: np.ndarray) -> np.ndarray: + return np.exp(array) -@nx.tensordot.register(np.ndarray) -def _(array_a, array_b, axes: int): +@nx.tensordot.register +def _( + array_a: np.ndarray, array_b: np.ndarray, *, axes: int = 2 +) -> np.ndarray: return np.tensordot(array_a, array_b, axes=axes) diff --git a/src/pymanopt/numerics/_backends/pytorch.py b/src/pymanopt/numerics/_backends/pytorch.py new file mode 100644 index 000000000..b79bd7999 --- /dev/null +++ b/src/pymanopt/numerics/_backends/pytorch.py @@ -0,0 +1,32 @@ +import typing + +import torch + +import pymanopt.numerics.core as nx + + +@nx.abs.register +def _(array: torch.Tensor) -> torch.Tensor: + return torch.abs(array) + + +@nx.allclose.register +def _( + array_a: torch.Tensor, array_b: typing.Union[torch.Tensor, float, int] +) -> bool: + # PyTorch does not automatically coerce values to tensors. + if isinstance(array_b, (float, int)): + array_b = torch.Tensor([array_b]) + return torch.allclose(array_a, array_b) + + +@nx.exp.register +def _(array: torch.Tensor) -> torch.Tensor: + return torch.exp(array) + + +@nx.tensordot.register +def _( + array_a: torch.Tensor, array_b: torch.Tensor, *, axes: int = 2 +) -> torch.Tensor: + return torch.tensordot(array_a, array_b, dims=axes) diff --git a/src/pymanopt/numerics/_backends/tensorflow.py b/src/pymanopt/numerics/_backends/tensorflow.py new file mode 100644 index 000000000..61df4853d --- /dev/null +++ b/src/pymanopt/numerics/_backends/tensorflow.py @@ -0,0 +1,24 @@ +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +import pymanopt.numerics.core as nx + + +@nx.abs.register +def _(array: tf.Tensor) -> tf.Tensor: + return tnp.abs(array) + + +@nx.allclose.register +def _(array_a: tf.Tensor, array_b: tf.Tensor) -> bool: + return tnp.allclose(array_a, array_b) + + +@nx.exp.register +def _(array: tf.Tensor) -> tf.Tensor: + return tnp.exp(array) + + +@nx.tensordot.register +def _(array_a: tf.Tensor, array_b: tf.Tensor, *, axes: int = 2) -> tf.Tensor: + return tnp.tensordot(array_a, array_b, axes=axes) diff --git a/tests/numerics/test_core.py b/tests/numerics/test_core.py index 8a1047793..2c47fa08e 100644 --- a/tests/numerics/test_core.py +++ b/tests/numerics/test_core.py @@ -1,12 +1,66 @@ +import jax.numpy as jnp import numpy as np import pytest +import tensorflow as tf +import torch import pymanopt.numerics as nx @pytest.mark.parametrize( - "argument, expected_output", [(np.array([-4, 2]), np.array([4, 2]))] + "argument, expected_output", + [ + (np.array([-4, 2]), np.array([4, 2])), + (jnp.array([-4, 2]), jnp.array([4, 2])), + (torch.Tensor([-4, 2]), torch.Tensor([4, 2])), + (tf.constant([-4, 2]), tf.constant([4, 2])), + ], ) def test_abs(argument, expected_output): output = nx.abs(argument) assert nx.allclose(output, expected_output) + + +@pytest.mark.parametrize( + "argument_a, argument_b, expected_output", + [ + (np.array([4, 2]), np.array([4, 2]), True), + (np.array([4, 2]), np.array([2, 4]), False), + (jnp.array([4, 2]), jnp.array([4, 2]), True), + (jnp.array([4, 2]), jnp.array([2, 4]), False), + (torch.Tensor([4, 2]), torch.Tensor([4, 2]), True), + (torch.Tensor([4, 2]), torch.Tensor([2, 4]), False), + (tf.constant([4, 2]), tf.constant([4, 2]), True), + (tf.constant([4, 2]), tf.constant([2, 4]), False), + ], +) +def test_allclose(argument_a, argument_b, expected_output): + assert nx.allclose(argument_a, argument_b) == expected_output + + +@pytest.mark.parametrize( + "argument, expected_output", + [ + (np.log(np.array([4, 2])), np.array([4, 2])), + (jnp.log(jnp.array([4, 2])), jnp.array([4, 2])), + (torch.log(torch.Tensor([4, 2])), torch.Tensor([4, 2])), + (tf.math.log(tf.constant([4.0, 2.0])), tf.constant([4.0, 2.0])), + ], +) +def test_exp(argument, expected_output): + output = nx.exp(argument) + assert nx.allclose(output, expected_output) + + +@pytest.mark.parametrize( + "argument_a, argument_b, expected_output", + [ + (np.array([-4, 2]), np.array([1, 3]), 2), + (jnp.array([-4, 2]), jnp.array([1, 3]), 2), + (torch.Tensor([-4, 2]), torch.Tensor([1, 3]), 2), + (tf.constant([-4, 2]), tf.constant([1, 3]), 2), + ], +) +def test_tensordot(argument_a, argument_b, expected_output): + output = nx.tensordot(argument_a, argument_b, axes=argument_a.ndim) + assert nx.allclose(output, expected_output) From 3b62854b1f641dcabd3a7b9ccdf20913931c2d29 Mon Sep 17 00:00:00 2001 From: Niklas Koep Date: Tue, 12 Sep 2023 22:55:01 +0200 Subject: [PATCH 4/4] Remove unused 'FUNCTIONS' variable Signed-off-by: Niklas Koep --- src/pymanopt/numerics/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/pymanopt/numerics/__init__.py b/src/pymanopt/numerics/__init__.py index ddddf07cd..72ba2f363 100644 --- a/src/pymanopt/numerics/__init__.py +++ b/src/pymanopt/numerics/__init__.py @@ -6,9 +6,6 @@ from pymanopt.numerics.core import abs, allclose, exp, tensordot -FUNCTIONS = [abs, allclose, exp, tensordot] - - def register_backends(): for backend in ["numpy", "jax", "pytorch", "tensorflow"]: try: