In [1]:
!python -m pip install lightning gensim wandb



In [2]:
import os
from argparse import ArgumentParser

import lightning as pl
import torch
import numpy as np
from gensim.models import Word2Vec
from loader import PlaylistDataset
from model import TransformerModel
from train import MaskedLanguageModel
from transform import *

In [3]:
PAD_TOKEN = 0

MODEL_PATH = "../models/song2vec"
DATA_PATH ="../data"

NHEADS = 4
NLAYERS = 2
DROPOUT = .2
DHIDDEN = 256

SEQLEN = 75
PPF = 50000

BATCH_SIZE = 48

In [4]:
wv = Word2Vec.load(MODEL_PATH).wv
dim = wv.vectors.shape[1]
# TODO save padding vector
embeddings = np.concatenate((np.random.normal(size=(1, dim)), wv.vectors), axis=0)
transformer = TransformerModel(
    embeddings=torch.tensor(embeddings),
    nhead=NHEADS,
    nlayers=NLAYERS,
    dropout=DROPOUT,
    d_hid=DHIDDEN
)
m = MaskedLanguageModel(transformer, PAD_TOKEN, device="cuda")

In [5]:
files = sorted([os.path.join(DATA_PATH, f) for f in os.listdir(DATA_PATH) if ".json" in f])
print(files)

transforms = Compose(
    RemoveUnknownTracks(wv.key_to_index.keys()),
    TrackURI2Idx(wv.key_to_index, offset=1),
    PadOrTrim(PAD_TOKEN, SEQLEN),
    ToLongTensor()
)

ds = PlaylistDataset(files, playlist_per_file=PPF, transform=transforms)
loader = torch.utils.data.DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False)

['../data/chunk_0.json', '../data/chunk_1.json', '../data/chunk_10.json', '../data/chunk_11.json', '../data/chunk_12.json', '../data/chunk_13.json', '../data/chunk_14.json', '../data/chunk_15.json', '../data/chunk_16.json', '../data/chunk_17.json', '../data/chunk_18.json', '../data/chunk_19.json', '../data/chunk_2.json', '../data/chunk_3.json', '../data/chunk_4.json', '../data/chunk_5.json', '../data/chunk_6.json', '../data/chunk_7.json', '../data/chunk_8.json', '../data/chunk_9.json']


In [12]:
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(project="song2vec_transformer", log_model="all")
trainer = pl.Trainer(gradient_clip_val=0.5, accumulate_grad_batches=4, logger=wandb_logger)
wandb_logger.watch(m)

[34m[1mwandb[0m: Currently logged in as: [33mricsi[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


In [13]:
trainer.fit(m, loader)

  rank_zero_warn(
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | TransformerModel | 36.1 M
-------------------------------------------
100.0 K   Trainable params
36.0 M    Non-trainable params
36.1 M    Total params
144.383   Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [16]:
model = m.model.to("cpu")

In [17]:
torch.save(model.state_dict(), "../models/transformer/transformer_model.pt")

In [25]:
m2 = TransformerModel(
    embeddings=torch.tensor(embeddings),
    nhead=NHEADS,
    nlayers=NLAYERS,
    dropout=DROPOUT,
    d_hid=DHIDDEN
)
m2.load_state_dict(torch.load("../models/transformer/transformer_model.pt", map_location="cpu"))

<All keys matched successfully>

In [26]:
for p1, p2 in zip(model.parameters(), m2.parameters()):
    assert torch.allclose(p1, p2)