In [6]:
import numpy as np
import random
import torch
import torch.nn as nn
from torch.nn import TransformerEncoderLayer, TransformerEncoder, Embedding
from gensim.models import Word2Vec

In [7]:
np.__version__

'1.24.1'

In [8]:
wv = Word2Vec.load("../checkpoints/model_final.model").wv

In [9]:
def _generate_attn_mask_single(seq_mask):
    """
        If a BoolTensor is provided, positions with True are not allowed 
        to attend while False values will be unchanged.
        Softmax goes along -1 dimension
    """
    n = seq_mask.shape[0]
    mask = torch.zeros((n,n), dtype=torch.bool)
    mask[:, seq_mask.nonzero()] = True
    return mask

def _generate_attn_mask_batch(seq_mask, n_heads):
    bs, n = seq_mask.shape
    mask = torch.zeros((bs, n, n), dtype=torch.bool)
    nz = seq_mask.nonzero()
    a, b = nz[:, 0], nz[:, 1]
    mask[a, :, b] = True
    if n_heads > 1:
        mask = mask.repeat(1, n_heads, 1)
        mask = mask.view(bs * n_heads, n, n)
    return mask

def generate_attn_mask(seq_mask, n_heads=1):
    if len(seq_mask.shape) == 1:
        return _generate_attn_mask_single(seq_mask, n_heads)
    elif len(seq_mask.shape) == 2:
        return _generate_attn_mask_batch(seq_mask, n_heads)
    else:
        assert False, f"Input should be BATCH_SIZE * SEQ_LEN matrix, got {seq_mask.shape}"
            
mask = generate_attn_mask(torch.Tensor([[False, False, True, False, True], [False, False, False, False, True]]))
mask

tensor([[[False, False,  True, False,  True],
         [False, False,  True, False,  True],
         [False, False,  True, False,  True],
         [False, False,  True, False,  True],
         [False, False,  True, False,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]]])

In [10]:
import json
import logging

class PlaylistDataset(torch.utils.data.Dataset):
    
    def __init__(self, files, playlist_per_file, transform=None):
        self.files = files
        self.current_file_index = -1
        self.data = None
        self.ppf = playlist_per_file
        self.transform = transform
        
    def __len__(self):
        return self.ppf * len(self.files)
    
    def _load(self, path):
        with open(path, "r") as f:
            self.data = json.load(f)
    
    def __getitem__(self, index):
        file_index = index // self.ppf
        offset = index % self.ppf
        if self.current_file_index != file_index:
            logging.debug(f"Loading file {self.files[file_index]}")
            self._load(self.files[file_index])
            self.current_file_index = file_index
        tracks = self.data["playlists"][offset]
        
        if self.transform is not None:
            tracks = self.transform(tracks)
        
        return tracks

    
class Compose:
    
    def __init__(self, *tfs):
        self.tfs = tfs
        
    def __call__(self, x):
        for tf in self.tfs:
            x = tf(x)
        return x
    
    
class RemoveUnknownTracks:
    
    def __init__(self, known_tracks):
        kt = known_tracks
        if not isinstance(kt, set):
            kt = set(kt)
        self.kt = kt
        
    def __call__(self, x):
        return [xi for xi in x if xi in self.kt]
    
    
class TrackURI2Idx:
    
    def __init__(self, uri2idx, offset=0):
        self.offset = offset
        self.uri2idx = uri2idx
        
    def __call__(self, x):
        return [self.uri2idx[xi] + self.offset for xi in x]
    
    
class ToLongTensor:
    
    def __call__(self, x):
        return torch.LongTensor(x)
    
class PadOrTrim:
    
    def __init__(self, pad_token, target_length):
        self.token = pad_token
        self.t = target_length
    
    def __call__(self, x):
        if len(x) == self.t:
            return x
        if len(x) < self.t:
            return x + [self.token] * (self.t - len(x))
        return x[:self.t]
    
    
class MaskTracksTensor:
    
    def __init__(self, mask_token, padding_token, mask_proba):
        self.token = mask_token
        self.padding_token = padding_token
        self.proba = mask_proba
        
    def __call__(self, x):
        mask = torch.rand(x.shape[0]) < self.proba
        padding = x == self.padding_token
        # avoid masking padded tracks
        mask = mask & (~padding)
        x_ = x.clone()
        x_[mask] = self.token
        return x_, x, mask
    

In [11]:
files = [f"../playlists_data/chunk_{i}.json" for i in range(20)]
files

['../playlists_data/chunk_0.json',
 '../playlists_data/chunk_1.json',
 '../playlists_data/chunk_2.json',
 '../playlists_data/chunk_3.json',
 '../playlists_data/chunk_4.json',
 '../playlists_data/chunk_5.json',
 '../playlists_data/chunk_6.json',
 '../playlists_data/chunk_7.json',
 '../playlists_data/chunk_8.json',
 '../playlists_data/chunk_9.json',
 '../playlists_data/chunk_10.json',
 '../playlists_data/chunk_11.json',
 '../playlists_data/chunk_12.json',
 '../playlists_data/chunk_13.json',
 '../playlists_data/chunk_14.json',
 '../playlists_data/chunk_15.json',
 '../playlists_data/chunk_16.json',
 '../playlists_data/chunk_17.json',
 '../playlists_data/chunk_18.json',
 '../playlists_data/chunk_19.json']

In [38]:
PAD_TOKEN = 0
MASK_TOKEN = 1

transforms = Compose(
    RemoveUnknownTracks(wv.key_to_index.keys()),
    TrackURI2Idx(wv.key_to_index, offset=2),
    PadOrTrim(PAD_TOKEN, 5),
    ToLongTensor(),
    MaskTracksTensor(MASK_TOKEN, PAD_TOKEN, .1)
)

ds = PlaylistDataset(files, 50000, transforms)

In [39]:
from torch.utils.data import DataLoader

In [40]:
dl = DataLoader(ds, shuffle=False, batch_size=32)

In [41]:
for batch in dl:
    break

In [34]:
import lightning.pytorch as pl


# TODO: add final linear layer

GELU = nn.GELU()

class TransRec(pl.LightningModule):
    
    def __init__(self, wv_model, n_head, layer_kwargs={}, enc_kwargs={}):
        super().__init__()
        
        self.embd = Embedding(len(wv_model)+2, wv_model.vector_size) # +2 for <PAD> and <MASK> tokens
        self.embd.weight.data[2:].copy_(torch.tensor(wv_model.vectors, dtype=torch.float))
        self.embd.requires_grad_ = False
        
        encoder_layer = TransformerEncoderLayer(wv_model.vector_size, batch_first=True, \
                                                nhead=n_head, **layer_kwargs)
        self.n_head = n_head
        self.encoder = TransformerEncoder(encoder_layer, **enc_kwargs)
        self.linear = nn.Linear(wv_model.vector_size, wv_model.vector_size, bias=True)
        
    def forward(self, x, mask=None):
        x = self.embd(x)
        x = self.encoder(x, mask=mask)
        return GELU(self.linear(x))
    
    def _token_probs(self, x, mask):
        bs, seq_len, embd_dim = x.shape
        num_tokens = self.embd.weight.shape[0]
        x = x[mask, :]
        logits = torch.matmul(self.embd.weight, x.view(-1, embd_dim).T).view(num_tokens, -1)
        return logits.softmax(dim=0)
        
    def training_step(self, batch, batch_idx):
        x, y, mask = batch
        padding_mask = x == PAD_TOKEN
        attn_mask = generate_attn_mask(mask, n_heads=self.n_head)
        
        crit = nn.CrossEntropyLoss()
        y_ = self._token_probs(self.forward(x, mask=attn_mask), mask)
        loss = crit(y_.T, y[mask])
        print(loss.item())
        
        return loss
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.002)

In [35]:
tr = TransRec(wv, n_head=10, enc_kwargs={"num_layers": 3})

In [36]:
trainer = pl.Trainer(max_epochs=1)

INFO: GPU available: False, used: False
INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [37]:
trainer.fit(tr, train_dataloaders=dl)

INFO: 
  | Name    | Type               | Params
-----------------------------------------------
0 | embd    | Embedding          | 28.1 M
1 | encoder | TransformerEncoder | 1.4 M 
2 | linear  | Linear             | 10.1 K
-----------------------------------------------
29.5 M    Trainable params
0         Non-trainable params
29.5 M    Total params
117.959   Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name    | Type               | Params
-----------------------------------------------
0 | embd    | Embedding          | 28.1 M
1 | encoder | TransformerEncoder | 1.4 M 
2 | linear  | Linear             | 10.1 K
-----------------------------------------------
29.5 M    Trainable params
0         Non-trainable params
29.5 M    Total params
117.959   Total estimated model params size (MB)
2023-04-06 18:10:03.527757: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dl

Epoch 0:   0%|                                        | 0/31250 [00:00<?, ?it/s]12.546979904174805
Epoch 0:   0%|                     | 1/31250 [00:00<8:06:05,  1.07it/s, v_num=8]12.546850204467773
Epoch 0:   0%|                     | 2/31250 [00:01<8:07:56,  1.07it/s, v_num=8]12.546832084655762
Epoch 0:   0%|                     | 3/31250 [00:02<8:32:59,  1.02it/s, v_num=8]12.546906471252441
Epoch 0:   0%|                     | 4/31250 [00:04<9:08:35,  1.05s/it, v_num=8]12.54690933227539
Epoch 0:   0%|                    | 5/31250 [00:05<10:04:56,  1.16s/it, v_num=8]12.54688835144043
Epoch 0:   0%|                    | 6/31250 [00:07<11:00:22,  1.27s/it, v_num=8]12.54714584350586
Epoch 0:   0%|                    | 7/31250 [00:09<11:35:40,  1.34s/it, v_num=8]12.547018051147461
Epoch 0:   0%|                    | 8/31250 [00:11<12:05:04,  1.39s/it, v_num=8]12.547054290771484
Epoch 0:   0%|                    | 9/31250 [00:14<13:34:53,  1.57s/it, v_num=8]12.54702377319336
Epoch 0:   0%|

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [18]:
10405103 / 281219

37.0

In [18]:
x, y, mask = batch
padding_mask = x == PAD_TOKEN
attn_mask = generate_attn_mask(mask, n_heads=tr.n_head)

In [23]:
out = tr._token_probs(tr(x, mask=attn_mask), mask=mask)

In [27]:
crit = nn.CrossEntropyLoss()

In [28]:
loss = crit(out.T, y[mask])

In [29]:
loss

tensor(12.5469, grad_fn=<NllLossBackward0>)