In [1]:
import networkx
import geoopt
import torch
import torch.nn as nn
import numpy as np
import random
import logging

In [3]:
geoopt.manifolds.PoincareBall?

[0;31mInit signature:[0m [0mgeoopt[0m[0;34m.[0m[0mmanifolds[0m[0;34m.[0m[0mPoincareBall[0m[0;34m([0m[0mc[0m[0;34m=[0m[0;36m1.0[0m[0;34m,[0m [0mlearnable[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
Poincare ball model.

See more in :doc:`/extended/stereographic`

Parameters
----------
c : float|tensor
    ball's negative curvature. The parametrization is constrained to have positive c

Notes
-----
It is extremely recommended to work with this manifold in double precision


See Also
--------
:class:`Stereographic`
:class:`StereographicExact`
:class:`PoincareBallExact`
:class:`SphereProjection`
:class:`SphereProjectionExact`
[0;31mInit docstring:[0m Initializes internal Module state, shared by both nn.Module and ScriptModule.
[0;31mFile:[0m           ~/.venvs/base/lib/python3.10/site-packages/geoopt/manifolds/stereographic/manifold.py
[0;31mType:[0m           ABCMeta
[0;31mSubclasses:[0m     PoincareBallEx

In [44]:
class ManifoldEmbedding(nn.Module):
    
    def __init__(self, manifold, num_embeddings, embedding_dim, dtype=torch.double, requires_grad=True, weights=None):
        super().__init__()
        if dtype != torch.double:
            logging.warning("Double precision is recommended for embeddings on manifold")
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self._manifold = manifold
        if weights is None:
            data = torch.zeros((num_embeddings, embedding_dim), dtype=dtype)
            self.w = geoopt.ManifoldParameter(data, requires_grad=requires_grad, manifold=self._manifold)
            self.reset_parameters()
        else:
            raise NotImplementedError()
            
    def forward(self, x):
        s0 = x.shape
        ws = self.w[x.view(-1)]
        return ws.view(*s0, self.embedding_dim)
    
    def reset_parameters(self) -> None:
        nn.init.normal_(self.w.data, std=0.25)
        self.w.data[:] = self._manifold.retr(torch.zeros(self.embedding_dim), self.w.data)
        
        
class PoincareEmbedding(ManifoldEmbedding):
    
    def __init__(self, num_embeddings, embedding_dim, c=1.0, **kwargs):
        manifold = geoopt.manifolds.PoincareBall(c, learnable=False)
        super().__init__(manifold, num_embeddings, embedding_dim, **kwargs)
        
        
class ManifoldSquaredDistance(nn.Module):
    
    def __init__(self, manifold):
        super().__init__()
        self.manifold = manifold
        
    def forward(self, a, b):
        return self.manifold.dist2(a,b)
    
    
class SGNSLoss(nn.Module):
    
    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction
        
    def forward(self, d2, y):
        y.masked_fill_(y == 0, -1)
        loss = torch.log(torch.sigmoid(d2 * (-y)))
        if self.reduction is None:
            return -loss
        elif self.reduction == "mean":
            return -loss.mean()
        elif self.rediction == "sum":
            return -loss.sum()
        raise NotImplementedError()

In [66]:
import lightning as pl


class Model(pl.LightningModule):
    
    def __init__(self, num_embeddings, embedding_dim, k=1.0):
        super().__init__()
        self.embd = PoincareEmbedding(num_embeddings, embedding_dim, k)
        self.d2 = ManifoldSquaredDistance(embd._manifold)
        self.loss_fn = SGNSLoss()
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        print(x.shape, y.mean())
        e1, e2 = x[:, 0], x[:, 1]
        v1, v2 = self.embd(e1), self.embd(e2)
        d2 = self.d2(v1, v2)
        loss = self.loss_fn(d2, y)
        print(loss.item())
        self.log("training_loss", loss.item(), prog_bar=True)
        return loss
        
    def configure_optimizers(self):
        optimizer = geoopt.optim.RiemannianAdam(self.parameters(), 1e-3)
        return optimizer


In [54]:
import random
from itertools import accumulate

def skip_gram(x, i, w):
    return x[i], x[max(0, i-w):i] + x[i+1:i+w+1]

class SkipGramWithNegativeSampling:
    
    def __init__(self, window, vocabulary, negative=5, negative_probs=None):
        self.window = window
        self.vocabulary = vocabulary
        self.negative = negative
        if negative_probs is not None:
            self.negative_probs = list(accumulate(negative_probs))
        else:
            self.negative_probs = None
        
    def sample_negatives(self, query):
        if self.negative == 0:
            return []
        items = set(query)
        randoms = random.choices(self.vocabulary, k=len(items)*self.negative, cum_weights=self.negative_probs)
        return zip(list(items)*self.negative, randoms)
        
    def __call__(self, x):
        grams = [skip_gram(x, i, self.window) for i in range(len(x))]
        batches = [[w,c] for w,context in grams for c in context]
        negatives = list(self.sample_negatives(x))
        labels = [1] * len(batches) + [0] * len(negatives)
        return batches + negatives, labels
    
    
class ToTensor:
    
    def __init__(self, *dtypes):
        self.dtypes = dtypes
        
    def __call__(self, x):
        assert isinstance(x, tuple)
        assert len(x) == len(self.dtypes), f"Number of inputs {len(x)} does not match number of specified data types {len(self.dtypes)}"
        return tuple(torch.tensor(xi, dtype=di) for xi, di in zip(x, self.dtypes))

In [55]:
import sys
sys.path.insert(0, "../src/")
from models.transformer.loader import PlaylistDataset
from models.transformer.transform import *

In [56]:
import os

# utils to create this file list

def get_file_list(base):
    return [os.path.join(base, f) for f in os.listdir(base) if ".json" in f]

files = get_file_list("../data/processed/")
len(files)

20

In [57]:
import json

MIN_FREQ = 15

with open("../data/frequencies.json") as f:
    frequencies = json.load(f)

frequencies = dict(filter(lambda item: item[1] >= MIN_FREQ, frequencies.items()))
songs = list(set(frequencies.keys()))
song2idx = {s: i for i,s in enumerate(songs)}
idx2song = {i: s for s,i in song2idx.items()}
len(frequencies)

281217

In [62]:
def collate_fn(data):
    pairs, labels = list(zip(*data))
    return torch.cat(pairs, dim=0), torch.cat(labels, dim=0)

# probs
alpha = 0.75
adjusted_song_weights = np.array([frequencies[s]**alpha for s in song2idx.keys()])


tf = Compose(
    RemoveUnknownTracks(songs),
    TrackURI2Idx(song2idx),
    SkipGramWithNegativeSampling(5, list(song2idx.values()), 10, negative_probs=adjusted_song_weights),
    ToTensor(torch.long, torch.float)
)

dataset = PlaylistDataset(files, 50_000, transform=tf)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=collate_fn, num_workers=5)

In [67]:
NUM_EMBEDDING = len(songs)
EMBEDDING_DIM = 32
THETA = 3

model = Model(NUM_EMBEDDING, EMBEDDING_DIM, THETA)

In [68]:
trainer = pl.Trainer()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, loader)


  | Name    | Type                    | Params
----------------------------------------------------
0 | embd    | PoincareEmbedding       | 9.0 M 
1 | d2      | ManifoldSquaredDistance | 1     
2 | loss_fn | SGNSLoss                | 0     
----------------------------------------------------
9.0 M     Trainable params
2         Non-trainable params
9.0 M     Total params
35.996    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

torch.Size([5620, 2]) tensor(0.4982)
2.1158828101975757
torch.Size([430, 2]) tensor(0.4419)
1.9365777741379142
torch.Size([1000, 2]) tensor(0.4700)
2.041391946547476
torch.Size([4760, 2]) tensor(0.4958)
2.1232267549794526
torch.Size([4170, 2]) tensor(0.4940)
2.097162534012782
torch.Size([2250, 2]) tensor(0.4889)
2.0847246248830715
torch.Size([320, 2]) tensor(0.4062)
1.7389213962627443
torch.Size([2960, 2]) tensor(0.4899)
2.0763438357831623
torch.Size([3620, 2]) tensor(0.4972)
2.1191803129590845
torch.Size([4080, 2]) tensor(0.4926)
2.1022303911109863
torch.Size([1590, 2]) tensor(0.4906)
2.084970279897508
torch.Size([3740, 2]) tensor(0.5000)
2.1177839828562774
torch.Size([1390, 2]) tensor(0.4820)
2.0462985446344737
torch.Size([1940, 2]) tensor(0.4845)
2.0790010134866144
torch.Size([2420, 2]) tensor(0.4876)
2.069303238998115
torch.Size([2420, 2]) tensor(0.4876)
2.0719898407321455
torch.Size([1810, 2]) tensor(0.4972)
2.1293263006469645
torch.Size([3620, 2]) tensor(0.4917)
2.086537512831321