Skip to content

Commit

Permalink
Merge pull request #125 from antoinecollas/complex_valued_manifolds
Browse files Browse the repository at this point in the history
Add Complex valued manifolds.
  • Loading branch information
antoinecollas committed Sep 12, 2023
2 parents acb52b2 + 2da40f2 commit 5314bab
Show file tree
Hide file tree
Showing 7 changed files with 1,024 additions and 402 deletions.
11 changes: 9 additions & 2 deletions src/pymanopt/manifolds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
"ComplexGrassmann",
"Elliptope",
"Euclidean",
"ComplexEuclidean",
"FixedRankEmbedded",
"Grassmann",
"HermitianPositiveDefinite",
"SpecialHermitianPositiveDefinite",
"Oblique",
"PSDFixedRank",
"PSDFixedRankComplex",
Expand All @@ -23,14 +26,18 @@
]

from .complex_circle import ComplexCircle
from .euclidean import Euclidean, SkewSymmetric, Symmetric
from .euclidean import ComplexEuclidean, Euclidean, SkewSymmetric, Symmetric
from .fixed_rank import FixedRankEmbedded
from .grassmann import ComplexGrassmann, Grassmann
from .group import SpecialOrthogonalGroup, UnitaryGroup
from .hyperbolic import PoincareBall
from .oblique import Oblique
from .positive import Positive
from .positive_definite import SymmetricPositiveDefinite
from .positive_definite import (
HermitianPositiveDefinite,
SpecialHermitianPositiveDefinite,
SymmetricPositiveDefinite,
)
from .product import Product
from .psd import Elliptope, PSDFixedRank, PSDFixedRankComplex
from .sphere import (
Expand Down
47 changes: 45 additions & 2 deletions src/pymanopt/manifolds/euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ def typical_dist(self):

def inner_product(self, point, tangent_vector_a, tangent_vector_b):
return float(
np.tensordot(
tangent_vector_a, tangent_vector_b, axes=tangent_vector_a.ndim
np.real(
np.tensordot(
tangent_vector_a.conj(),
tangent_vector_b,
axes=tangent_vector_a.ndim,
)
)
)

Expand Down Expand Up @@ -92,6 +96,45 @@ def __init__(self, *shape: int):
super().__init__(name, dimension, *shape)


class ComplexEuclidean(_Euclidean):
r"""Complex Euclidean manifold.
Args:
shape: Shape of points on the manifold.
Note:
If ``shape == (n,)``, this is the manifold of vectors with the
standard Euclidean inner product, i.e., :math:`\C^n`.
For ``shape == (m, n)``, it corresponds to the manifold of ``m x n``
matrices equipped with the standard trace inner product.
For ``shape == (n1, n2, ..., nk)``, the class represents the manifold
of tensors of shape ``n1 x n2 x ... x nk`` with the inner product
corresponding to the usual tensor dot product.
"""

def __init__(self, *shape):
if len(shape) == 0:
raise TypeError("Need shape parameters")
if len(shape) == 1:
(n1,) = shape
name = f"Complex Euclidean manifold of {n1}-vectors"
elif len(shape) == 2:
n1, n2 = shape
name = f"Complex Euclidean manifold of {n1}x{n2} matrices"
else:
name = f"Complex Euclidean manifold of shape {shape} tensors"
dimension = 2 * np.prod(shape)
super().__init__(name, dimension, *shape)

def random_point(self):
return np.random.randn(*self._shape) + 1j * np.random.randn(
*self._shape
)

def zero_vector(self, point):
return np.zeros(self._shape, dtype=complex)


class Symmetric(_Euclidean):
"""(Product) manifold of symmetric matrices.
Expand Down
22 changes: 11 additions & 11 deletions src/pymanopt/manifolds/manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np


def _raise_not_implemented_error(method):
def raise_not_implemented_error(method):
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
raise NotImplementedError(
Expand Down Expand Up @@ -206,7 +206,7 @@ def zero_vector(self, point):

# Methods which are only required by certain optimizers.

@_raise_not_implemented_error
@raise_not_implemented_error
def dist(self, point_a, point_b):
"""The geodesic distance between two points on the manifold.
Expand All @@ -218,7 +218,7 @@ def dist(self, point_a, point_b):
The distance between ``point_a`` and ``point_b`` on the manifold.
"""

@_raise_not_implemented_error
@raise_not_implemented_error
def euclidean_to_riemannian_gradient(self, point, euclidean_gradient):
"""Converts the Euclidean to the Riemannian gradient.
Expand All @@ -233,7 +233,7 @@ def euclidean_to_riemannian_gradient(self, point, euclidean_gradient):
This must be a tangent vector at ``point``.
"""

@_raise_not_implemented_error
@raise_not_implemented_error
def euclidean_to_riemannian_hessian(
self, point, euclidean_gradient, euclidean_hessian, tangent_vector
):
Expand All @@ -257,7 +257,7 @@ def euclidean_to_riemannian_hessian(
The Riemannian Hessian as a tangent vector at ``point``.
"""

@_raise_not_implemented_error
@raise_not_implemented_error
def retraction(self, point, tangent_vector):
"""Retracts a tangent vector back to the manifold.
Expand All @@ -275,7 +275,7 @@ def retraction(self, point, tangent_vector):
the direction of ``tangent_vector``.
"""

@_raise_not_implemented_error
@raise_not_implemented_error
def exp(self, point, tangent_vector):
"""Computes the exponential map on the manifold.
Expand All @@ -288,7 +288,7 @@ def exp(self, point, tangent_vector):
along a geodesic in the direction of ``tangent_vector``.
"""

@_raise_not_implemented_error
@raise_not_implemented_error
def log(self, point_a, point_b):
"""Computes the logarithmic map on the manifold.
Expand All @@ -306,7 +306,7 @@ def log(self, point_a, point_b):
A tangent vector in the tangent space at ``point_a``.
"""

@_raise_not_implemented_error
@raise_not_implemented_error
def transport(self, point_a, point_b, tangent_vector_a):
"""Compute transport of tangent vectors between tangent spaces.
Expand All @@ -326,7 +326,7 @@ def transport(self, point_a, point_b, tangent_vector_a):
A tangent vector at ``point_b``.
"""

@_raise_not_implemented_error
@raise_not_implemented_error
def pair_mean(self, point_a, point_b):
"""Computes the intrinsic mean of two points on the manifold.
Expand All @@ -342,7 +342,7 @@ def pair_mean(self, point_a, point_b):
The mid-way point between ``point_a`` and ``point_b``.
"""

@_raise_not_implemented_error
@raise_not_implemented_error
def to_tangent_space(self, point, vector):
"""Re-tangentialize a vector.
Expand Down Expand Up @@ -411,7 +411,7 @@ class RiemannianSubmanifold(Manifold, metaclass=abc.ABCMeta):
the notes in section 5.11 of [Bou2020]_.
"""

@_raise_not_implemented_error
@raise_not_implemented_error
def weingarten(self, point, tangent_vector, normal_vector):
"""Compute the Weingarten map of the manifold.
Expand Down
Loading

0 comments on commit 5314bab

Please sign in to comment.