In [12]:
import networkx
import geoopt
import torch
import torch.nn as nn
import numpy as np
import random
import logging

In [80]:
class ManifoldEmbedding(nn.Module):
    
    def __init__(self, manifold, num_embeddings, embedding_dim, dtype=torch.double, requires_grad=True, weights=None):
        super().__init__()
        if dtype != torch.double:
            logging.warning("Double precision is recommended for embeddings on manifold")
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self._manifold = manifold
        if weights is None:
            data = torch.zeros((num_embeddings, embedding_dim), dtype=dtype)
            data = geoopt.ManifoldTensor(data, manifold=self._manifold)
            self.w = geoopt.ManifoldParameter(data, requires_grad=requires_grad)
            self.reset_parameters()
        else:
            raise NotImplementedError()
            
    def forward(self, x):
        s0 = x.shape
        ws = self.w[x.view(-1)]
        return ws.view(*s0, self.embedding_dim)
    
    def reset_parameters(self) -> None:
        nn.init.normal_(self.w.data)
        self.w.data[:] = self._manifold.retr(torch.zeros(self.embedding_dim), self.w.data)
        
        
class LorentzEmbedding(ManifoldEmbedding):
    
    def __init__(self, num_embeddings, embedding_dim, k=1.0, **kwargs):
        manifold = geoopt.manifolds.Lorentz(k, learnable=False)
        super().__init__(manifold, num_embeddings, embedding_dim, **kwargs)
        
        
class LorentzSkipGram(nn.Module):
    
    def __init__(self, theta, k=1.0):
        super().__init__()
        self.theta = theta
        self._manifold = geoopt.manifolds.Lorentz(k)
        self.x0 = torch.zeros(10)
        
    def forward(self, a, b):
        return self._manifold.inner(self.x0, a, b)
    

class SGNSLoss(nn.Module):
    
    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction
        
    def forward(self, y_, y):
        y.masked_fill_(y == 0, -1)
        loss = torch.log(torch.sigmoid(y*y_))
        if self.reduction is None:
            return loss
        elif self.reduction == "mean":
            return loss.mean()
        elif self.rediction == "sum":
            return loss.sum()
        raise NotImplementedError()