Skip to content

Commit

Permalink
Merge 3b62854 into 5314bab
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoep committed Sep 12, 2023
2 parents 5314bab + 3b62854 commit 9f304b1
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/pymanopt/manifolds/group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import scipy.special

import pymanopt.numerics as nx
from pymanopt.manifolds.manifold import RiemannianSubmanifold
from pymanopt.tools import extend_docstring
from pymanopt.tools.multi import (
Expand Down Expand Up @@ -28,7 +29,7 @@ def __init__(self, name, dimension, retraction):
raise ValueError(f"Invalid retraction type '{retraction}'")

def inner_product(self, point, tangent_vector_a, tangent_vector_b):
return np.tensordot(
return nx.tensordot(
tangent_vector_a.conj(),
tangent_vector_b,
axes=tangent_vector_a.ndim,
Expand Down Expand Up @@ -222,7 +223,7 @@ def random_point(self):
n, k = self._n, self._k
if n == 1:
point = np.ones((k, 1, 1)) + 1j * np.ones((k, 1, 1))
point /= np.abs(point)
point /= nx.abs(point)
else:
point, _ = multiqr(
np.random.normal(size=(k, n, n))
Expand Down
17 changes: 17 additions & 0 deletions src/pymanopt/numerics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
__all__ = ["abs", "allclose", "exp", "tensordot"]


import importlib

from pymanopt.numerics.core import abs, allclose, exp, tensordot


def register_backends():
for backend in ["numpy", "jax", "pytorch", "tensorflow"]:
try:
importlib.import_module(f"pymanopt.numerics._backends.{backend}")
except ImportError:
pass


register_backends()
Empty file.
25 changes: 25 additions & 0 deletions src/pymanopt/numerics/_backends/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import jax.numpy as jnp

import pymanopt.numerics.core as nx


@nx.abs.register
def _(array: jnp.ndarray) -> jnp.ndarray:
return jnp.abs(array)


@nx.allclose.register
def _(array_a: jnp.ndarray, array_b: jnp.ndarray) -> bool:
return jnp.allclose(array_a, array_b)


@nx.exp.register
def _(array: jnp.ndarray) -> jnp.ndarray:
return jnp.exp(array)


@nx.tensordot.register
def _(
array_a: jnp.ndarray, array_b: jnp.ndarray, *, axes: int = 2
) -> jnp.ndarray:
return jnp.tensordot(array_a, array_b, axes=axes)
25 changes: 25 additions & 0 deletions src/pymanopt/numerics/_backends/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np

import pymanopt.numerics.core as nx


@nx.abs.register
def _(array: np.ndarray) -> np.ndarray:
return np.abs(array)


@nx.allclose.register
def _(array_a: np.ndarray, array_b: np.ndarray) -> bool:
return np.allclose(array_a, array_b)


@nx.exp.register
def _(array: np.ndarray) -> np.ndarray:
return np.exp(array)


@nx.tensordot.register
def _(
array_a: np.ndarray, array_b: np.ndarray, *, axes: int = 2
) -> np.ndarray:
return np.tensordot(array_a, array_b, axes=axes)
32 changes: 32 additions & 0 deletions src/pymanopt/numerics/_backends/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import typing

import torch

import pymanopt.numerics.core as nx


@nx.abs.register
def _(array: torch.Tensor) -> torch.Tensor:
return torch.abs(array)


@nx.allclose.register
def _(
array_a: torch.Tensor, array_b: typing.Union[torch.Tensor, float, int]
) -> bool:
# PyTorch does not automatically coerce values to tensors.
if isinstance(array_b, (float, int)):
array_b = torch.Tensor([array_b])
return torch.allclose(array_a, array_b)


@nx.exp.register
def _(array: torch.Tensor) -> torch.Tensor:
return torch.exp(array)


@nx.tensordot.register
def _(
array_a: torch.Tensor, array_b: torch.Tensor, *, axes: int = 2
) -> torch.Tensor:
return torch.tensordot(array_a, array_b, dims=axes)
24 changes: 24 additions & 0 deletions src/pymanopt/numerics/_backends/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

import pymanopt.numerics.core as nx


@nx.abs.register
def _(array: tf.Tensor) -> tf.Tensor:
return tnp.abs(array)


@nx.allclose.register
def _(array_a: tf.Tensor, array_b: tf.Tensor) -> bool:
return tnp.allclose(array_a, array_b)


@nx.exp.register
def _(array: tf.Tensor) -> tf.Tensor:
return tnp.exp(array)


@nx.tensordot.register
def _(array_a: tf.Tensor, array_b: tf.Tensor, *, axes: int = 2) -> tf.Tensor:
return tnp.tensordot(array_a, array_b, axes=axes)
36 changes: 36 additions & 0 deletions src/pymanopt/numerics/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import functools


def _not_implemented(function):
@functools.wraps(function)
def inner(*arguments):
raise TypeError(
f"Function '{function.__name__}' not implemented for arguments of "
f"type '{type(arguments[0])}'"
)

return inner


@functools.singledispatch
@_not_implemented
def abs(_):
pass


@functools.singledispatch
@_not_implemented
def allclose(*_):
pass


@functools.singledispatch
@_not_implemented
def exp(_):
pass


@functools.singledispatch
@_not_implemented
def tensordot(*_):
pass
Empty file added tests/numerics/__init__.py
Empty file.
66 changes: 66 additions & 0 deletions tests/numerics/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import jax.numpy as jnp
import numpy as np
import pytest
import tensorflow as tf
import torch

import pymanopt.numerics as nx


@pytest.mark.parametrize(
"argument, expected_output",
[
(np.array([-4, 2]), np.array([4, 2])),
(jnp.array([-4, 2]), jnp.array([4, 2])),
(torch.Tensor([-4, 2]), torch.Tensor([4, 2])),
(tf.constant([-4, 2]), tf.constant([4, 2])),
],
)
def test_abs(argument, expected_output):
output = nx.abs(argument)
assert nx.allclose(output, expected_output)


@pytest.mark.parametrize(
"argument_a, argument_b, expected_output",
[
(np.array([4, 2]), np.array([4, 2]), True),
(np.array([4, 2]), np.array([2, 4]), False),
(jnp.array([4, 2]), jnp.array([4, 2]), True),
(jnp.array([4, 2]), jnp.array([2, 4]), False),
(torch.Tensor([4, 2]), torch.Tensor([4, 2]), True),
(torch.Tensor([4, 2]), torch.Tensor([2, 4]), False),
(tf.constant([4, 2]), tf.constant([4, 2]), True),
(tf.constant([4, 2]), tf.constant([2, 4]), False),
],
)
def test_allclose(argument_a, argument_b, expected_output):
assert nx.allclose(argument_a, argument_b) == expected_output


@pytest.mark.parametrize(
"argument, expected_output",
[
(np.log(np.array([4, 2])), np.array([4, 2])),
(jnp.log(jnp.array([4, 2])), jnp.array([4, 2])),
(torch.log(torch.Tensor([4, 2])), torch.Tensor([4, 2])),
(tf.math.log(tf.constant([4.0, 2.0])), tf.constant([4.0, 2.0])),
],
)
def test_exp(argument, expected_output):
output = nx.exp(argument)
assert nx.allclose(output, expected_output)


@pytest.mark.parametrize(
"argument_a, argument_b, expected_output",
[
(np.array([-4, 2]), np.array([1, 3]), 2),
(jnp.array([-4, 2]), jnp.array([1, 3]), 2),
(torch.Tensor([-4, 2]), torch.Tensor([1, 3]), 2),
(tf.constant([-4, 2]), tf.constant([1, 3]), 2),
],
)
def test_tensordot(argument_a, argument_b, expected_output):
output = nx.tensordot(argument_a, argument_b, axes=argument_a.ndim)
assert nx.allclose(output, expected_output)

0 comments on commit 9f304b1

Please sign in to comment.