Skip to content

Commit

Permalink
Merge d749708 into acb52b2
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoep committed Jul 2, 2023
2 parents acb52b2 + d749708 commit 37c4a94
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/pymanopt/manifolds/group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import scipy.special

import pymanopt.numerics as nx
from pymanopt.manifolds.manifold import RiemannianSubmanifold
from pymanopt.tools import extend_docstring
from pymanopt.tools.multi import (
Expand Down Expand Up @@ -28,7 +29,7 @@ def __init__(self, name, dimension, retraction):
raise ValueError(f"Invalid retraction type '{retraction}'")

def inner_product(self, point, tangent_vector_a, tangent_vector_b):
return np.tensordot(
return nx.tensordot(
tangent_vector_a.conj(),
tangent_vector_b,
axes=tangent_vector_a.ndim,
Expand Down Expand Up @@ -222,7 +223,7 @@ def random_point(self):
n, k = self._n, self._k
if n == 1:
point = np.ones((k, 1, 1)) + 1j * np.ones((k, 1, 1))
point /= np.abs(point)
point /= nx.abs(point)
else:
point, _ = multiqr(
np.random.normal(size=(k, n, n))
Expand Down
17 changes: 17 additions & 0 deletions src/pymanopt/numerics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
__all__ = ["abs", "allclose", "exp", "tensordot"]


import importlib

from pymanopt.numerics.core import abs, allclose, exp, tensordot


def _register_backends():
for backend in ["numpy", "jax", "pytorch", "tensorflow"]:
try:
importlib.import_module(f"pymanopt.numerics._backends.{backend}")
except ImportError:
pass


_register_backends()
Empty file.
23 changes: 23 additions & 0 deletions src/pymanopt/numerics/_backends/numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np

import pymanopt.numerics.core as nx


@nx.abs.register(np.ndarray)
def _(array):
return np.abs(array)


@nx.allclose.register(np.ndarray)
def _(array_a, array_b):
return np.allclose(array_a, array_b)


@nx.exp.register(np.ndarray)
def _(array):
return np.abs(array)


@nx.tensordot.register(np.ndarray)
def _(array_a, array_b, axes: int):
return np.tensordot(array_a, array_b, axes=axes)
36 changes: 36 additions & 0 deletions src/pymanopt/numerics/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import functools


def _not_implemented(function):
@functools.wraps(function)
def inner(*arguments):
raise TypeError(
f"Function '{function.__name__}' not implemented for arguments of "
f"type '{type(arguments[0])}'"
)

return inner


@functools.singledispatch
@_not_implemented
def abs(_):
pass


@functools.singledispatch
@_not_implemented
def allclose(*_):
pass


@functools.singledispatch
@_not_implemented
def exp(_):
pass


@functools.singledispatch
@_not_implemented
def tensordot(*_):
pass
Empty file added tests/numerics/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions tests/numerics/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np
import pytest

import pymanopt.numerics as nx


@pytest.mark.parametrize(
"argument, expected_output", [(np.array([-4, 2]), np.array([4, 2]))]
)
def test_abs(argument, expected_output):
output = nx.abs(argument)
assert nx.allclose(output, expected_output)

0 comments on commit 37c4a94

Please sign in to comment.