In [1]:
import torch, math
from HTorch.manifolds import Euclidean, PoincareBall, Lorentz, HalfSpace, Sphere
import HTorch
sq_norm = HTorch.utils.sq_norm
arcosh = HTorch.utils.arcosh
arsinh = HTorch.utils.arsinh

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
manifold_maps = {'euclidean':Euclidean, 'poincare':PoincareBall, 'lorentz':Lorentz, 'halfspace':HalfSpace, 'Sphere':Sphere}

In [125]:
poin = PoincareBall()
lore = Lorentz()
hal = HalfSpace()
sph = Sphere()
c=1.2

### Initialize poincare space points

In [149]:
q = 0.1
xp = torch.tensor([0.3, 0.3, 0.6])
ori = torch.tensor([0.0, 0.0, 0.0])
yp = q * xp
print(sq_norm(xp), 1.0 / c)
print(poin.distance(xp, yp, c))

tensor([0.5400]) 0.8333333333333334
tensor([1.8841])


In [155]:
def Hyx(x, y, c):
    return sq_norm(y) + 1/c - (1.0 + 1.0 / (c * sq_norm(x))) * (x * y).sum(dim=-1)

def pro_dist(x, y, c):
    hyx = Hyx(x, y, c)
    nom = 2 * c ** 1.5 * hyx * sq_norm(x) ** 0.5
    denom = (1.0 - c * sq_norm(x)) * (1.0 - c * sq_norm(y))
    return arsinh(nom / denom) / c ** 0.5

def hori_dist(x, y, c):
    theta = torch.acos((x*y).sum(dim=-1)/(sq_norm(x)**0.5 * sq_norm(y)**0.5).clamp_(min=1e-8))
    nom = 2 * c ** 0.5 * sq_norm(y)**0.5 * torch.sin(theta)
    denom = 1.0 - c * sq_norm(y)
    return arsinh(nom / denom) / c ** 0.5

In [156]:
hori_dist(ori, xp, c)

tensor([2.0313])

In [158]:
poin.distance(ori, xp, c)

tensor([2.0313])

In [None]:
xp = torch.tensor([0.2, 0.3, 0.1])
yp = torch.tensor([0.39, 0.41, 0.5])
print(poin.distance(xp, yp, c))

In [5]:
# poincare exp map
op = torch.tensor([0.0,0.0,0.0])
ap = xp.clone().requires_grad_()
print(ap)

tensor([0.2000, 0.3000, 0.1000], requires_grad=True)


In [6]:
dist_ap = poin.distance(ap, op, c)
print(dist_ap)
dist_ap.backward()
print(ap.grad)

tensor(0.7950, grad_fn=<DivBackward0>)
tensor([1.2849, 1.9274, 0.6425])


In [7]:
grad = poin.egrad2rgrad(ap, ap.grad, c)
print(grad)
bp = poin.expmap(ap, grad, c)
print(bp, poin.distance(bp, op, c))

tensor([0.2224, 0.3335, 0.1112])
tensor([0.3681, 0.5522, 0.1841], grad_fn=<DivBackward0>) tensor(1.7950, grad_fn=<DivBackward0>)


In [8]:
print(poin.norm_t(grad, ap, c))

tensor([1.0000])


### convert poincare to lorentz

In [9]:
xl = poin.to_lorentz(xp, c)
yl = poin.to_lorentz(yp, c)
print(lore.distance(xl, yl, c))

tensor([1.5769])


convert lorentz back to poincare to check

In [10]:
lore.to_poincare(xl, c), lore.to_poincare(yl, c)

(tensor([0.2000, 0.3000, 0.1000]), tensor([0.3900, 0.4100, 0.5000]))

In [11]:
# lorentz exp map
ol = torch.tensor([0.0, 0.0, 0.0, 1/math.sqrt(c)])
al = xl.clone().requires_grad_()
print(al)

tensor([0.4808, 0.7212, 0.2404, 1.2815], requires_grad=True)


In [12]:
dist_al = lore.distance(al, ol, c)
print(dist_al)
dist_al.backward()
print(al.grad)

tensor([0.7950], grad_fn=<ClampBackward1>)
tensor([0.0000, 0.0000, 0.0000, 1.0149])


In [13]:
grad = lore.egrad2rgrad(al, al.grad, c)
print(grad)
bl = lore.expmap(al, grad, c)
print(bl, lore.distance(bl, ol, c))

tensor([0.7504, 1.1256, 0.3752, 0.9853], grad_fn=<AddcmulBackward0>)
tensor([1.7090, 2.5635, 0.8545, 3.3250], grad_fn=<AddBackward0>) tensor([1.7950], grad_fn=<ClampBackward1>)


In [14]:
print(lore.norm_t(grad))

tensor([1.0000], grad_fn=<SqrtBackward0>)


### convert poincare to halfspace

In [15]:
xh = poin.to_halfspace(xp, c)
yh = poin.to_halfspace(yp, c)
print(hal.distance(xh, yh, c))

tensor([1.5769])


convert halfspace back to check

In [16]:
hal.to_poincare(xh, c), hal.to_poincare(yh, c)

(tensor([0.2000, 0.3000, 0.1000]), tensor([0.3900, 0.4100, 0.5000]))

In [17]:
# halfspace exp map
oh = torch.tensor([0,0, 1/math.sqrt(c)])
ah = xh.clone().requires_grad_()
print(ah)

tensor([0.4215, 0.6323, 0.8004], requires_grad=True)


In [18]:
dist_ah = hal.distance(ah, oh, c)
print(dist_ah)
dist_ah.backward()
print(ah.grad)

tensor([0.7950], grad_fn=<DivBackward0>)
tensor([ 0.5345,  0.8018, -0.6101])


In [19]:
ah

tensor([0.4215, 0.6323, 0.8004], requires_grad=True)

In [20]:
grad = hal.egrad2rgrad(ah, ah.grad, c)
print('grad', grad)
bh = hal.expmap(ah, grad, c)
print(bh, hal.distance(bh, oh, c))

grad tensor([ 0.4109,  0.6164, -0.4690], grad_fn=<MulBackward0>)
tensor([0.6315, 0.9472, 0.3373], grad_fn=<CatBackward0>) tensor([1.7950], grad_fn=<DivBackward0>)


In [21]:
print(hal.norm_t(grad, ah, c))

tensor([1.], grad_fn=<SqrtBackward0>)


### test exp/log map in halfspace model

In [22]:
v = torch.tensor([0.01, 0.02, 0.001])

In [23]:
xxh = hal.expmap(xh, v, c)
print(xxh)

tensor([0.4315, 0.6523, 0.8011])


In [24]:
hal.distance(xh, xxh, c)

tensor([0.0255])

In [25]:
hal.norm_t(v, xh, c)

tensor([0.0255])

In [26]:
hal.logmap(xh, xxh, c)

tensor([0.0100, 0.0200, 0.0010])

### test sphere exp/log map

In [4]:
c=2.0
xp = torch.tensor([0.2, 0.3, 0.1])
yp = torch.tensor([0.39, 0.41, 0.5])
xp = sph.proj(xp, c)
yp = sph.proj(yp, c)
print(xp, yp)
print(sph.distance(xp, yp, c))

tensor([0.3780, 0.5669, 0.1890]) tensor([0.3652, 0.3839, 0.4682])
tensor(0.3373)


In [5]:
# sphere exp map
xp.requires_grad_()
print(xp)

tensor([0.3780, 0.5669, 0.1890], requires_grad=True)


In [6]:
dist_ap = sph.distance(xp, yp, c)
print(dist_ap)
dist_ap.backward()
print(xp.grad)

tensor(0.3373, grad_fn=<MulBackward0>)
tensor([-1.1249, -1.1826, -1.4422])


In [7]:
grad = sph.egrad2rgrad(xp, xp.grad, c)
print(grad)
bp = sph.expmap(xp, grad, c)
print(bp, sph.distance(bp, yp, c))

tensor([-0.0907,  0.3688, -0.9251])
tensor([-0.0044,  0.3460, -0.6167], grad_fn=<AddBackward0>) tensor(1.3373, grad_fn=<MulBackward0>)


In [15]:
sph.check(bp, c)

tensor(True)

In [8]:
print(sph.norm_t(grad, xp, c))

tensor([1.0000])


In [9]:
v = torch.tensor([0.01, 0.02, 0.003])
v = sph.egrad2rgrad(yp, v, c)
print(v)

tensor([ 0.0007,  0.0102, -0.0089])


In [10]:
yyp = sph.expmap(yp, v, c)
print(yp, yyp)

tensor([0.3652, 0.3839, 0.4682]) tensor([0.3658, 0.3941, 0.4592])


In [11]:
sph.distance(yp, yyp, c), sph.norm_t(v, yp, c)

(tensor(0.0136), tensor([0.0136]))

In [12]:
sph.logmap(yp, yyp, c)

tensor([ 0.0007,  0.0102, -0.0089])