In [3]:
import os
import sys
sys.path.append("/home/romainlhardy/code/hyperbolic-cancer/Mixed-Curvature-Pathways")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import math

from pytorch.hyperbolic_parameter import PoincareParameter, EuclideanParameter, SphericalParameter, HyperboloidParameter

In [None]:
class Embedding(nn.Module):
    def __init__(self, dist_fn, param_cls, n, d, project=True, initialize=None, learn_scale=False, initial_scale=0.0):
        super().__init__()
        self.dist_fn = dist_fn # Distance function
        self.n = n # Number of nodes
        self.d = d # Dimension
        self.project = project # Whether to project the embeddings
        self.w = param_cls(data=initialize, sizes=(n, d)) # Embedding matrix
        self.scale_log = nn.Parameter(torch.tensor([initial_scale], dtype=torch.double), requires_grad=learn_scale)

    def scale(self):
        scale = torch.exp(self.scale_log)
        return scale

    def dist_idx(self, idx):
        wi = torch.index_select(self.w, 0, idx[:, 0])
        wj = torch.index_select(self.w, 0, idx[:, 1])
        d = self.dist_fn(wi, wj)
        return d * self.scale()

    def dist_row(self, i):
        m = self.w.size(0)
        return self.dist_fn(self.w[i, :].clone().unsqueeze(0).repeat(m, 1), self.w) * self.scale()

    def dist_matrix(self):
        m = self.w.size(0)
        rets = torch.zeros(m, m, dtype=torch.double)
        for i in range(m):
            rets[i, :] = self.dist_row(i)
        return rets

    def normalize(self):
        self.w.proj()

In [None]:
def acosh(x):
    """Inverse hyperbolic cosine."""
    return torch.log(x + torch.sqrt(x ** 2 - 1))

def dot_h(x, y):
    """Inner product in hyperbolic space."""
    return torch.sum(x * y, -1) - 2 * x[..., 0] * y[..., 0]

def dist_h(x, y):
    """Distance in hyperbolic space."""
    return acosh(torch.clamp(-dot_h(x, y), min=1.0 + 1e-8))

def dist_e(u, v):
    """Distance in Euclidean space."""
    return torch.norm(u - v, 2, dim=1)

def dist_s(u, v, eps=1e-9):
    """Distance in spherical space."""
    uu = SphericalParameter._proj(u)
    vv = SphericalParameter._proj(v)
    return torch.acos(torch.clamp(dot(uu, vv), -1+eps, 1-eps))

class ProductEmbedding(nn.Module):
    def __init__(
        self, 
        num_nodes, 
        h_dim, 
        h_copies=1, 
        e_dim=1, 
        e_copies=0, 
        s_dim=1, 
        s_copies=0, 
        project=True, 
        initialize=None, 
        learn_scale=False, 
        initial_scale=0.0, 
        absolute_loss=False, 
        logrel_loss=False, 
        dist_loss=False, 
        square_loss=False, 
        sym_loss=False, 
        exponential_rescale=None, 
        riemann=False
    ):
        super().__init__()
        self.num_nodes = num_nodes
        self.riemann = riemann

        self.H = nn.ModuleList([Embedding(dist_h, HyperboloidParameter, num_nodes, h_dim, project, initialize, learn_scale, initial_scale) for _ in range(h_copies)])
        self.E = nn.ModuleList([Embedding(dist_e, EuclideanParameter, num_nodes, e_dim, False, initialize, False, initial_scale) for _ in range(e_copies)])
        self.S = nn.ModuleList([Embedding(dist_s, SphericalParameter, num_nodes, s_dim, project, initialize, learn_scale, initial_scale) for _ in range(s_copies)])

        self.scale_params = [H.scale_log for H in self.H] + [E.scale_log for E in self.E] + [S.scale_log for S in self.S] if learn_scale else []
        self.hyp_params = [H.w for H in self.H]
        self.euc_params = [E.w for E in self.E]
        self.sph_params = [S.w for S in self.S]
        self.embed_params = [H.w for H in self.H] + [E.w for E in self.E] + [S.w for S in self.S]

        self.absolute_loss = absolute_loss
        self.logrel_loss = logrel_loss
        self.dist_loss = dist_loss
        self.square_loss = square_loss
        self.sym_loss = sym_loss

        self.exponential_rescale = exponential_rescale

    def all_attr(self, fn):
        H_attr = [fn(H) for H in self.H]
        E_attr = [fn(E) for E in self.E]
        S_attr = [fn(S) for S in self.S]
        return H_attr + E_attr + S_attr

    def embedding(self):
        return torch.cat(self.all_attr(lambda emb: emb.w.view(-1)))

    def scale(self):
        return self.all_attr(lambda emb: emb.scale())

    def dist_idx(self, idx):
        d = self.all_attr(lambda emb: emb.dist_idx(idx))
        if self.riemann:
            return torch.norm(torch.stack(d, 0), 2, dim=0)
        else:
            return sum(d)

    def dist_row(self, i):
        d = self.all_attr(lambda emb: emb.dist_row(i))
        if self.riemann:
            return torch.norm(torch.stack(d, 0), 2, dim=0)
        else:
            return sum(d)

    def dist_matrix(self):
        d = self.all_attr(lambda emb: emb.dist_matrix())
        if self.riemann:
            return torch.norm(torch.stack(d), 2, dim=0)
        else:
            return sum(d)

    def loss(self, _x):
        idx, values, w = _x
        d = self.dist_idx(idx)

        term_rescale = w

        if self.absolute_loss:
            loss = torch.sum(term_rescale*(d - values)**2)
        elif self.logrel_loss:
            loss = torch.sum(torch.log((d/values)**2)**2)
        elif self.dist_loss:
            loss = torch.sum(torch.abs(term_rescale*((d/values) - 1)))
        elif self.square_loss:
            loss = torch.sum(term_rescale*torch.abs((d/values)**2 - 1))
        else:
            l1 = torch.sum(term_rescale*((d/values) - 1)**2)
            l2 = torch.sum(term_rescale*((values/d) - 1)**2) if self.sym_loss else 0
            loss = l1 + l2
        return loss / values.size(0)

    def normalize(self):
        for H in self.H:
            H.normalize()
        for S in self.S:
            S.normalize()