Skip to content

Commit

Permalink
Add product of strictly postive vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinecollas committed Apr 22, 2020
1 parent 21b73a9 commit 5d943f5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 23 deletions.
19 changes: 5 additions & 14 deletions pymanopt/manifolds/strictly_positive_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def typicaldist(self):

def inner(self, x, u, v):
inv_x = (1./x)
return np.tensordot(inv_x*u, inv_x*v, axes=(-1, -1))
return np.sum(inv_x*u*inv_x*v, axis=0, keepdims=True)

def proj(self, x, u):
return u
Expand All @@ -38,26 +38,17 @@ def norm(self, x, u):
return np.sqrt(self.inner(x, u, u))

def rand(self):
if self._k == 1:
return rnd.uniform(low=1e-6, high=1, size=(self._n))
return rnd.uniform(low=1e-6, high=1, size=(self._k, self._n))
return rnd.uniform(low=1e-6, high=1, size=(self._n, self._k))

def randvec(self, x):
if self._k == 1:
u = rnd.randn(self._n)
else:
u = rnd.randn(self._k, self._n)
u = rnd.randn(self._n, self._k)
return u / self.norm(x, u)

def zerovec(self, x):
k = self._k
n = self._n
if k == 1:
return np.zeros(n)
return np.zeros(k, n)
return np.zeros(self._n, self._k)

def dist(self, x, y):
return la.norm(np.log(x)-np.log(y), axis=0)
return la.norm(np.log(x)-np.log(y), axis=0, keepdims=True)

egrad2rgrad = proj

Expand Down
17 changes: 8 additions & 9 deletions tests/test_manifolds/test_strictly_positive_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@
from .._test import TestCase


class TestSingleStrictlyPositiveVectors(TestCase):
class TestStrictlyPositiveVectors(TestCase):
def setUp(self):
self.n = n = 3
self.k = k = 1
self.k = k = 2
self.man = StrictlyPositiveVectors(n, k=k)

def test_inner(self):
x = self.man.rand()
g = self.man.randvec(x)
h = self.man.randvec(x)
inv_x = 1./x
np_testing.assert_almost_equal(np.sum(inv_x*g*inv_x*h),
self.man.inner(x, g, h))
assert (self.man.inner(x, g, h).shape == np.array([1, self.k])).all()

def test_proj(self):
# Test proj(proj(X)) == proj(X)
Expand All @@ -32,25 +30,26 @@ def test_proj(self):
def test_norm(self):
x = self.man.rand()
u = self.man.randvec(x)
inv_x = 1./x
x_u = (1./x) * u
np_testing.assert_almost_equal(
np.sqrt(np.sum(inv_x*u*inv_x*u)), self.man.norm(x, u))
la.norm(x_u, axis=0, keepdims=True),
self.man.norm(x, u))

def test_rand(self):
# Just make sure that things generated are on the manifold
# and that if you generate two they are not equal.
x = self.man.rand()
assert (x > 0).all()
y = self.man.rand()
assert self.man.dist(x, y) > 1e-6
assert (self.man.dist(x, y)).all() > 1e-6

def test_randvec(self):
# Just make sure that if you generate two they are not equal.
# check also if unit norm
x = self.man.rand()
g = self.man.randvec(x)
h = self.man.randvec(x)
assert la.norm(g - h) > 1e-6
assert (la.norm(g-h, axis=0) > 1e-6).all()
np_testing.assert_almost_equal(self.man.norm(x, g), 1)

def test_dist(self):
Expand Down

0 comments on commit 5d943f5

Please sign in to comment.