In [1]:
import sys
sys.path.append("../src/models/transformer")

import os
from argparse import ArgumentParser

import lightning as pl
import torch
import numpy as np
from gensim.models import Word2Vec
from loader import PlaylistDataset
from model import TransformerModel
from train import MaskedLanguageModel
from transform import *

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
PAD_TOKEN = 0

MODEL_PATH = "../models/w2v/song2vec"
DATA_PATH ="../data/processed/"

NHEADS = 4
NLAYERS = 2
DROPOUT = .2
DHIDDEN = 256

SEQLEN = 75
PPF = 50000

BATCH_SIZE = 1

In [3]:
wv = Word2Vec.load(MODEL_PATH).wv
dim = wv.vectors.shape[1]
# TODO save padding vector
embeddings = np.concatenate((np.random.normal(size=(1, dim)), wv.vectors), axis=0)
transformer = TransformerModel(
    embeddings=torch.tensor(embeddings),
    nhead=NHEADS,
    nlayers=NLAYERS,
    dropout=DROPOUT,
    d_hid=DHIDDEN
)
transformer.load_state_dict(torch.load("../models/transformer/transformer_model.pt", map_location="cpu"))

<All keys matched successfully>

In [18]:
files = sorted([os.path.join(DATA_PATH, f) for f in os.listdir(DATA_PATH) if ".json" in f])
print(files)

transforms = Compose(
    RemoveUnknownTracks(wv.key_to_index.keys()),
    TrackURI2Idx(wv.key_to_index, offset=1),
    PadOrTrim(PAD_TOKEN, SEQLEN),
    ToLongTensor()
)
ds = PlaylistDataset(files, playlist_per_file=PPF, transform=transforms)
loader = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)

['../data/processed/chunk_0.json', '../data/processed/chunk_1.json', '../data/processed/chunk_10.json', '../data/processed/chunk_11.json', '../data/processed/chunk_12.json', '../data/processed/chunk_13.json', '../data/processed/chunk_14.json', '../data/processed/chunk_15.json', '../data/processed/chunk_16.json', '../data/processed/chunk_17.json', '../data/processed/chunk_18.json', '../data/processed/chunk_19.json', '../data/processed/chunk_2.json', '../data/processed/chunk_3.json', '../data/processed/chunk_4.json', '../data/processed/chunk_5.json', '../data/processed/chunk_6.json', '../data/processed/chunk_7.json', '../data/processed/chunk_8.json', '../data/processed/chunk_9.json']


In [27]:
from tqdm import tqdm
import numpy as np

K = 50

result = []
with torch.no_grad():
    for batch in (pbar := tqdm(loader)):
        x = batch[0].view(-1, 1)
        seq_len = x.shape[0]
        src_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        predictions = transformer(x, src_mask=src_mask, apply_softmax=False)
        topk = torch.topk(predictions, K, dim=2).indices.squeeze(1)
        x_ = x.view(-1).tolist()
        z = []
        for i in range(seq_len-1):
            tops = topk[i].tolist()
            hits = len(set(tops) & set(x_[i+1:]))
            if hits > 0:
                z.append(1)
            else:
                z.append(0)
        result.append(np.mean(z))
        pbar.set_description(f"{np.mean(result):.5f}")

0.26419:   0%|                                                                              | 211/1000000 [00:21<28:38:53,  9.69it/s]


KeyboardInterrupt: 