In [2]:
import networkx
import geoopt
import torch
import torch.nn as nn
import numpy as np
import random
import logging

In [3]:
class ManifoldEmbedding(nn.Module):
    
    def __init__(self, manifold, num_embeddings, embedding_dim, dtype=torch.double, requires_grad=True, weights=None):
        super().__init__()
        if dtype != torch.double:
            logging.warning("Double precision is recommended for embeddings on manifold")
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self._manifold = manifold
        if weights is None:
            data = torch.zeros((num_embeddings, embedding_dim), dtype=dtype)
            data = geoopt.ManifoldTensor(data, manifold=self._manifold)
            self.w = geoopt.ManifoldParameter(data, requires_grad=requires_grad)
            self.reset_parameters()
        else:
            raise NotImplementedError()
            
    def forward(self, x):
        s0 = x.shape
        ws = self.w[x.view(-1)]
        return ws.view(*s0, self.embedding_dim)
    
    def reset_parameters(self) -> None:
        nn.init.normal_(self.w.data)
        self.w.data[:] = self._manifold.retr(torch.zeros(self.embedding_dim), self.w.data)
        
        
class LorentzEmbedding(ManifoldEmbedding):
    
    def __init__(self, num_embeddings, embedding_dim, k=1.0, **kwargs):
        manifold = geoopt.manifolds.Lorentz(k, learnable=False)
        super().__init__(manifold, num_embeddings, embedding_dim, **kwargs)
        
        
class LorentzSkipGram(nn.Module):
    
    def __init__(self, theta, k=1.0):
        super().__init__()
        self.theta = theta
        self._manifold = geoopt.manifolds.Lorentz(k)
        self.x0 = torch.zeros(10)
        
    def forward(self, a, b):
        return self._manifold.inner(self.x0, a, b)
    

class SGNSLoss(nn.Module):
    
    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction
        
    def forward(self, y_, y):
        y.masked_fill_(y == 0, -1)
        loss = torch.log(torch.sigmoid(y*y_))
        if self.reduction is None:
            return loss
        elif self.reduction == "mean":
            return loss.mean()
        elif self.rediction == "sum":
            return loss.sum()
        raise NotImplementedError()

In [6]:
import lightning as pl


class Model(pl.LightningModule):
    
    def __init__(self, num_embeddings, embedding_dim, theta, k=1.0):
        super().__init__()
        self.embd = LorentzEmbedding(num_embeddings, embedding_dim, k)
        self.sg = LorentzSkipGram(theta, k)
        self.loss_fn = SGNSLoss()
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        e1, e2 = x[:, 0], x[:, 1]
        z = self.sg(e1, e2)
        return self.loss_fn(z, y)
        
    def configure_optimizers(self):
        optimizer = geoopt.optim.RiemannianAdam(self.parameters(), 1e-3)
        return optimizer


In [None]:
class SGNSDataset(torch.utils.data.Dataset):
    
    def __init__(self, files, playlist_per_file, window=5, transform=None):
        super().__init__()
        self.files = files
        self.current_file_index = -1
        self.data = None
        self.transform = transform
        self.ppf = playlist_per_file
        self.window = window
        
    def __len__(self):
        return self.ppf * len(self.files)
        
    def _load(self, path):
        with open(path, "r") as f:
            self.data = json.load(f)
    
    def __getitem__(self, index):
        file_index = index // self.ppf
        offset = index % self.ppf
        if self.current_file_index != file_index:
            logging.debug(f"Loading file {self.files[file_index]}")
            self._load(self.files[file_index])
            self.current_file_index = file_index
        tracks = self.data["playlists"][offset]
        
        if self.transform is not None:
            tracks = self.transform(tracks)
        
        return tracks

In [55]:
import random

def skip_gram(x, i, w):
    return x[i], x[max(0, i-w):i] + x[i+1:i+w+1]

class SkipGramWithNegativeSampling:
    
    def __init__(self, window, vocabulary, negative=5, negative_probs=None):
        self.window = window
        self.vocabulary = vocabulary
        self.negative = negative
        self.negative_probs = negative_probs
        
    def sample_negatives(self, query):
        assert self.negative_probs is None, "Weighted sampling not implemented yet!"
        items = set(query)
        randoms = random.choices(self.vocabulary, k=len(items)*self.negative)
        return zip(list(items)*self.negative, randoms)
        
    def __call__(self, x):
        grams = [skip_gram(x, i, self.window) for i in range(len(x))]
        batches = [[w,c] for w,context in grams for c in context]
        negatives = list(self.sample_negatives(x))
        labels = [1] * len(batches) + [0] * len(negatives)
        return batches + negatives, labels

In [61]:
vocab = list("abcdefghijklmnopqrstuvxyz")
sg = SkipGramWithNegativeSampling(2, vocab, negative=4)
for a,b in zip(*sg("almapaprika")):
    print(a,b)

['a', 'l'] 1
['a', 'm'] 1
['l', 'a'] 1
['l', 'm'] 1
['l', 'a'] 1
['m', 'a'] 1
['m', 'l'] 1
['m', 'a'] 1
['m', 'p'] 1
['a', 'l'] 1
['a', 'm'] 1
['a', 'p'] 1
['a', 'a'] 1
['p', 'm'] 1
['p', 'a'] 1
['p', 'a'] 1
['p', 'p'] 1
['a', 'a'] 1
['a', 'p'] 1
['a', 'p'] 1
['a', 'r'] 1
['p', 'p'] 1
['p', 'a'] 1
['p', 'r'] 1
['p', 'i'] 1
['r', 'a'] 1
['r', 'p'] 1
['r', 'i'] 1
['r', 'k'] 1
['i', 'p'] 1
['i', 'r'] 1
['i', 'k'] 1
['i', 'a'] 1
['k', 'r'] 1
['k', 'i'] 1
['k', 'a'] 1
['a', 'i'] 1
['a', 'k'] 1
('a', 'd') 0
('r', 'u') 0
('m', 'd') 0
('p', 'k') 0
('i', 'e') 0
('k', 'y') 0
('l', 'e') 0
('a', 'r') 0
('r', 'f') 0
('m', 'o') 0
('p', 'p') 0
('i', 'z') 0
('k', 'x') 0
('l', 'f') 0
('a', 's') 0
('r', 'x') 0
('m', 'u') 0
('p', 'r') 0
('i', 'i') 0
('k', 'c') 0
('l', 'c') 0
('a', 'i') 0
('r', 'k') 0
('m', 'i') 0
('p', 'k') 0
('i', 'k') 0
('k', 'g') 0
('l', 'e') 0
