In [1]:
import networkx
import geoopt
import torch
import torch.nn as nn
import numpy as np
import random
import logging

In [2]:
class ManifoldEmbedding(nn.Module):
    
    def __init__(self, manifold, num_embeddings, embedding_dim, dtype=torch.double, requires_grad=True, weights=None):
        super().__init__()
        if dtype != torch.double:
            logging.warning("Double precision is recommended for embeddings on manifold")
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self._manifold = manifold
        if weights is None:
            data = torch.zeros((num_embeddings, embedding_dim), dtype=dtype)
            self.w = geoopt.ManifoldParameter(data, requires_grad=requires_grad, manifold=self._manifold)
            self.reset_parameters()
        else:
            raise NotImplementedError()
            
    def forward(self, x):
        s0 = x.shape
        ws = self.w[x.view(-1)]
        return ws.view(*s0, self.embedding_dim)
    
    def reset_parameters(self) -> None:
        nn.init.normal_(self.w.data, std=0.25)
        self.w.data[:] = self._manifold.retr(torch.zeros(self.embedding_dim), self.w.data)
        
        
class LorentzEmbedding(ManifoldEmbedding):
    
    def __init__(self, num_embeddings, embedding_dim, k=1.0, **kwargs):
        manifold = geoopt.manifolds.Lorentz(k, learnable=False)
        super().__init__(manifold, num_embeddings, embedding_dim, **kwargs)
        
        
class LorentzSkipGram(nn.Module):
    
    def __init__(self, k=1.0):
        super().__init__()
        self._manifold = geoopt.manifolds.Lorentz(k, learnable=False)
        
    def forward(self, a, b):
        x0 = torch.zeros(a.shape[-1]).to(a.device)
        return self._manifold.inner(x0, a, b)
    
    
class SGNSLoss(nn.Module):
    
    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction
        
    def forward(self, y_, y):
        y.masked_fill_(y == 0, -1)
        loss = -torch.log(torch.sigmoid(y * y_))
        if self.reduction is None:
            return loss
        elif self.reduction == "mean":
            return loss.mean()
        elif self.rediction == "sum":
            return loss.sum()
        raise NotImplementedError()

In [3]:
import lightning as pl


class Model(pl.LightningModule):
    
    def __init__(self, num_embeddings, embedding_dim, theta, k=1.0):
        super().__init__()
        self.embd = LorentzEmbedding(num_embeddings, embedding_dim, k)
        self.sg = LorentzSkipGram(k)
        self.loss_fn = SGNSLoss()
        self.theta = theta
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        e1, e2 = x[:, 0], x[:, 1]
        v1, v2 = self.embd(e1), self.embd(e2)
        z = self.sg(v1, v2) + self.theta
        loss = self.loss_fn(z, y)
        self.log("training_loss", loss.item(), prog_bar=True)
        return loss
        
    def configure_optimizers(self):
        optimizer = geoopt.optim.RiemannianAdam(self.parameters(), 1e-3)
        return optimizer


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import random

def skip_gram(x, i, w):
    return x[i], x[max(0, i-w):i] + x[i+1:i+w+1]

class SkipGramWithNegativeSampling:
    
    def __init__(self, window, vocabulary, negative=5, negative_probs=None):
        self.window = window
        self.vocabulary = vocabulary
        self.negative = negative
        self.negative_probs = negative_probs
        
    def sample_negatives(self, query):
        if self.negative == 0:
            return []
        assert self.negative_probs is None, "Weighted sampling not implemented yet!"
        items = set(query)
        randoms = random.choices(self.vocabulary, k=len(items)*self.negative)
        return zip(list(items)*self.negative, randoms)
        
    def __call__(self, x):
        grams = [skip_gram(x, i, self.window) for i in range(len(x))]
        batches = [[w,c] for w,context in grams for c in context]
        negatives = list(self.sample_negatives(x))
        labels = [1] * len(batches) + [0] * len(negatives)
        return batches + negatives, labels
    
    
class ToTensor:
    
    def __init__(self, *dtypes):
        self.dtypes = dtypes
        
    def __call__(self, x):
        assert isinstance(x, tuple)
        assert len(x) == len(self.dtypes), f"Number of inputs {len(x)} does not match number of specified data types {len(self.dtypes)}"
        return tuple(torch.tensor(xi, dtype=di) for xi, di in zip(x, self.dtypes))

In [5]:
import sys
sys.path.insert(0, "../src/")
from models.transformer.loader import PlaylistDataset
from models.transformer.transform import *

In [6]:
import os

# utils to create this file list

def get_file_list(base):
    return [os.path.join(base, f) for f in os.listdir(base) if ".json" in f]

files = get_file_list("../data/processed/")
len(files)

20

In [7]:
"""Compute and save song frequencies
from collections import Counter
from tqdm import tqdm
import json

songs = Counter()
for f in tqdm(files):
    with open(f) as f:
        data = json.load(f)["playlists"]
        for pl in data:
            songs.update(pl)
            
with open("../data/frequencies.json", "w") as f:
    f.write(json.dumps(dict(songs)))
len(songs)
"""

'Compute and save song frequencies\nfrom collections import Counter\nfrom tqdm import tqdm\nimport json\n\nsongs = Counter()\nfor f in tqdm(files):\n    with open(f) as f:\n        data = json.load(f)["playlists"]\n        for pl in data:\n            songs.update(pl)\n            \nwith open("../data/frequencies.json", "w") as f:\n    f.write(json.dumps(dict(songs)))\nlen(songs)\n'

In [8]:
import json

MIN_FREQ = 5

with open("../data/frequencies.json") as f:
    frequencies = json.load(f)

frequencies = dict(filter(lambda item: item[1] >= MIN_FREQ, frequencies.items()))
songs = list(set(frequencies.keys()))
song2idx = {s: i for i,s in enumerate(songs)}
idx2song = {i: s for s,i in song2idx.items()}
len(frequencies)

599341

In [9]:
def collate_fn(data):
    pairs, labels = list(zip(*data))
    return torch.cat(pairs, dim=0), torch.cat(labels, dim=0)

tf = Compose(
    RemoveUnknownTracks(songs),
    TrackURI2Idx(song2idx),
    SkipGramWithNegativeSampling(5, list(song2idx.values()), 10),
    ToTensor(torch.long, torch.float)
)

dataset = PlaylistDataset(files, 50_000, transform=tf)
loader = torch.utils.data.DataLoader(dataset, batch_size=1200, collate_fn=collate_fn, num_workers=4)

In [10]:
NUM_EMBEDDING = len(songs)
EMBEDDING_DIM = 32
THETA = 3

model = Model(NUM_EMBEDDING, EMBEDDING_DIM, THETA)

In [11]:
trainer = pl.Trainer()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, loader)

  rank_zero_warn(

  | Name    | Type             | Params
---------------------------------------------
0 | embd    | LorentzEmbedding | 19.2 M
1 | sg      | LorentzSkipGram  | 1     
2 | loss_fn | SGNSLoss         | 0     
---------------------------------------------
19.2 M    Trainable params
2         Non-trainable params
19.2 M    Total params
76.716    Total estimated model params size (MB)
2023-05-09 10:56:50.419382: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-05-09 10:56:50.419422: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


Epoch 0:   2%|  | 18/834 [00:48<36:45,  2.70s/it, v_num=36, training_loss=1.010]