In [None]:
! pip install -Uqr requirements.txt

In [2]:
import os
import json
from typing import List
import pandas as pd
import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.nn.init import xavier_normal_, constant_
from torch.utils.data import DataLoader, Dataset
from catalyst import dl, metrics
from catalyst.utils import set_global_seed

In [3]:
class InteractionsDataset(Dataset):
    def __init__(self, interactions_pickle_path: str):
        data = pd.read_pickle(interactions_pickle_path)
        users = data['user'].to_numpy()
        tracks = data['track'].to_numpy()

        i = torch.from_numpy(np.stack((users, tracks)).astype("int64"))
        v = torch.ones(data.shape[0])

        self.interactions = torch.sparse.FloatTensor(i, v)

    def __len__(self):
        return len(self.interactions)

    def __getitem__(self, idx):
        return self.interactions[idx].to_dense()

In [4]:
def collate_fn(batch: List[torch.Tensor]) -> torch.Tensor:
    return {"inputs": torch.stack(batch), "targets": torch.stack(batch)}

In [5]:
class MultiDAE(nn.Module):
    def __init__(self, p_dims, q_dims=None, dropout=0.5):
        super().__init__()
        self.p_dims = p_dims
        if q_dims:
            assert q_dims[0] == p_dims[-1], "In and Out dimensions must equal to each other"
            assert q_dims[-1] == p_dims[0], "Latent dimension for p- and q- network mismatches."
            self.q_dims = q_dims
        else:
            self.q_dims = p_dims[::-1]

        self.dims = self.q_dims + self.p_dims[1:]
        self.layers = nn.ModuleList([nn.Linear(d_in, d_out) for
            d_in, d_out in zip(self.dims[:-1], self.dims[1:])])
        self.drop = nn.Dropout(dropout)
        
        self.init_weights()
    
    def forward(self, input):
        h = F.normalize(input)
        h = self.drop(h)

        for i, layer in enumerate(self.layers):
            h = layer(h)
            if i != len(self.layers) - 1:
                h = torch.tanh(h)
        return h

    def init_weights(self):
        for layer in self.layers:
            xavier_normal_(layer.weight.data)
            constant_(layer.bias.data, 0)

In [6]:
set_global_seed(42)

In [7]:
# For top_k tracks recommendation
top_k = 50

In [8]:
train_dataset = InteractionsDataset("user-track.pkl")
loaders = {
    "train": DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn),
}

In [9]:
item_num = len(train_dataset[0])
model = MultiDAE([50, 300, item_num], dropout=0.5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
engine = dl.DeviceEngine()

In [10]:
callbacks = [
    dl.NDCGCallback("logits", "targets", [top_k]),
    dl.MAPCallback("logits", "targets", [top_k]),
    dl.MRRCallback("logits", "targets", [top_k]),
    dl.HitrateCallback("logits", "targets", [top_k]),
    dl.OptimizerCallback("loss", accumulation_steps=1),
]

In [11]:
runner = dl.SupervisedRunner(
    input_key="inputs", output_key="logits", target_key="targets", loss_key="loss"
)

In [12]:
runner.train(
  model=model,
  optimizer=optimizer,
  criterion=criterion,
  engine=engine,
  loaders=loaders, 
  num_epochs=10,
  verbose=True,
  timeit=False,
  callbacks=callbacks,
  logdir="./logs",
)

1/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (1/10) hitrate50: 0.02771001897230745 | hitrate50/std: 0.007372049422105452 | loss: 1141.0121966796871 | loss/mean: 1141.0121966796871 | loss/std: 125.0725787333086 | lr: 0.001 | map50: 0.08959206992387767 | map50/std: 0.022168381126007605 | momentum: 0.9 | mrr50: 0.11869113146066669 | mrr50/std: 0.03510939726795003 | ndcg50: 0.044124188467860224 | ndcg50/std: 0.0108324641600768
* Epoch (1/10) 


2/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (2/10) hitrate50: 0.050233050251007084 | hitrate50/std: 0.017648483853364153 | loss: 1124.923616210937 | loss/mean: 1124.923616210937 | loss/std: 124.77262995455854 | lr: 0.001 | map50: 0.18661329555511463 | map50/std: 0.07074875749611184 | momentum: 0.9 | mrr50: 0.23683163228034976 | mrr50/std: 0.09030268000052903 | ndcg50: 0.0891028039574623 | ndcg50/std: 0.03267231277851972
* Epoch (2/10) 


3/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (3/10) hitrate50: 0.1021085259079933 | hitrate50/std: 0.017299145658003997 | loss: 1097.418998437501 | loss/mean: 1097.418998437501 | loss/std: 124.19059911747867 | lr: 0.001 | map50: 0.3822198646545411 | map50/std: 0.05858738519085576 | momentum: 0.9 | mrr50: 0.4786859772682191 | mrr50/std: 0.0778157380620313 | ndcg50: 0.1872953398942947 | ndcg50/std: 0.029209548416805356
* Epoch (3/10) 


4/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (4/10) hitrate50: 0.13583794407844543 | hitrate50/std: 0.017210331637224522 | loss: 1073.3517447265633 | loss/mean: 1073.3517447265633 | loss/std: 123.35381631068275 | lr: 0.001 | map50: 0.4792158036231995 | map50/std: 0.04816452479046939 | momentum: 0.9 | mrr50: 0.6063136577606201 | mrr50/std: 0.0637752595817542 | ndcg50: 0.24710329809188858 | ndcg50/std: 0.02631341272986718
* Epoch (4/10) 


5/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (5/10) hitrate50: 0.16234579267501836 | hitrate50/std: 0.017898487349513822 | loss: 1053.08764765625 | loss/mean: 1053.08764765625 | loss/std: 121.9131781397834 | lr: 0.001 | map50: 0.5456693675994875 | map50/std: 0.04352279809105382 | momentum: 0.9 | mrr50: 0.6992070158004763 | mrr50/std: 0.05757661023102915 | ndcg50: 0.2944851998329164 | ndcg50/std: 0.025323797205005338
* Epoch (5/10) 


6/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (6/10) hitrate50: 0.181679709148407 | hitrate50/std: 0.017930186100983934 | loss: 1036.9584318359375 | loss/mean: 1036.9584318359375 | loss/std: 120.65008127364702 | lr: 0.001 | map50: 0.5821064191818235 | map50/std: 0.035546333716710565 | momentum: 0.9 | mrr50: 0.7506896734237671 | mrr50/std: 0.04991274351058211 | ndcg50: 0.32676026515960677 | ndcg50/std: 0.024705740385988835
* Epoch (6/10) 


7/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (7/10) hitrate50: 0.20066012344360343 | hitrate50/std: 0.019723475874187043 | loss: 1021.8781378906251 | loss/mean: 1021.8781378906251 | loss/std: 119.4447459239595 | lr: 0.001 | map50: 0.6175002902984619 | map50/std: 0.0344449835259542 | momentum: 0.9 | mrr50: 0.8013053871154786 | mrr50/std: 0.045770184515017666 | ndcg50: 0.3597115365982057 | ndcg50/std: 0.025468488419576222
* Epoch (7/10) 


8/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (8/10) hitrate50: 0.21726163568496695 | hitrate50/std: 0.019695218638443937 | loss: 1008.5955379882814 | loss/mean: 1008.5955379882814 | loss/std: 118.1334362730233 | lr: 0.001 | map50: 0.6403780801773075 | map50/std: 0.03246477767393164 | momentum: 0.9 | mrr50: 0.830375368499756 | mrr50/std: 0.04070047766620645 | ndcg50: 0.3856836753845214 | ndcg50/std: 0.025437281313332345
* Epoch (8/10) 


9/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (9/10) hitrate50: 0.2291357434749602 | hitrate50/std: 0.020220849328581893 | loss: 997.4983689453128 | loss/mean: 997.4983689453128 | loss/std: 117.32356382145517 | lr: 0.001 | map50: 0.6528657202720642 | map50/std: 0.03088319662070991 | momentum: 0.9 | mrr50: 0.8486783336639406 | mrr50/std: 0.03740061134501038 | ndcg50: 0.4048163892745971 | ndcg50/std: 0.025821068642376637
* Epoch (9/10) 


10/10 * Epoch (train):   0%|          | 0/157 [00:00<?, ?it/s]

train (10/10) hitrate50: 0.2399987279415131 | hitrate50/std: 0.021566481877022512 | loss: 987.3791653320308 | loss/mean: 987.3791653320308 | loss/std: 116.1773971689306 | lr: 0.001 | map50: 0.6680257558822631 | map50/std: 0.029461707948112665 | momentum: 0.9 | mrr50: 0.8688665008544924 | mrr50/std: 0.035047296617086963 | ndcg50: 0.42308673896789556 | ndcg50/std: 0.02714243027944873
* Epoch (10/10) 
Top best models:
logs/checkpoints/train.10.pth	10.0000


In [13]:
%%time
with open("recommendations.json", "w") as rf:    
    for batch, prediction in enumerate(runner.predict_loader(loader=loaders["train"])):
        preds = prediction["logits"].detach().cpu().numpy()
        for i, pred in enumerate(preds):
            user = loaders["train"].batch_size * batch + i
            pred_tracks = np.argsort(pred)[::-1][:top_k]
            
            recommendation = {
                "user": user,
                "tracks": pred_tracks.tolist(),
            }
            rf.write(json.dumps(recommendation) + "\n")

CPU times: user 1min 28s, sys: 750 ms, total: 1min 28s
Wall time: 1min 28s
