In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder, Embedding
from gensim.models import Word2Vec

In [2]:
wv = Word2Vec.load("../checkpoints/model_final.model").wv

In [3]:
# create embedding layer from gensim embeddings
print(wv.vector_size, len(wv))
embd = Embedding(len(wv), wv.vector_size)
embd.weight.data.copy_(torch.tensor(wv.vectors, dtype=torch.float))

100 281217


tensor([[ 1.0990,  2.4085,  0.3983,  ...,  0.0845,  0.2064,  2.6756],
        [ 0.2235,  5.3773,  2.8357,  ..., -4.8849, -5.0111, -2.0209],
        [-2.7695,  4.5329,  0.3006,  ..., -5.3204, -3.4041,  0.4088],
        ...,
        [-1.1284, -0.0572, -2.2218,  ..., -0.6975,  0.8115, -1.0279],
        [ 0.6771,  0.1054,  0.1341,  ...,  0.6022, -0.1027, -0.7091],
        [-0.8491, -0.9734, -1.2336,  ...,  0.8185, -0.1867, -1.1534]])

In [4]:
encoder_layer = TransformerEncoderLayer(wv.vector_size, 10, batch_first=True)
encoder = TransformerEncoder(encoder_layer, 3)

In [5]:
seq = random.choices(list(wv.key_to_index.keys()), k=12)
t_seq = torch.tensor([wv.key_to_index[item] for item in seq], dtype=torch.long)

In [6]:
v = embd(t_seq.view(1,-1))
encoder(v).shape, v.shape

(torch.Size([1, 12, 100]), torch.Size([1, 12, 100]))

In [7]:
v.shape

torch.Size([1, 12, 100])

In [8]:
# mask padding: src_key_padding_mask
# mask some items: mask

In [143]:
a,b = batch[2].shape
a, b

(64, 5)

In [147]:
def _generate_attn_mask_single(seq_mask):
    """
        If a BoolTensor is provided, positions with True are not allowed 
        to attend while False values will be unchanged.
        Softmax goes along -1 dimension
    """
    n = seq_mask.shape[0]
    mask = torch.zeros((n,n), dtype=torch.bool)
    mask[:, seq_mask.nonzero()] = True
    return mask

def _generate_attn_mask_batch(seq_mask):
    bs, n = seq_mask.shape
    mask = torch.zeros((bs, n, n), dtype=torch.bool)
    nz = seq_mask.nonzero()
    a, b = nz[:, 0], nz[:, 1]
    mask[a, :, b] = True
    return mask

def generate_attn_mask(seq_mask):
    if len(seq_mask.shape) == 1:
        return _generate_attn_mask_single(seq_mask)
    elif len(seq_mask.shape) == 2:
        return _generate_attn_mask_batch(seq_mask)
    else:
        assert False, f"Input should be BATCH_SIZE * SEQ_LEN matrix, got {seq_mask.shape}"
        
mask = generate_attn_mask(torch.Tensor([[False, False, True, False, True], [False, False, False, False, True]]))
mask

tensor([[[False, False,  True, False,  True],
         [False, False,  True, False,  True],
         [False, False,  True, False,  True],
         [False, False,  True, False,  True],
         [False, False,  True, False,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]]])

In [10]:
mask = torch.Tensor([False] * v.shape[1])
attn_mask = generate_attn_mask(mask).repeat(10,1,1)
attn_mask.shape, v.shape, 

(torch.Size([10, 12, 12]), torch.Size([1, 12, 100]))

In [11]:
attn_mask.sum()

tensor(0)

In [12]:
out = encoder(v, mask=attn_mask)
out

tensor([[[-0.4049,  0.9781, -1.9790,  ...,  2.3513,  0.6293,  2.1527],
         [-1.1132,  1.9905, -1.0655,  ...,  0.5341,  0.0057,  1.9082],
         [ 0.2805, -0.1490, -1.7117,  ...,  0.5807,  0.0717,  1.0083],
         ...,
         [-1.3863,  2.4944,  0.3424,  ...,  0.8583,  0.8222,  0.4182],
         [-1.5273,  1.4244, -0.9384,  ...,  0.4088,  0.3260, -0.5216],
         [-0.8199,  0.5492, -0.1732,  ...,  1.3180,  1.1253,  0.8410]]],
       grad_fn=<NativeLayerNormBackward0>)

In [122]:
import json
import logging

class PlaylistDataset(torch.utils.data.Dataset):
    
    def __init__(self, files, playlist_per_file, transform=None):
        self.files = files
        self.current_file_index = -1
        self.data = None
        self.ppf = playlist_per_file
        self.transform = transform
        
    def __len__(self):
        return self.ppf * len(self.files)
    
    def _load(self, path):
        with open(path, "r") as f:
            self.data = json.load(f)
    
    def __getitem__(self, index):
        file_index = index // self.ppf
        offset = index % self.ppf
        if self.current_file_index != file_index:
            logging.debug(f"Loading file {self.files[file_index]}")
            self._load(self.files[file_index])
            self.current_file_index = file_index
        tracks = self.data["playlists"][offset]
        
        if self.transform is not None:
            tracks = self.transform(tracks)
        
        return tracks

    
class Compose:
    
    def __init__(self, *tfs):
        self.tfs = tfs
        
    def __call__(self, x):
        for tf in self.tfs:
            x = tf(x)
        return x
    
    
class RemoveUnknownTracks:
    
    def __init__(self, known_tracks):
        kt = known_tracks
        if not isinstance(kt, set):
            kt = set(kt)
        self.kt = kt
        
    def __call__(self, x):
        return [xi for xi in x if xi in self.kt]
    
    
class TrackURI2Idx:
    
    def __init__(self, uri2idx, offset=0):
        self.offset = offset
        self.uri2idx = uri2idx
        
    def __call__(self, x):
        return [self.uri2idx[xi] + self.offset for xi in x]
    
    
class ToLongTensor:
    
    def __call__(self, x):
        return torch.LongTensor(x)
    
class PadOrTrim:
    
    def __init__(self, pad_token, target_length):
        self.token = pad_token
        self.t = target_length
    
    def __call__(self, x):
        if len(x) == self.t:
            return x
        if len(x) < self.t:
            return x + [self.token] * (self.t - len(x))
        return x[:self.t]
    
    
class MaskTracksTensor:
    
    def __init__(self, mask_token, padding_token, mask_proba):
        self.token = mask_token
        self.padding_token = padding_token
        self.proba = mask_proba
        
    def __call__(self, x):
        mask = torch.rand(x.shape[0]) < self.proba
        padding = x == self.padding_token
        # avoid masking padded tracks
        mask = mask & (~padding)
        x_ = x.clone()
        x_[mask] = self.token
        return x_, x, mask
    

In [123]:
mask = torch.BoolTensor([True, False, False, True])
pad = torch.BoolTensor([False, False, True, True])

mask & (~pad)

tensor([ True, False, False, False])

In [124]:
files = [f"../playlists_data/chunk_{i}.json" for i in range(20)]
files

['../playlists_data/chunk_0.json',
 '../playlists_data/chunk_1.json',
 '../playlists_data/chunk_2.json',
 '../playlists_data/chunk_3.json',
 '../playlists_data/chunk_4.json',
 '../playlists_data/chunk_5.json',
 '../playlists_data/chunk_6.json',
 '../playlists_data/chunk_7.json',
 '../playlists_data/chunk_8.json',
 '../playlists_data/chunk_9.json',
 '../playlists_data/chunk_10.json',
 '../playlists_data/chunk_11.json',
 '../playlists_data/chunk_12.json',
 '../playlists_data/chunk_13.json',
 '../playlists_data/chunk_14.json',
 '../playlists_data/chunk_15.json',
 '../playlists_data/chunk_16.json',
 '../playlists_data/chunk_17.json',
 '../playlists_data/chunk_18.json',
 '../playlists_data/chunk_19.json']

In [125]:
PAD_TOKEN = 0
MASK_TOKEN = 1

transforms = Compose(
    RemoveUnknownTracks(wv.key_to_index.keys()),
    TrackURI2Idx(wv.key_to_index, offset=2),
    PadOrTrim(PAD_TOKEN, 5),
    ToLongTensor(),
    MaskTracksTensor(MASK_TOKEN, PAD_TOKEN, .1)
)

ds = PlaylistDataset(files, 50000, transforms)

In [126]:
from torch.utils.data import DataLoader

In [127]:
dl = DataLoader(ds, shuffle=False, batch_size=64)

In [128]:
for batch in dl:
    break

In [129]:
batch

[tensor([[  1228,    356,    177,    803,     68],
         [   380, 235913,      1, 126095,  10340],
         [ 99654, 153183, 140755,  55915,  11900],
         [ 31343,  37487, 131362,  38786,  34456],
         [  4191,    143,    846, 212700,   1043],
         [  1636,     98,      1,   1283,  60442],
         [ 11492,  20491, 179673,   2002,   5471],
         [ 64637,   3487,   3761,    401,      1],
         [ 17164,   7608,      1,  15300,  26847],
         [  4898,   4898,  12900,   2189,   2758],
         [  3686,    657,    187,    965,    748],
         [  6674, 191212,   2463,   5150,   3873],
         [  2827,    134,    121,   1838,    586],
         [   415,   2215,      1,  82184,   1788],
         [     1,   3431,   4127,   3236,   1443],
         [  1354,  15557,  11069,      1,   7987],
         [ 20766,  33418, 144088,  14677,      1],
         [  1332,   3587,      1,    130,      1],
         [  5716,  12725,   2210,   2717,    252],
         [ 31207,  61419,  2692

In [30]:
embd(batch).shape

torch.Size([64, 5, 100])

In [31]:
batch

tensor([[  1227,    355,    176,    802,     67],
        [   379, 235912, 225683, 126094,  10339],
        [ 99653, 153182, 140754,  55914,  11899],
        [ 31342,  37486, 131361,  38785,  34455],
        [  4190,    142,    845, 212699,   1042],
        [  1635,     97,    918,   1282,  60441],
        [ 11491,  20490, 179672,   2001,   5470],
        [ 64636,   3486,   3760,    400,  10090],
        [ 17163,   7607,  17485,  15299,  26846],
        [  4897,   4897,  12899,   2188,   2757],
        [  3685,    656,    186,    964,    747],
        [  6673, 191211,   2462,   5149,   3872],
        [  2826,    133,    120,   1837,    585],
        [   414,   2214,   4347,  82183,   1787],
        [  7760,   3430,   4126,   3235,   1442],
        [  1353,  15556,  11068,   4290,   7986],
        [ 20765,  33417, 144087,  14676,  11572],
        [  1331,   3586,   6645,    129,    423],
        [  5715,  12724,   2209,   2716,    251],
        [ 31206,  61418,  26919, 229897,  27662],


In [64]:
import lightning.pytorch as pl


# TODO: add final linear layer

GELU = nn.GELU()

class TransRec(pl.LightningModule):
    
    def __init__(self, wv_model, n_head, layer_kwargs={}, enc_kwargs={}):
        super().__init__()
        
        self.embd = Embedding(len(wv_model)+2, wv_model.vector_size) # +2 for <PAD> and <MASK> tokens
        self.embd.weight.data[2:].copy_(torch.tensor(wv_model.vectors, dtype=torch.float))
        self.embd.requires_grad_ = False
        
        encoder_layer = TransformerEncoderLayer(wv_model.vector_size, batch_first=True, \
                                                nhead=n_head, **layer_kwargs)
        self.encoder = TransformerEncoder(encoder_layer, **enc_kwargs)
        self.linear = nn.Linear(wv_model.vector_size, wv_model.vector_size, bias=True)
        
    def forward(self, x):
        x = self.embd(x)
        x = self.encoder(x)
        z = GELU(self.linear(x))
        bs, seq_len, embd_dim = x.shape
        num_tokens = self.embd.weight.shape[0]
        logits = torch.matmul(self.embd.weight, z.view(-1, embd_dim).T) \
                .view(num_tokens, bs, -1)
        return logits.softmax(dim=0)
    
    
    def _build_attn_mask()
        
        
    def training_step(self, batch, batch_idx):
        x, y, mask = batch
        padding_mask = x == PAD_TOKEN
        
        
    
    def configure_optimizers(self):
        pass

In [65]:
tr = TransRec(wv, n_head=10, enc_kwargs={"num_layers": 3})

In [66]:
out = tr(batch)
out.shape

torch.Size([281219, 64, 5])

In [67]:
out.argmax(dim=0).shape

torch.Size([64, 5])

In [70]:
out.softmax(dim=0).sum(dim=0).sum()

tensor(320.0490, grad_fn=<SumBackward0>)