In [1]:
import sys

sys.path.insert(0, "../src")

from models.hyperbolic import ManifoldSkipGram

import networkx
import geoopt
import torch
import torch.nn as nn
import numpy as np
import random
import logging

In [2]:
import random
from itertools import accumulate

def skip_gram(x, i, w):
    return x[i], x[max(0, i-w):i] + x[i+1:i+w+1]

class SkipGramWithNegativeSampling:
    
    def __init__(self, window, vocabulary, negative=5, negative_probs=None):
        self.window = window
        self.vocabulary = vocabulary
        self.negative = negative
        if negative_probs is not None:
            self.negative_probs = list(accumulate(negative_probs))
        else:
            self.negative_probs = None
        
    def sample_negatives(self, query):
        if self.negative == 0:
            return []
        items = set(query)
        randoms = random.choices(self.vocabulary, k=len(items)*self.negative, cum_weights=self.negative_probs)
        return zip(list(items)*self.negative, randoms)
        
    def __call__(self, x):
        grams = [skip_gram(x, i, self.window) for i in range(len(x))]
        batches = [[w,c] for w,context in grams for c in context]
        negatives = list(self.sample_negatives(x))
        labels = [1] * len(batches) + [0] * len(negatives)
        return batches + negatives, labels
    
    
class ToTensor:
    
    def __init__(self, *dtypes):
        self.dtypes = dtypes
        
    def __call__(self, x):
        assert isinstance(x, tuple)
        assert len(x) == len(self.dtypes), f"Number of inputs {len(x)} does not match number of specified data types {len(self.dtypes)}"
        return tuple(torch.tensor(xi, dtype=di) for xi, di in zip(x, self.dtypes))

In [3]:
import sys
sys.path.insert(0, "../src/")
from models.transformer.loader import PlaylistDataset
from models.transformer.transform import *

In [4]:
import os

# utils to create this file list

def get_file_list(base):
    return [os.path.join(base, f) for f in os.listdir(base) if ".json" in f]

files = get_file_list("../data/processed/")
len(files)

20

In [5]:
"""Compute and save song frequencies
from collections import Counter
from tqdm import tqdm
import json

songs = Counter()
for f in tqdm(files):
    with open(f) as f:
        data = json.load(f)["playlists"]
        for pl in data:
            songs.update(pl)
            
with open("../data/frequencies.json", "w") as f:
    f.write(json.dumps(dict(songs)))
len(songs)
"""

'Compute and save song frequencies\nfrom collections import Counter\nfrom tqdm import tqdm\nimport json\n\nsongs = Counter()\nfor f in tqdm(files):\n    with open(f) as f:\n        data = json.load(f)["playlists"]\n        for pl in data:\n            songs.update(pl)\n            \nwith open("../data/frequencies.json", "w") as f:\n    f.write(json.dumps(dict(songs)))\nlen(songs)\n'

In [6]:
import json

MIN_FREQ = 10

with open("../data/frequencies.json") as f:
    frequencies = json.load(f)

frequencies = dict(filter(lambda item: item[1] >= MIN_FREQ, frequencies.items()))
songs = list(set(frequencies.keys()))
song2idx = {s: i for i,s in enumerate(songs)}
idx2song = {i: s for s,i in song2idx.items()}
len(frequencies)

172308

In [13]:
def collate_fn(data):
    pairs, labels = list(zip(*data))
    return torch.cat(pairs, dim=0), torch.cat(labels, dim=0)

# probs
alpha = 0.75
adjusted_song_weights = np.array([frequencies[s]**alpha for s in song2idx.keys()])


tf = Compose(
    RemoveUnknownTracks(songs),
    TrackURI2Idx(song2idx),
    SkipGramWithNegativeSampling(5, list(song2idx.values()), 10, negative_probs=adjusted_song_weights),
    ToTensor(torch.long, torch.float)
)

dataset = PlaylistDataset(files, 50_000, transform=tf)
loader = torch.utils.data.DataLoader(dataset, batch_size=16, collate_fn=collate_fn, num_workers=5)

In [14]:
from tqdm import tqdm

for _ in tqdm(loader):
    pass


  0%|                                                 | 0/62500 [00:00<?, ?it/s][A
  0%|                                      | 1/62500 [00:01<19:14:58,  1.11s/it][A
  0%|                                       | 5/62500 [00:01<3:20:40,  5.19it/s][A
  0%|                                        | 19/62500 [00:01<45:20, 22.97it/s][A
  0%|                                        | 34/62500 [00:01<24:25, 42.61it/s][A
  0%|                                        | 54/62500 [00:01<14:38, 71.09it/s][A
  0%|                                        | 71/62500 [00:01<11:30, 90.40it/s][A
  0%|                                       | 86/62500 [00:01<10:04, 103.18it/s][A
  0%|                                      | 106/62500 [00:01<08:41, 119.55it/s][A
  0%|                                      | 121/62500 [00:02<08:51, 117.28it/s][A
  0%|                                      | 136/62500 [00:02<08:42, 119.37it/s][A
  0%|                                      | 152/62500 [00:02<08:27, 122.74

In [8]:
NUM_EMBEDDING = len(songs)
EMBEDDING_DIM = 16

import geoopt
manifold = geoopt.manifolds.Lorentz()
model = ManifoldSkipGram(manifold, NUM_EMBEDDING, EMBEDDING_DIM, 
                         similarity="distance", opt_kwargs={"algo": "adam", "lr": 0.001})

In [9]:
import lightning as pl

trainer = pl.Trainer(max_epochs=50)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(model, loader)


  | Name    | Type              | Params
----------------------------------------------
0 | encoder | ManifoldEmbedding | 2.8 M 
1 | sim     | ManifoldDistance  | 1     
2 | loss_fn | SGNSLoss          | 0     
----------------------------------------------
2.8 M     Trainable params
1         Non-trainable params
2.8 M     Total params
11.028    Total estimated model params size (MB)
2023-05-26 20:10:51.422934: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-05-26 20:10:51.423180: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
c = 0
for f in frequencies.values():
    if f < MIN_FREQ:
        c += f

c, c / sum(frequencies.values())