In [38]:
import numpy as np
import random
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder, Embedding
from gensim.models import Word2Vec

In [4]:
wv = Word2Vec.load("../checkpoints/model_final.model").wv

In [30]:
# create embedding layer from gensim embeddings
print(wv.vector_size, len(wv))
embd = Embedding(len(wv), wv.vector_size)
embd.weight.data.copy_(torch.tensor(wv.vectors, dtype=torch.float))

100 281217


tensor([[ 1.0990,  2.4085,  0.3983,  ...,  0.0845,  0.2064,  2.6756],
        [ 0.2235,  5.3773,  2.8357,  ..., -4.8849, -5.0111, -2.0209],
        [-2.7695,  4.5329,  0.3006,  ..., -5.3204, -3.4041,  0.4088],
        ...,
        [-1.1284, -0.0572, -2.2218,  ..., -0.6975,  0.8115, -1.0279],
        [ 0.6771,  0.1054,  0.1341,  ...,  0.6022, -0.1027, -0.7091],
        [-0.8491, -0.9734, -1.2336,  ...,  0.8185, -0.1867, -1.1534]])

In [108]:
encoder_layer = TransformerEncoderLayer(wv.vector_size, 10, batch_first=True)
encoder = TransformerEncoder(encoder_layer, 3)

In [109]:
seq = random.choices(list(wv.key_to_index.keys()), k=12)
t_seq = torch.tensor([wv.key_to_index[item] for item in seq], dtype=torch.long)

In [110]:
v = embd(t_seq.view(1,-1))
encoder(v).shape, v.shape

(torch.Size([1, 12, 100]), torch.Size([1, 12, 100]))

In [111]:
v.shape

torch.Size([1, 12, 100])

In [112]:
# mask padding: src_key_padding_mask
# mask some items: mask

In [142]:
def generate_attn_mask(seq_mask):
    """
        If a BoolTensor is provided, positions with True are not allowed 
        to attend while False values will be unchanged.
        Softmax goes along -1 dimension
    """
    n = seq_mask.shape[-1]
    mask = torch.zeros((n,n), dtype=torch.bool)
    mask[:, seq_mask.nonzero()] = True
    return mask
    
mask = generate_attn_mask(torch.Tensor([False, False, True, False, True]))
mask

tensor([[False, False,  True, False,  True],
        [False, False,  True, False,  True],
        [False, False,  True, False,  True],
        [False, False,  True, False,  True],
        [False, False,  True, False,  True]])

In [148]:
mask = torch.Tensor([False] * v.shape[1])
attn_mask = generate_attn_mask(mask).repeat(10,1,1)
attn_mask.shape, v.shape, 

(torch.Size([10, 12, 12]), torch.Size([1, 12, 100]))

In [144]:
attn_mask.sum()

tensor(0)

In [169]:
out = encoder(v, mask=attn_mask)
out

tensor([[[ 0.0812, -0.1265, -0.8604,  ..., -0.1958,  2.6957,  1.0880],
         [ 0.2169, -0.0378, -0.5093,  ...,  0.4028,  3.0029,  0.0229],
         [ 2.4005,  0.4884, -0.8056,  ...,  0.4309,  0.6416, -0.3111],
         ...,
         [ 2.2272,  0.7680,  0.0664,  ...,  0.4633,  0.1944, -0.7390],
         [ 1.7706,  0.5730, -1.4109,  ...,  0.1600,  1.1616, -0.4409],
         [ 1.9712, -0.0155, -1.8148,  ...,  0.3185,  1.3445, -0.2619]]],
       grad_fn=<NativeLayerNormBackward0>)

In [300]:
import json
import logging

class PlaylistDataset(torch.utils.data.Dataset):
    
    def __init__(self, files, playlist_per_file, transform=None):
        self.files = files
        self.current_file_index = -1
        self.data = None
        self.ppf = playlist_per_file
        self.transform = transform
        
    def __len__(self):
        return self.ppf * len(self.files)
    
    def _load(self, path):
        with open(path, "r") as f:
            self.data = json.load(f)
    
    def __getitem__(self, index):
        file_index = index // self.ppf
        offset = index % self.ppf
        if self.current_file_index != file_index:
            logging.debug(f"Loading file {self.files[file_index]}")
            self._load(self.files[file_index])
            self.current_file_index = file_index
        tracks = self.data["playlists"][offset]
        
        if self.transform is not None:
            tracks = self.transform(tracks)
        
        return tracks

    
class Compose:
    
    def __init__(self, *tfs):
        self.tfs = tfs
        
    def __call__(self, x):
        for tf in self.tfs:
            x = tf(x)
        return x
    
    
class RemoveUnknownTracks:
    
    def __init__(self, known_tracks):
        kt = known_tracks
        if not isinstance(kt, set):
            kt = set(kt)
        self.kt = kt
        
    def __call__(self, x):
        return [xi for xi in x if xi in self.kt]
    
    
class TrackURI2Idx:
    
    def __init__(self, uri2idx, offset=0):
        self.offset = offset
        self.uri2idx = uri2idx
        
    def __call__(self, x):
        return [self.uri2idx[xi] + self.offset for xi in x]
    
    
class ToLongTensor:
    
    def __call__(self, x):
        return torch.LongTensor(x)
    
class PadOrTrim:
    
    def __init__(self, pad_token, target_length):
        self.token = pad_token
        self.t = target_length
    
    def __call__(self, x):
        if len(x) == self.t:
            return x
        if len(x) < self.t:
            return x + [self.token] * (self.t - len(x))
        return x[:self.t]
            

In [301]:
files = [f"../playlists_data/chunk_{i}.json" for i in range(20)]
files

['../playlists_data/chunk_0.json',
 '../playlists_data/chunk_1.json',
 '../playlists_data/chunk_2.json',
 '../playlists_data/chunk_3.json',
 '../playlists_data/chunk_4.json',
 '../playlists_data/chunk_5.json',
 '../playlists_data/chunk_6.json',
 '../playlists_data/chunk_7.json',
 '../playlists_data/chunk_8.json',
 '../playlists_data/chunk_9.json',
 '../playlists_data/chunk_10.json',
 '../playlists_data/chunk_11.json',
 '../playlists_data/chunk_12.json',
 '../playlists_data/chunk_13.json',
 '../playlists_data/chunk_14.json',
 '../playlists_data/chunk_15.json',
 '../playlists_data/chunk_16.json',
 '../playlists_data/chunk_17.json',
 '../playlists_data/chunk_18.json',
 '../playlists_data/chunk_19.json']

In [302]:
transforms = Compose(
    RemoveUnknownTracks(wv.key_to_index.keys()),
    TrackURI2Idx(wv.key_to_index, offset=1),
    PadOrTrim(0, 20),
    ToLongTensor()
)

ds = PlaylistDataset(files, 50000, transforms)

In [243]:
"""
import os
from tqdm import tqdm

BATCH_SIZE = 50000

def get_playlists(data):
    for plist in data["playlists"]:
        sequence = [song["track_uri"] for song in plist["tracks"]]
        yield sequence

def dump(batch, path):
    print(f"Writing {path}")
    with open(path, "w") as f:
        json.dump({"playlists": batch}, f)
        
batch = []
i = 0
for fname in tqdm(list(sorted(os.listdir("../data")))):
    if not (fname.startswith("mpd.slice.") and fname.endswith(".json")):
        continue
        
    fullpath = os.sep.join(("../data", fname))
    with open(fullpath, "r") as f:
        data = json.load(f)
    batch += list(get_playlists(data))
    
    if len(batch) >= BATCH_SIZE:
        dump(batch, f"../playlists_data/chunk_{i}.json")
        batch = []
        i += 1
        
if len(batch) > 0:
    dump(batch, f"../playlists_data/chunk_{i}.json")
"""

  5%|██                                       | 49/1000 [00:07<02:02,  7.75it/s]

Writing ../playlists_data/chunk_0.json


 10%|████                                     | 99/1000 [00:16<02:36,  5.75it/s]

Writing ../playlists_data/chunk_1.json


 15%|█████▉                                  | 149/1000 [00:26<02:18,  6.14it/s]

Writing ../playlists_data/chunk_2.json


 20%|███████▉                                | 199/1000 [00:36<01:57,  6.83it/s]

Writing ../playlists_data/chunk_3.json


 25%|█████████▉                              | 249/1000 [00:46<01:56,  6.45it/s]

Writing ../playlists_data/chunk_4.json


 30%|███████████▉                            | 299/1000 [00:56<01:53,  6.18it/s]

Writing ../playlists_data/chunk_5.json


 35%|█████████████▉                          | 349/1000 [01:06<02:38,  4.10it/s]

Writing ../playlists_data/chunk_6.json


 40%|███████████████▉                        | 399/1000 [01:16<01:41,  5.90it/s]

Writing ../playlists_data/chunk_7.json


 45%|█████████████████▉                      | 449/1000 [01:27<01:39,  5.56it/s]

Writing ../playlists_data/chunk_8.json


 50%|███████████████████▉                    | 499/1000 [01:36<01:20,  6.20it/s]

Writing ../playlists_data/chunk_9.json


 55%|█████████████████████▉                  | 549/1000 [01:46<01:18,  5.71it/s]

Writing ../playlists_data/chunk_10.json


 60%|███████████████████████▉                | 599/1000 [01:56<01:07,  5.91it/s]

Writing ../playlists_data/chunk_11.json


 65%|█████████████████████████▉              | 649/1000 [02:06<00:58,  5.97it/s]

Writing ../playlists_data/chunk_12.json


 70%|███████████████████████████▉            | 699/1000 [02:16<00:52,  5.71it/s]

Writing ../playlists_data/chunk_13.json


 75%|█████████████████████████████▉          | 749/1000 [02:27<00:43,  5.83it/s]

Writing ../playlists_data/chunk_14.json


 80%|███████████████████████████████▉        | 799/1000 [02:36<00:49,  4.08it/s]

Writing ../playlists_data/chunk_15.json


 85%|█████████████████████████████████▉      | 849/1000 [02:46<00:25,  5.93it/s]

Writing ../playlists_data/chunk_16.json


 90%|███████████████████████████████████▉    | 899/1000 [02:57<00:18,  5.57it/s]

Writing ../playlists_data/chunk_17.json


 95%|█████████████████████████████████████▉  | 949/1000 [03:07<00:10,  5.08it/s]

Writing ../playlists_data/chunk_18.json


100%|███████████████████████████████████████▉| 999/1000 [03:17<00:00,  5.83it/s]

Writing ../playlists_data/chunk_19.json


100%|███████████████████████████████████████| 1000/1000 [03:18<00:00,  5.03it/s]


In [303]:
from torch.utils.data import DataLoader

In [304]:
dl = DataLoader(ds, shuffle=False, batch_size=512)

In [308]:
for batch in dl:
    print(batch[:, -1])
    asd

tensor([  1307,     53,  41116,  51833,      0,   2424,      0,   4191,    162,
             0,   2507,   6671,      0,      0,  13741,      0,  85067,    191,
          3817, 160357,      0,      0,   4323,  32338,      0,  30586,  10295,
             0,     14,    313,  31142,     40,    976,      0,   1607,     35,
             0,   1681, 153092, 167702,    675,    359,   5382,   4940,  24167,
           361,      0,      0,   1778, 159385,   4531,  13787,      0,    863,
             0,    481,      0,      0,  24654,  37324,  28972,      0,    119,
         15746,  47823,   3068,  38377,      0,  22133,      0, 212766,    510,
         30946,     61,      0,      0, 163122,  60665,      0,      0,  88003,
          1131,      0,  76051,  43209,    191, 165599,  22655,   9727,  24915,
         22285,  34323,      0,  21196,   4052,   5997,      0,  16815, 102969,
         49558,      0,      0,    730,   2210,   3559,    381,   1989,   3488,
         22475,  10542,      0,   1925, 

NameError: name 'asd' is not defined

In [271]:
len(ds)

TypeError: 'list' object cannot be interpreted as an integer