# MF

In [20]:
# 加载依赖
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from utils.movielens_dataset import MovieLensDatasetWithTorch
from utils.metric import Metric
from utils.trainer import Trainer

In [21]:
# 使用的超参数
config = {
    'TRAIN_BATCH_SIZE': 8,
    'VALID_BATCH_SIZE': 4,
    'TEST_BATCH_SIZE': 4,
    'DATASET_RATIO': [0.8, 0.1, 0.1],
    'DEVICE': 'cpu',
    'NUM_WORKERS': 0,
    'EPOCH': 30,
    'NUM_FEATURE': 111,
    'POS_WEIGHT': 2,
    'LEARNING_RATE': 1e-3,
    'K': 8,
    'TASK': 'recommend',
    'USER_NUM': 611,
    'ITEM_NUM': 193610,
}

In [22]:
# 加载数据集
# MovieLens
with open('../dataset/ml-latest-small-ratings.txt', 'r', encoding='utf-8') as f:
    ml_dataset = MovieLensDatasetWithTorch(f, task='rating')
dataset_size = len(ml_dataset)
dataset_split_size = [int(dataset_size * r) for r in config['DATASET_RATIO']]
if sum(dataset_split_size) != dataset_size:
    dataset_split_size[-1] += dataset_size - sum(dataset_split_size)
train_set, valid_set, test_set = torch.utils.data.random_split(ml_dataset, dataset_split_size)

user_num, item_num = ml_dataset.user_num, ml_dataset.item_num

In [23]:
ml_dataset[610, 193609]

[(610, 193609), tensor(0.)]

In [24]:
train_loader = DataLoader(
    dataset=train_set,
    batch_size=config['TRAIN_BATCH_SIZE'],
    shuffle=True,
    num_workers=config['NUM_WORKERS']
)

valid_loader = DataLoader(
    dataset=valid_set,
    batch_size=config['VALID_BATCH_SIZE'],
    shuffle=False,
    num_workers=config['NUM_WORKERS']
)

test_loader = DataLoader(
    dataset=test_set,
    batch_size=config['TEST_BATCH_SIZE'],
    shuffle=False,
    num_workers=config['NUM_WORKERS']
)

In [25]:
# 构建模型
class MF(nn.Module):
    def __init__(self, user_num, item_num, k):
        super(MF, self).__init__()
        U = torch.zeros((user_num, k))
        I = torch.zeros((item_num, k))
        user_bias = torch.zeros((user_num, k))
        item_bias = torch.zeros((item_num, k))

        nn.init.xavier_uniform_(U)
        nn.init.xavier_uniform_(I)
        nn.init.xavier_uniform_(user_bias)
        nn.init.xavier_uniform_(item_bias)

        self.U = nn.Parameter(U, requires_grad=True)
        self.I = nn.Parameter(I, requires_grad=True)
        self.user_bias = nn.Parameter(user_bias, requires_grad=True)
        self.item_bias = nn.Parameter(item_bias, requires_grad=True)

    def forward(self, batch_data):
        user_idx_lst = batch_data[0]
        sparse_indices = batch_data[1].coalesce().indices()
        for i in range(len(user_idx_lst)):
            sparse_indices[0, :] = torch.where(sparse_indices[0, :] == i, user_idx_lst[i], sparse_indices[0, :])
        user_idx = sparse_indices[0, :]
        item_idx = sparse_indices[1, :]

        user_factor = self.U[user_idx, :]
        item_factor = self.I[item_idx, :]
        user_bias = self.user_bias[user_idx, :]
        item_bias = self.item_bias[item_idx, :]
        # print(user_factor.shape)
        # print(item_factor.shape)
        return torch.sum(user_factor * item_factor + user_bias + item_bias, dim=1)

    def rec(self, user_idx_lst, k=50):
        # print(user_idx_lst)
        return torch.topk(torch.matmul(self.U[user_idx_lst, :], self.I.t()), k=k, dim=1)

In [26]:
class MFTrainer(Trainer):
    def step(self, batch_data, mode='train', **param_for_rec):
        device = self.config['DEVICE']

        def compute_pred(batch_data):
            # print(batch_data[1].shape)
            pred = self.model(batch_data)
            loss = self.loss_func(pred, batch_data[1].coalesce().values())
            gt_rec_lst = param_for_rec['gt_rec_lst']
            pred_rec_lst = param_for_rec['pred_rec_lst']
            if mode == 'evaluate':
                user_idx_lst = batch_data[0]
                sparse_indices = batch_data[1].coalesce().indices()
                for i in range(len(user_idx_lst)):
                    tmp_idx = torch.nonzero(sparse_indices[0, :] == i).view(-1)
                    gt_rec_lst.extend(sparse_indices[1, tmp_idx].tolist())
                pred_rec_lst.extend((self.model.rec(user_idx_lst)[1] + 1).tolist())
                self.metric.compute_metric(gt_rec_lst, pred_rec_lst, 200000, len(gt_rec_lst))

            return pred, loss, gt_rec_lst, pred_rec_lst

        if mode == 'train':
            self.model.train()
            self.optimizer.zero_grad()
            pred, loss = compute_pred(batch_data)
            loss.backward()
            self.optimizer.step()
            return loss.item(), pred
        elif mode == 'evaluate':
            with torch.no_grad():
                self.model.eval()
                pred, loss, gt_rec_lst, pred_rec_lst = compute_pred(batch_data)

                return loss.item(), pred, gt_rec_lst, pred_rec_lst
        else:
            raise ValueError("Wrong Mode")

In [27]:
model = MF(user_num, item_num, k=config['K'])
optimizer = optim.Adam(lr=config['LEARNING_RATE'], params=model.parameters())
loss_func = nn.MSELoss()
metric = Metric(k=(1, 10, 50))

trainer = Trainer(
    model=model,
    loss_func=loss_func,
    optimizer=optimizer,
    metric=metric,
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    config=config,
)

TRAIN_BATCH_SIZE: 8
VALID_BATCH_SIZE: 4
TEST_BATCH_SIZE: 4
DATASET_RATIO: [0.8, 0.1, 0.1]
DEVICE: cpu
NUM_WORKERS: 0
EPOCH: 30
NUM_FEATURE: 111
POS_WEIGHT: 2
LEARNING_RATE: 0.001
K: 8
TASK: recommend
USER_NUM: 611
ITEM_NUM: 193610


In [28]:
if __name__ == '__main__':
    trainer.train()
    trainer.test()



100%|██████████| 61/61 [00:00<00:00, 114.16it/s]


Train Epoch: 1
Loss: 12.914455585792416


100%|██████████| 16/16 [00:00<00:00, 319.88it/s]


Valid Epoch: 1
hit:
	hit@1: 0.0492
	hit@10: 0.2623
	hit@50: 0.5574
gini_index:
	gini_index@1: 0.1157
	gini_index@10: 0.2734
	gini_index@50: 0.3429
diversity:
	diversity@1: 0.9951
	diversity@10: 0.9785
	diversity@50: 0.9656
nDCG:
	nDCG@1: 0.0492
	nDCG@10: 0.1418
	nDCG@50: 0.1983
MRR:
	MRR@1: 0.0492
	MRR@10: 0.1076
	MRR@50: 0.1211



100%|██████████| 61/61 [00:00<00:00, 111.77it/s]


Train Epoch: 2
Loss: 10.890722290414278


100%|██████████| 16/16 [00:00<00:00, 289.37it/s]


Valid Epoch: 2
hit:
	hit@1: 0.0492
	hit@10: 0.2295
	hit@50: 0.4590
gini_index:
	gini_index@1: 0.0876
	gini_index@10: 0.2650
	gini_index@50: 0.3368
diversity:
	diversity@1: 0.9967
	diversity@10: 0.9802
	diversity@50: 0.9676
nDCG:
	nDCG@1: 0.0492
	nDCG@10: 0.1310
	nDCG@50: 0.1715
MRR:
	MRR@1: 0.0492
	MRR@10: 0.1007
	MRR@50: 0.1102



100%|██████████| 61/61 [00:00<00:00, 105.85it/s]


Train Epoch: 3
Loss: 9.005512221914824


100%|██████████| 16/16 [00:00<00:00, 264.64it/s]


Valid Epoch: 3
hit:
	hit@1: 0.0656
	hit@10: 0.2459
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.0627
	gini_index@10: 0.2416
	gini_index@50: 0.3220
diversity:
	diversity@1: 0.9973
	diversity@10: 0.9843
	diversity@50: 0.9713
nDCG:
	nDCG@1: 0.0656
	nDCG@10: 0.1310
	nDCG@50: 0.1927
MRR:
	MRR@1: 0.0656
	MRR@10: 0.1120
	MRR@50: 0.1235



100%|██████████| 61/61 [00:00<00:00, 109.39it/s]


Train Epoch: 4
Loss: 7.6024624558745835


100%|██████████| 16/16 [00:00<00:00, 293.64it/s]


Valid Epoch: 4
hit:
	hit@1: 0.1148
	hit@10: 0.2295
	hit@50: 0.5410
gini_index:
	gini_index@1: 0.1060
	gini_index@10: 0.2347
	gini_index@50: 0.3263
diversity:
	diversity@1: 0.9951
	diversity@10: 0.9847
	diversity@50: 0.9699
nDCG:
	nDCG@1: 0.1148
	nDCG@10: 0.1555
	nDCG@50: 0.2225
MRR:
	MRR@1: 0.1148
	MRR@10: 0.1512
	MRR@50: 0.1682



100%|██████████| 61/61 [00:00<00:00, 113.44it/s]


Train Epoch: 5
Loss: 6.3585824301985445


100%|██████████| 16/16 [00:00<00:00, 306.21it/s]


Valid Epoch: 5
hit:
	hit@1: 0.1311
	hit@10: 0.3770
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.1157
	gini_index@10: 0.2474
	gini_index@50: 0.3382
diversity:
	diversity@1: 0.9951
	diversity@10: 0.9837
	diversity@50: 0.9671
nDCG:
	nDCG@1: 0.1311
	nDCG@10: 0.2138
	nDCG@50: 0.2419
MRR:
	MRR@1: 0.1311
	MRR@10: 0.1855
	MRR@50: 0.1912



100%|██████████| 61/61 [00:00<00:00, 114.11it/s]


Train Epoch: 6
Loss: 5.286051621202563


100%|██████████| 16/16 [00:00<00:00, 297.78it/s]


Valid Epoch: 6
hit:
	hit@1: 0.1311
	hit@10: 0.3770
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.1626
	gini_index@10: 0.2596
	gini_index@50: 0.3548
diversity:
	diversity@1: 0.9918
	diversity@10: 0.9802
	diversity@50: 0.9605
nDCG:
	nDCG@1: 0.1311
	nDCG@10: 0.2405
	nDCG@50: 0.2612
MRR:
	MRR@1: 0.1311
	MRR@10: 0.1947
	MRR@50: 0.2006



100%|██████████| 61/61 [00:00<00:00, 105.72it/s]


Train Epoch: 7
Loss: 4.46349814289906


100%|██████████| 16/16 [00:00<00:00, 273.34it/s]


Valid Epoch: 7
hit:
	hit@1: 0.1475
	hit@10: 0.3770
	hit@50: 0.4754
gini_index:
	gini_index@1: 0.1779
	gini_index@10: 0.2883
	gini_index@50: 0.3732
diversity:
	diversity@1: 0.9896
	diversity@10: 0.9736
	diversity@50: 0.9540
nDCG:
	nDCG@1: 0.1475
	nDCG@10: 0.2418
	nDCG@50: 0.2611
MRR:
	MRR@1: 0.1475
	MRR@10: 0.1989
	MRR@50: 0.2046



100%|██████████| 61/61 [00:00<00:00, 102.68it/s]


Train Epoch: 8
Loss: 3.721196170713081


100%|██████████| 16/16 [00:00<00:00, 245.97it/s]


Valid Epoch: 8
hit:
	hit@1: 0.1148
	hit@10: 0.3934
	hit@50: 0.5082
gini_index:
	gini_index@1: 0.1577
	gini_index@10: 0.3107
	gini_index@50: 0.3900
diversity:
	diversity@1: 0.9918
	diversity@10: 0.9673
	diversity@50: 0.9472
nDCG:
	nDCG@1: 0.1148
	nDCG@10: 0.2507
	nDCG@50: 0.2771
MRR:
	MRR@1: 0.1148
	MRR@10: 0.1956
	MRR@50: 0.2009



100%|██████████| 61/61 [00:00<00:00, 103.96it/s]


Train Epoch: 9
Loss: 3.1941740727815473


100%|██████████| 16/16 [00:00<00:00, 257.28it/s]


Valid Epoch: 9
hit:
	hit@1: 0.0984
	hit@10: 0.3770
	hit@50: 0.4918
gini_index:
	gini_index@1: 0.2356
	gini_index@10: 0.3212
	gini_index@50: 0.3966
diversity:
	diversity@1: 0.9836
	diversity@10: 0.9648
	diversity@50: 0.9444
nDCG:
	nDCG@1: 0.0984
	nDCG@10: 0.2448
	nDCG@50: 0.2745
MRR:
	MRR@1: 0.0984
	MRR@10: 0.1885
	MRR@50: 0.1956



100%|██████████| 61/61 [00:00<00:00, 104.94it/s]


Train Epoch: 10
Loss: 2.6780152711711946


100%|██████████| 16/16 [00:00<00:00, 261.48it/s]


Valid Epoch: 10
hit:
	hit@1: 0.1148
	hit@10: 0.3934
	hit@50: 0.4918
gini_index:
	gini_index@1: 0.2642
	gini_index@10: 0.3172
	gini_index@50: 0.4069
diversity:
	diversity@1: 0.9765
	diversity@10: 0.9649
	diversity@50: 0.9404
nDCG:
	nDCG@1: 0.1148
	nDCG@10: 0.2552
	nDCG@50: 0.2816
MRR:
	MRR@1: 0.1148
	MRR@10: 0.2037
	MRR@50: 0.2111



100%|██████████| 61/61 [00:00<00:00, 107.79it/s]


Train Epoch: 11
Loss: 2.3549357887174263


100%|██████████| 16/16 [00:00<00:00, 254.59it/s]


Valid Epoch: 11
hit:
	hit@1: 0.1311
	hit@10: 0.4098
	hit@50: 0.4918
gini_index:
	gini_index@1: 0.2448
	gini_index@10: 0.3222
	gini_index@50: 0.4050
diversity:
	diversity@1: 0.9749
	diversity@10: 0.9631
	diversity@50: 0.9400
nDCG:
	nDCG@1: 0.1311
	nDCG@10: 0.2663
	nDCG@50: 0.2853
MRR:
	MRR@1: 0.1311
	MRR@10: 0.2151
	MRR@50: 0.2211



100%|██████████| 61/61 [00:00<00:00, 103.96it/s]


Train Epoch: 12
Loss: 2.038222604110593


100%|██████████| 16/16 [00:00<00:00, 257.43it/s]


Valid Epoch: 12
hit:
	hit@1: 0.1311
	hit@10: 0.4262
	hit@50: 0.5082
gini_index:
	gini_index@1: 0.2448
	gini_index@10: 0.3341
	gini_index@50: 0.4032
diversity:
	diversity@1: 0.9749
	diversity@10: 0.9609
	diversity@50: 0.9407
nDCG:
	nDCG@1: 0.1311
	nDCG@10: 0.2664
	nDCG@50: 0.2875
MRR:
	MRR@1: 0.1311
	MRR@10: 0.2205
	MRR@50: 0.2253



100%|██████████| 61/61 [00:00<00:00, 100.54it/s]


Train Epoch: 13
Loss: 1.7851169207057014


100%|██████████| 16/16 [00:00<00:00, 242.95it/s]


Valid Epoch: 13
hit:
	hit@1: 0.1311
	hit@10: 0.4426
	hit@50: 0.5082
gini_index:
	gini_index@1: 0.2493
	gini_index@10: 0.3413
	gini_index@50: 0.4051
diversity:
	diversity@1: 0.9743
	diversity@10: 0.9597
	diversity@50: 0.9404
nDCG:
	nDCG@1: 0.1311
	nDCG@10: 0.2714
	nDCG@50: 0.2875
MRR:
	MRR@1: 0.1311
	MRR@10: 0.2204
	MRR@50: 0.2240



100%|██████████| 61/61 [00:00<00:00, 100.52it/s]


Train Epoch: 14
Loss: 1.6469305235831464


100%|██████████| 16/16 [00:00<00:00, 218.89it/s]


Valid Epoch: 14
hit:
	hit@1: 0.1311
	hit@10: 0.3770
	hit@50: 0.5082
gini_index:
	gini_index@1: 0.2135
	gini_index@10: 0.3381
	gini_index@50: 0.4004
diversity:
	diversity@1: 0.9765
	diversity@10: 0.9613
	diversity@50: 0.9428
nDCG:
	nDCG@1: 0.1311
	nDCG@10: 0.2461
	nDCG@50: 0.2814
MRR:
	MRR@1: 0.1311
	MRR@10: 0.2085
	MRR@50: 0.2174



100%|██████████| 61/61 [00:00<00:00, 100.86it/s]


Train Epoch: 15
Loss: 1.4714358273099681


100%|██████████| 16/16 [00:00<00:00, 247.28it/s]


Valid Epoch: 15
hit:
	hit@1: 0.1475
	hit@10: 0.3607
	hit@50: 0.5082
gini_index:
	gini_index@1: 0.2135
	gini_index@10: 0.3325
	gini_index@50: 0.3966
diversity:
	diversity@1: 0.9765
	diversity@10: 0.9628
	diversity@50: 0.9451
nDCG:
	nDCG@1: 0.1475
	nDCG@10: 0.2403
	nDCG@50: 0.2776
MRR:
	MRR@1: 0.1475
	MRR@10: 0.2120
	MRR@50: 0.2218



100%|██████████| 61/61 [00:00<00:00, 102.98it/s]


Train Epoch: 16
Loss: 1.3561536142083466


100%|██████████| 16/16 [00:00<00:00, 238.54it/s]


Valid Epoch: 16
hit:
	hit@1: 0.1475
	hit@10: 0.3607
	hit@50: 0.5082
gini_index:
	gini_index@1: 0.2426
	gini_index@10: 0.3307
	gini_index@50: 0.3920
diversity:
	diversity@1: 0.9738
	diversity@10: 0.9626
	diversity@50: 0.9470
nDCG:
	nDCG@1: 0.1475
	nDCG@10: 0.2402
	nDCG@50: 0.2732
MRR:
	MRR@1: 0.1475
	MRR@10: 0.2114
	MRR@50: 0.2218



100%|██████████| 61/61 [00:00<00:00, 97.12it/s]


Train Epoch: 17
Loss: 1.2571158565458704


100%|██████████| 16/16 [00:00<00:00, 251.95it/s]


Valid Epoch: 17
hit:
	hit@1: 0.1475
	hit@10: 0.3607
	hit@50: 0.5082
gini_index:
	gini_index@1: 0.2176
	gini_index@10: 0.3291
	gini_index@50: 0.3887
diversity:
	diversity@1: 0.9749
	diversity@10: 0.9646
	diversity@50: 0.9493
nDCG:
	nDCG@1: 0.1475
	nDCG@10: 0.2370
	nDCG@50: 0.2657
MRR:
	MRR@1: 0.1475
	MRR@10: 0.2112
	MRR@50: 0.2204



100%|██████████| 61/61 [00:00<00:00, 97.09it/s]


Train Epoch: 18
Loss: 1.1733797554109917


100%|██████████| 16/16 [00:00<00:00, 255.86it/s]


Valid Epoch: 18
hit:
	hit@1: 0.1311
	hit@10: 0.3279
	hit@50: 0.5082
gini_index:
	gini_index@1: 0.2190
	gini_index@10: 0.3215
	gini_index@50: 0.3837
diversity:
	diversity@1: 0.9732
	diversity@10: 0.9668
	diversity@50: 0.9523
nDCG:
	nDCG@1: 0.1311
	nDCG@10: 0.2120
	nDCG@50: 0.2518
MRR:
	MRR@1: 0.1311
	MRR@10: 0.1962
	MRR@50: 0.2069



100%|██████████| 61/61 [00:00<00:00, 100.15it/s]


Train Epoch: 19
Loss: 1.1107181408366218


100%|██████████| 16/16 [00:00<00:00, 237.70it/s]


Valid Epoch: 19
hit:
	hit@1: 0.1148
	hit@10: 0.3279
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.2156
	gini_index@10: 0.3157
	gini_index@50: 0.3783
diversity:
	diversity@1: 0.9776
	diversity@10: 0.9697
	diversity@50: 0.9540
nDCG:
	nDCG@1: 0.1148
	nDCG@10: 0.2035
	nDCG@50: 0.2448
MRR:
	MRR@1: 0.1148
	MRR@10: 0.1746
	MRR@50: 0.1835



100%|██████████| 61/61 [00:00<00:00, 96.61it/s]


Train Epoch: 20
Loss: 1.0490530817235102


100%|██████████| 16/16 [00:00<00:00, 244.45it/s]


Valid Epoch: 20
hit:
	hit@1: 0.0984
	hit@10: 0.3279
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.1998
	gini_index@10: 0.3117
	gini_index@50: 0.3780
diversity:
	diversity@1: 0.9814
	diversity@10: 0.9710
	diversity@50: 0.9550
nDCG:
	nDCG@1: 0.0984
	nDCG@10: 0.1884
	nDCG@50: 0.2372
MRR:
	MRR@1: 0.0984
	MRR@10: 0.1550
	MRR@50: 0.1637



100%|██████████| 61/61 [00:00<00:00, 99.23it/s] 


Train Epoch: 21
Loss: 1.00379444243478


100%|██████████| 16/16 [00:00<00:00, 248.81it/s]


Valid Epoch: 21
hit:
	hit@1: 0.0820
	hit@10: 0.3279
	hit@50: 0.5410
gini_index:
	gini_index@1: 0.1397
	gini_index@10: 0.3147
	gini_index@50: 0.3787
diversity:
	diversity@1: 0.9902
	diversity@10: 0.9719
	diversity@50: 0.9552
nDCG:
	nDCG@1: 0.0820
	nDCG@10: 0.1754
	nDCG@50: 0.2327
MRR:
	MRR@1: 0.0820
	MRR@10: 0.1344
	MRR@50: 0.1434



100%|██████████| 61/61 [00:00<00:00, 99.57it/s] 


Train Epoch: 22
Loss: 0.9551085214145848


100%|██████████| 16/16 [00:00<00:00, 234.21it/s]


Valid Epoch: 22
hit:
	hit@1: 0.0656
	hit@10: 0.3115
	hit@50: 0.5410
gini_index:
	gini_index@1: 0.1466
	gini_index@10: 0.3196
	gini_index@50: 0.3780
diversity:
	diversity@1: 0.9923
	diversity@10: 0.9717
	diversity@50: 0.9548
nDCG:
	nDCG@1: 0.0656
	nDCG@10: 0.1568
	nDCG@50: 0.2212
MRR:
	MRR@1: 0.0656
	MRR@10: 0.1164
	MRR@50: 0.1263



100%|██████████| 61/61 [00:00<00:00, 102.23it/s]


Train Epoch: 23
Loss: 0.9448880133081655


100%|██████████| 16/16 [00:00<00:00, 230.37it/s]


Valid Epoch: 23
hit:
	hit@1: 0.0328
	hit@10: 0.2787
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.1711
	gini_index@10: 0.3258
	gini_index@50: 0.3735
diversity:
	diversity@1: 0.9913
	diversity@10: 0.9701
	diversity@50: 0.9556
nDCG:
	nDCG@1: 0.0328
	nDCG@10: 0.1392
	nDCG@50: 0.2085
MRR:
	MRR@1: 0.0328
	MRR@10: 0.0950
	MRR@50: 0.1065



100%|██████████| 61/61 [00:00<00:00, 101.29it/s]


Train Epoch: 24
Loss: 0.8885579402329492


100%|██████████| 16/16 [00:00<00:00, 234.27it/s]


Valid Epoch: 24
hit:
	hit@1: 0.0328
	hit@10: 0.2623
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.1779
	gini_index@10: 0.3240
	gini_index@50: 0.3748
diversity:
	diversity@1: 0.9896
	diversity@10: 0.9714
	diversity@50: 0.9548
nDCG:
	nDCG@1: 0.0328
	nDCG@10: 0.1386
	nDCG@50: 0.2067
MRR:
	MRR@1: 0.0328
	MRR@10: 0.1000
	MRR@50: 0.1127



100%|██████████| 61/61 [00:00<00:00, 99.10it/s] 


Train Epoch: 25
Loss: 0.8799901506939872


100%|██████████| 16/16 [00:00<00:00, 233.97it/s]


Valid Epoch: 25
hit:
	hit@1: 0.0328
	hit@10: 0.2787
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.1679
	gini_index@10: 0.3307
	gini_index@50: 0.3763
diversity:
	diversity@1: 0.9913
	diversity@10: 0.9700
	diversity@50: 0.9543
nDCG:
	nDCG@1: 0.0328
	nDCG@10: 0.1394
	nDCG@50: 0.2031
MRR:
	MRR@1: 0.0328
	MRR@10: 0.0973
	MRR@50: 0.1082



100%|██████████| 61/61 [00:00<00:00, 100.32it/s]


Train Epoch: 26
Loss: 0.8364635715719129


100%|██████████| 16/16 [00:00<00:00, 262.36it/s]


Valid Epoch: 26
hit:
	hit@1: 0.0492
	hit@10: 0.1967
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.1827
	gini_index@10: 0.3254
	gini_index@50: 0.3753
diversity:
	diversity@1: 0.9896
	diversity@10: 0.9707
	diversity@50: 0.9546
nDCG:
	nDCG@1: 0.0492
	nDCG@10: 0.1146
	nDCG@50: 0.2006
MRR:
	MRR@1: 0.0492
	MRR@10: 0.0910
	MRR@50: 0.1085



100%|██████████| 61/61 [00:00<00:00, 99.36it/s] 


Train Epoch: 27
Loss: 0.8180449976295722


100%|██████████| 16/16 [00:00<00:00, 238.62it/s]


Valid Epoch: 27
hit:
	hit@1: 0.0328
	hit@10: 0.1967
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.2259
	gini_index@10: 0.3337
	gini_index@50: 0.3763
diversity:
	diversity@1: 0.9847
	diversity@10: 0.9687
	diversity@50: 0.9546
nDCG:
	nDCG@1: 0.0328
	nDCG@10: 0.1088
	nDCG@50: 0.1945
MRR:
	MRR@1: 0.0328
	MRR@10: 0.0813
	MRR@50: 0.0985



100%|██████████| 61/61 [00:00<00:00, 100.42it/s]


Train Epoch: 28
Loss: 0.7909796394285609


100%|██████████| 16/16 [00:00<00:00, 210.60it/s]


Valid Epoch: 28
hit:
	hit@1: 0.0328
	hit@10: 0.1967
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.2448
	gini_index@10: 0.3202
	gini_index@50: 0.3757
diversity:
	diversity@1: 0.9814
	diversity@10: 0.9710
	diversity@50: 0.9547
nDCG:
	nDCG@1: 0.0328
	nDCG@10: 0.1056
	nDCG@50: 0.1920
MRR:
	MRR@1: 0.0328
	MRR@10: 0.0796
	MRR@50: 0.0967



100%|██████████| 61/61 [00:00<00:00, 97.38it/s] 


Train Epoch: 29
Loss: 0.7705813797770954


100%|██████████| 16/16 [00:00<00:00, 252.65it/s]


Valid Epoch: 29
hit:
	hit@1: 0.0164
	hit@10: 0.2131
	hit@50: 0.5246
gini_index:
	gini_index@1: 0.2120
	gini_index@10: 0.3135
	gini_index@50: 0.3720
diversity:
	diversity@1: 0.9863
	diversity@10: 0.9719
	diversity@50: 0.9549
nDCG:
	nDCG@1: 0.0164
	nDCG@10: 0.1075
	nDCG@50: 0.1894
MRR:
	MRR@1: 0.0164
	MRR@10: 0.0730
	MRR@50: 0.0882



100%|██████████| 61/61 [00:00<00:00, 97.33it/s] 


Train Epoch: 30
Loss: 0.7711418886653713


100%|██████████| 16/16 [00:00<00:00, 260.20it/s]


Valid Epoch: 30
hit:
	hit@1: 0.0328
	hit@10: 0.2131
	hit@50: 0.5082
gini_index:
	gini_index@1: 0.2208
	gini_index@10: 0.3166
	gini_index@50: 0.3709
diversity:
	diversity@1: 0.9858
	diversity@10: 0.9712
	diversity@50: 0.9549
nDCG:
	nDCG@1: 0.0328
	nDCG@10: 0.1144
	nDCG@50: 0.1892
MRR:
	MRR@1: 0.0328
	MRR@10: 0.0846
	MRR@50: 0.0997



100%|██████████| 16/16 [00:00<00:00, 250.61it/s]


Test Loss: 2.5238191708922386
hit:
	hit@1: 0.1290
	hit@10: 0.2581
	hit@50: 0.3387
gini_index:
	gini_index@1: 0.2309
	gini_index@10: 0.3276
	gini_index@50: 0.3660
diversity:
	diversity@1: 0.9857
	diversity@10: 0.9710
	diversity@50: 0.9562
nDCG:
	nDCG@1: 0.1290
	nDCG@10: 0.1771
	nDCG@50: 0.1775
MRR:
	MRR@1: 0.1290
	MRR@10: 0.1645
	MRR@50: 0.1677



In [29]:
model.rec([0, 1])

torch.return_types.topk(
values=tensor([[0.1160, 0.1099, 0.1079, 0.1066, 0.1062, 0.1055, 0.1046, 0.1045, 0.1043,
         0.1030, 0.1030, 0.1023, 0.1022, 0.1005, 0.1001, 0.0999, 0.0994, 0.0986,
         0.0984, 0.0980, 0.0977, 0.0975, 0.0968, 0.0967, 0.0964, 0.0962, 0.0962,
         0.0959, 0.0957, 0.0956, 0.0956, 0.0955, 0.0953, 0.0952, 0.0952, 0.0949,
         0.0949, 0.0943, 0.0942, 0.0942, 0.0941, 0.0941, 0.0940, 0.0939, 0.0938,
         0.0936, 0.0933, 0.0933, 0.0933, 0.0932],
        [0.7879, 0.7620, 0.7618, 0.7144, 0.7096, 0.7091, 0.7006, 0.6990, 0.6962,
         0.6919, 0.6863, 0.6770, 0.6751, 0.6740, 0.6729, 0.6695, 0.6673, 0.6663,
         0.6657, 0.6643, 0.6625, 0.6603, 0.6598, 0.6594, 0.6591, 0.6589, 0.6585,
         0.6582, 0.6571, 0.6542, 0.6537, 0.6501, 0.6492, 0.6477, 0.6440, 0.6439,
         0.6430, 0.6429, 0.6428, 0.6424, 0.6418, 0.6416, 0.6411, 0.6397, 0.6393,
         0.6390, 0.6385, 0.6381, 0.6380, 0.6374]], grad_fn=<TopkBackward0>),
indices=tensor([[  1394,   1272