# бейзлайны:

In [1]:
import pandas as pd
import numpy as np
import torch
import pickle

from metrics import ( 
    calculate_grouped_ndcg_random, 
    calculate_grouped_ndcg_sum_popularity,
    calculate_grouped_ndcg_with_embeddings,
    calculate_grouped_ndcg_for_bert4rec_output
)

In [2]:
rnames = ['user_id', 'movie_id', 'rating', 'timestamp']
ratings = pd.read_table('movielens_1m_dataset/ratings.dat', sep='::', header=None, names=rnames, engine='python', encoding='ISO-8859-1')

In [3]:
ratings

Unnamed: 0,user_id,movie_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291
...,...,...,...,...
1000204,6040,1091,1,956716541
1000205,6040,1094,5,956704887
1000206,6040,562,5,956704746
1000207,6040,1096,4,956715648


In [4]:
ratings.sort_values(["user_id", "timestamp"], inplace=True)

ratings = ratings[["user_id", "movie_id", "rating"]]

ratings.reset_index(drop=True, inplace=True)

ratings = ratings.astype(int)

ratings

Unnamed: 0,user_id,movie_id,rating
0,1,3186,4
1,1,1270,5
2,1,1721,4
3,1,1022,5
4,1,2340,3
...,...,...,...
1000204,6040,2917,4
1000205,6040,1921,4
1000206,6040,1784,3
1000207,6040,161,3


In [5]:
ratings.user_id.nunique(), ratings.user_id.min(), ratings.user_id.max()

(6040, 1, 6040)

In [6]:
ratings.movie_id.nunique(), ratings.movie_id.min(), ratings.movie_id.max()

(3706, 1, 3952)

In [7]:
ratings.rating = 1

ratings

Unnamed: 0,user_id,movie_id,rating
0,1,3186,1
1,1,1270,1
2,1,1721,1
3,1,1022,1
4,1,2340,1
...,...,...,...
1000204,6040,2917,1
1000205,6040,1921,1
1000206,6040,1784,1
1000207,6040,161,1


In [8]:
# ratings = ratings[ratings.groupby("movie_id")['user_id'].transform('count') >= 5]
# дабы подстроиться под берт. потом вернем
ratings = ratings[ratings.groupby("user_id")['movie_id'].transform('count') >= 5]
ratings = ratings[ratings.groupby("movie_id")['user_id'].transform('count') >= 0]

ratings.reset_index(drop=True, inplace=True)

ratings

Unnamed: 0,user_id,movie_id,rating
0,1,3186,1
1,1,1270,1
2,1,1721,1
3,1,1022,1
4,1,2340,1
...,...,...,...
1000204,6040,2917,1
1000205,6040,1921,1
1000206,6040,1784,1
1000207,6040,161,1


In [9]:
print("statistics:", ratings.user_id.nunique(), ratings.movie_id.nunique())

statistics: 6040 3706


In [10]:
def create_index_mapping(df):
    """
    Создает словарь, отображающий айдишники юзеров и айтемов в
    небольшие различные числа (соответствующие индексам в enumerate)
    Зачем это нужно: implicit ALS хочет спарс матрицу для обучения, а размер scipy спарс
    матрицы зависит от максимального значения айдишника. Применение такого отображения к данным
    позволит работать с матрицами значительно меньших размеров.
    Parameters
        df: датасет, с колонок которого создаются словари
    Returns
        u2ix, i2ix: вышеописанные словари
    """
    u2ix = {user_id: i for i, user_id in enumerate(df.user_id.unique())}
    i2ix = {item_id: i for i, item_id in enumerate(df.movie_id.unique())}
    return u2ix, i2ix

def apply_index_mapping(df, u2ix, i2ix) -> pd.DataFrame:
    """
    Применяет вышеописанные словари к колонкам "user_id" и "movie_id".
    Parameters
        df: датасет, к которому применяется маппинг по словарям
        u2ix: словарь "user_id": "небольшое число"
        i2ix: словарь "movie_id": "небольшое число"
    Returns
        df: датасет с измененными (согласно u2ix и i2ix) значениями в колонках
    """
    df.user_id = df.user_id.map(lambda x: u2ix[x])
    df.movie_id = df.movie_id.map(lambda x: i2ix[x])
    return df

u2ix, i2ix = create_index_mapping(ratings)
ratings = apply_index_mapping(ratings, u2ix, i2ix)

In [11]:
ratings.user_id.min(), ratings.user_id.max()

(0, 6039)

In [12]:
ratings.movie_id.min(), ratings.movie_id.max()

(0, 3705)

leave-last-out:

In [13]:
test = ratings.groupby("user_id").tail(1)

test

Unnamed: 0,user_id,movie_id,rating
52,0,52,1
181,1,174,1
232,2,207,1
253,3,88,1
451,4,384,1
...,...,...,...
999522,6035,927,1
999724,6036,684,1
999744,6037,1587,1
999867,6038,579,1


In [14]:
train_and_val = ratings.drop(test.index, axis=0)

train_and_val

Unnamed: 0,user_id,movie_id,rating
0,0,0,1
1,0,1,1
2,0,2,1
3,0,3,1
4,0,4,1
...,...,...,...
1000203,6039,1097,1
1000204,6039,1248,1
1000205,6039,370,1
1000206,6039,89,1


In [15]:
val = train_and_val.groupby("user_id").tail(1)

val

Unnamed: 0,user_id,movie_id,rating
51,0,51,1
180,1,173,1
231,2,206,1
252,3,217,1
450,4,383,1
...,...,...,...
999521,6035,3356,1
999723,6036,244,1
999743,6037,252,1
999866,6038,562,1


In [16]:
train = train_and_val.drop(val.index, axis=0)

train

Unnamed: 0,user_id,movie_id,rating
0,0,0,1
1,0,1,1
2,0,2,1
3,0,3,1
4,0,4,1
...,...,...,...
1000202,6039,180,1
1000203,6039,1097,1
1000204,6039,1248,1
1000205,6039,370,1


In [17]:
train.reset_index(drop=True, inplace=True)
val.reset_index(drop=True, inplace=True)
test.reset_index(drop=True, inplace=True)

train_and_val.reset_index(drop=True, inplace=True) # нужно только для бейзлайнов и алс, мы там ниче не смотрим на val и не тюним

In [18]:
# так, берту не нужно время, ток порядок
train_dict = dict(train.groupby('user_id').apply(lambda d: list(d['movie_id'])))
val_dict = dict(val.groupby('user_id').apply(lambda d: list(d['movie_id'])))
test_dict = dict(test.groupby('user_id').apply(lambda d: list(d['movie_id'])))

In [19]:
dataset_for_bert4rec = {'train': train_dict,
                       'val': val_dict,
                       'test': test_dict,
                       'umap': u2ix,
                       'smap': i2ix}

In [20]:
# import pickle

# with open('prepared_dataset/dataset_for_bert4rec.pickle', 'wb') as handle:
#     pickle.dump(dataset_for_bert4rec, handle, protocol=pickle.HIGHEST_PROTOCOL)

... продолжаем:

In [21]:
train_and_val

Unnamed: 0,user_id,movie_id,rating
0,0,0,1
1,0,1,1
2,0,2,1
3,0,3,1
4,0,4,1
...,...,...,...
994164,6039,1097,1
994165,6039,1248,1
994166,6039,370,1
994167,6039,89,1


In [22]:
test

Unnamed: 0,user_id,movie_id,rating
0,0,52,1
1,1,174,1
2,2,207,1
3,3,88,1
4,4,384,1
...,...,...,...
6035,6035,927,1
6036,6036,684,1
6037,6037,1587,1
6038,6038,579,1


In [23]:
user_positively_interacted_with = ratings.groupby('user_id')['movie_id'].apply(set).to_dict()

In [24]:
assert len(user_positively_interacted_with[1]) == len(train_and_val[train_and_val.user_id == 1]) + 1

In [25]:
item_count = ratings["movie_id"].value_counts().to_dict()
item_probabilities = {k: v / sum([x for x in item_count.values()]) for k, v in item_count.items()}

In [26]:
item_probabilities[62]

0.003427283697707179

In [27]:
item_probabilities[731]

0.00035592561154718663

In [28]:
import numpy as np
from numpy.random import choice

negative_samples = dict()
test_rows = []

for _, row in test.iterrows():
    test_interactions = [(row["user_id"], row["movie_id"], row["rating"],)]
    
    np.random.seed(row["user_id"])
    negative_sampled_interactions = list(choice(list(item_probabilities.keys()), 800, replace=False, p=list(item_probabilities.values())))
    negative_sampled_interactions = [x for x in negative_sampled_interactions if x not in user_positively_interacted_with[row["user_id"]]]
    negative_sampled_interactions = negative_sampled_interactions[:100]
    
    # потом, когда будем делать берт, будем ссылаться на сгенерированные здесь негативы:
    negative_samples[row["user_id"]] = negative_sampled_interactions

    test_interactions.extend([(row["user_id"], x, 0,) for x in negative_sampled_interactions])
    
    test_rows.extend(test_interactions)
    
test_rows[:10]

[(0, 52, 1),
 (0, 1068, 0),
 (0, 731, 0),
 (0, 1026, 0),
 (0, 314, 0),
 (0, 1148, 0),
 (0, 1008, 0),
 (0, 1030, 0),
 (0, 1513, 0),
 (0, 2575, 0)]

In [29]:
negative_samples[0][:9]

[1068, 731, 1026, 314, 1148, 1008, 1030, 1513, 2575]

In [30]:
neg_sampled_test = pd.DataFrame(test_rows, columns=["user_id", "movie_id", "rating"])

neg_sampled_test

Unnamed: 0,user_id,movie_id,rating
0,0,52,1
1,0,1068,0
2,0,731,0
3,0,1026,0
4,0,314,0
...,...,...,...
610035,6039,442,0
610036,6039,581,0
610037,6039,147,0
610038,6039,1327,0


### random baseline:

In [29]:
ndcgresults = []
for i in range(10):
    ndcg_value = calculate_grouped_ndcg_random(train_and_val, neg_sampled_test, 10, i)
    
    ndcgresults.append(ndcg_value)
    
np.mean(ndcgresults), np.std(ndcgresults)

(0.04438110400529842, 0.001780072420999918)

### pop baseline:

In [30]:
calculate_grouped_ndcg_sum_popularity(train_and_val, neg_sampled_test, 10)

0.06391453558404456

### iALS baseline:

In [31]:
import scipy.sparse as sparse


def create_sparse_matrix(df: pd.DataFrame) -> sparse.csr_matrix:
    """
    Делает разреженную матрицу из пандас датафрейма. Нужно для
    обучения implicit.als.AlternatingLeastSquares.
    Parameters
        df (pd.DataFrame): датасет
    Returns
        csr (sparse.csr_matrix): разреженная матрица
    """
    csr = sparse.csr_matrix((df.rating, (df.user_id, df.movie_id)))
    return csr

sparse_train_and_val = create_sparse_matrix(train_and_val)

In [32]:
import implicit

In [33]:
iALS = implicit.als.AlternatingLeastSquares(factors=256, iterations=150)

  check_blas_config()


In [34]:
iALS.fit(sparse_train_and_val)

  0%|          | 0/150 [00:00<?, ?it/s]

In [35]:
calculate_grouped_ndcg_with_embeddings(neg_sampled_test, iALS, 10)

0.2209271961839956

In [36]:
BPRMF = implicit.bpr.BayesianPersonalizedRanking(factors=256, iterations=150)

In [37]:
BPRMF.fit(sparse_train_and_val)

  0%|          | 0/150 [00:00<?, ?it/s]

In [38]:
calculate_grouped_ndcg_with_embeddings(neg_sampled_test, BPRMF, 10)

0.2718568152928149

# начинаем берт4рек:

In [31]:
import argparse

import pandas as pd
import pickle
import random
import torch

from tqdm import trange
from collections import Counter

import numpy as np
from numpy.random import choice

In [32]:
# fix argparse in ipython
import sys
sys.argv = ['']

In [33]:
from bert4rec_modules_and_configs.utils import *
from bert4rec_modules_and_configs.options import args

In [34]:
# то что сохранили

def read_data(prepared_data_path):
    with open(prepared_data_path, 'rb') as handle:
        dataset = pickle.load(handle)
    return dataset

In [35]:
data = read_data(args.prepared_data_path)

train_data = data['train']
val_data = data['val']
test_data = data['test']
umap = data['umap']
smap = data['smap']

In [36]:
# отслеживаем путь юзера с индексом 0 (оригинальный айди 1)
print(train_data[0])
print(val_data[0])
print(test_data[0])
print(umap[1])

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
[51]
[52]
0


готовим трейн датасет и даталоадер:

In [37]:
class BertTrainDataset(torch.utils.data.Dataset):
    def __init__(self, u2seq, max_len, mask_prob, mask_token, num_items, rng):
        self.u2seq = u2seq
        self.users = sorted(self.u2seq.keys())  # дополнительный "надежный" (с заданным порядком) способ ходить по train_data, в отличие от хождения по train_data по ключу
        self.max_len = max_len
        self.mask_prob = mask_prob
        print("self.mask_prob", self.mask_prob)
        self.mask_token = mask_token
        print("self.mask_token", self.mask_token)
        self.num_items = num_items
        self.rng = rng

    def __len__(self):
        return len(self.users)

    def __getitem__(self, index):
        user = self.users[index]
        seq = self._getseq(user)

        tokens = []
        labels = []
        for s in seq: # касательно каждого s заполняем токен и лейбл
            prob = self.rng.random()  # сгенерили для конкретного s вероятность uniform, deterministic
            if prob < self.mask_prob:  # если prob меньше self.mask_prob == 0.15: TLDR С ВЕРОЯТНОСТЬЮ 80% ЗАПОЛНИМ ТОКЕН МАСКТОКЕНОМ, А НЕ ЭТИМ ITEM INDEX
                prob /= self.mask_prob  # то сильно бустим, затем...

                # if prob < 0.8:
                if prob < 0.99:
                    tokens.append(self.mask_token)  # если оч слабо то заполняем токен масктокеном - который max(item !INDEX!) + 1 или self.item_count + 1
                # if prob < 0.9:
                elif prob < 0.999:
                    tokens.append(self.rng.randint(1, self.num_items))  # если попал в маленькое окошко - то рандомно между 1 и self.item_count (== self.num_items) НАХУЯ?
                else:
                    tokens.append(s) # emergency(?) вариант, но лэйбл всё равно заполнится ненормальным значением - он заполнится ЭТИМ ITEM INDEX

                labels.append(s)
            else:
                tokens.append(s)
                labels.append(0)  # !!!ЭТО СТРАННО ВЕДЬ СУЩНОСТЬ S [0, MAX INDEX]!!!

        # в итоге каждому s присвоится либо (s, 0) если не замаскирован, либо (mask_token<под вопросом>, s). окей, для чего? !!!ЭТО СТРАННО ВЕДЬ СУЩНОСТЬ S [0, MAX INDEX]!!!

        # видимо max_len это не длина окна подпоследовательности из всей последовательности (если использование такой подпоследовательности ваще в этой реализации будет как-то фигурировать), а то, раньше чего мы 100% забываем
        tokens = tokens[-self.max_len:]
        labels = labels[-self.max_len:]
        # был 140 - стал 100
        # был 90 - стал 90

        # тупа паддинг, если изначальный seq был меньше max_len
        mask_len = self.max_len - len(tokens)

        # прилепляем нули слева. почему нули???? !!!ЭТО СТРАННО ВЕДЬ СУЩНОСТЬ S [0, MAX INDEX]!!!
        tokens = [0] * mask_len + tokens
        labels = [0] * mask_len + labels

        # лол ща попробуем маскировать с 10% вероятностью последний айтем
        prob = self.rng.random()
        if prob < 0.1:
            tokens[-1] = self.mask_token
            labels[-1] = seq[-1]

        # в итоге каждому s присвоится либо (s, 0) если не замаскирован, либо (mask_token<под вопросом><уже не под вопросом>, s) + в начале будут прилеплены нули, если не дотягивает до max_len, или иначе обрублено [-max_len:]. окей, для чего? !!!ЭТО СТРАННО ВЕДЬ СУЩНОСТЬ S [0, MAX INDEX]!!!
        # tokens    0   0   0   2969, 1574,   957,    1178,   <3707>, 1658,   <3707>, 1117
        # labels    0   0   0   0     0       0       0       2147    0       3177    0
        return torch.LongTensor(tokens), torch.LongTensor(labels)

        # КАК ЕЩЁ ЧАСТО ДЕЛАЕТСЯ И КАК ДЕЛАЛОСЬ В СТАТЬЕ: БЕРЕТСЯ СЛУЧАЙНО 10% ЮЗЕРСКИХ ПОСЛЕДОВАТЕЛЬНОСТЕЙ ИЗ ТРЕЙНА И ДЛЯ НИХ ТОЛЬКО ПОСЛЕДНИЙ ТОКЕН ЗАМЕНЯЕТСЯ МАСКОЙ. мы так не делаем, в представленной реализации которую я разбираю такого случайного выбора 10% из трейна (для навешивания маски только на конец) нет.
        # теперь я тоже так делаю
        
    def _getseq(self, user):
        return self.u2seq[user]

In [38]:
train_torch_dataset = BertTrainDataset(
    u2seq=      train_data,
    max_len=    args.bert_max_len,
    mask_prob=  args.bert_mask_prob,
    mask_token= len(smap) + 1,
    num_items=  len(smap),
    rng=        random.Random(args.dataloader_random_seed)
)

self.mask_prob 0.15
self.mask_token 3707


In [39]:
# отслеживаем путь юзера с индексом 0 (оригинальный айди 1)
print(train_torch_dataset[0])

(tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    1,    2,    3,    4,    5,    6,    7,    8,    9,   10,
          11,   12,   13,   14,   15,   16,   17,   18,   19,   20,   21,   22,
          23,   24, 3707,   26,   27,   28,   29,   30,   31,   32,   33,   34,
        3707,   36,   37,   38,   39, 3707,   41,   42,   43,   44,   45,   46,
          47,   48,   49,   50]), tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
     

In [40]:
train_torch_dataloader = torch.utils.data.DataLoader(
    dataset=train_torch_dataset,
    batch_size=args.train_batch_size,
    shuffle=True,
    pin_memory=True
)

готовим валидационные и тестовые датасет и даталоадер:

прежде чем сделать эти датасеты, сгенерируем негативы для юзеров, чтобы у них был не только один true positive на отложенной выборке, а ещё сотня негативов - тогда сможем замерить валидац метрики. семплируем по популярности (согласно статье)

In [41]:
# popularity = Counter()
# for user in range(len(umap)):
#     popularity.update(train_data[user])
#     popularity.update(val_data[user])
#     popularity.update(test_data[user])
# item_probabilities = {k: v / sum([x for x in popularity.values()]) for k, v in popularity.items()}

# print("PROB CHECK")
# print(item_probabilities[62])
# print(item_probabilities[731])

# negative_samples = {}
# print('Sampling negative items')
# for user in trange(len(umap)):
#     seen = set(train_data[user])
#     seen.update(val_data[user])
#     seen.update(test_data[user])

#     np.random.seed(user)

#     negative_sampled_interactions = list(choice(list(item_probabilities.keys()), 800, replace=False, p=list(item_probabilities.values())))
#     negative_sampled_interactions = [x for x in negative_sampled_interactions if x not in seen]
#     negative_sampled_interactions = negative_sampled_interactions[:100]

#     negative_samples[user] = negative_sampled_interactions
###

# взяли negative_samples, сделанные на этапе построения бейзлайнов, для 100% идентичности замера ndcg@10.

In [42]:
# отслеживаем путь юзера с индексом 0 (оригинальный айди 1)
print(negative_samples[0])

[1068, 731, 1026, 314, 1148, 1008, 1030, 1513, 2575, 388, 398, 525, 1438, 1682, 215, 258, 129, 2122, 1954, 1869, 2824, 859, 748, 983, 183, 1436, 233, 1803, 1182, 874, 1305, 2056, 752, 844, 971, 1065, 671, 2549, 2153, 325, 981, 1766, 60, 880, 1028, 557, 144, 948, 1302, 3160, 743, 424, 796, 947, 180, 1191, 458, 534, 386, 1110, 622, 208, 592, 2055, 739, 1215, 2751, 489, 2863, 502, 1095, 73, 138, 580, 522, 443, 699, 125, 653, 390, 832, 193, 1906, 284, 663, 147, 1022, 241, 597, 1294, 1776, 1688, 809, 1260, 1529, 738, 652, 631, 665, 203]


In [43]:
class BertEvalDataset(torch.utils.data.Dataset):
    def __init__(self, u2seq, u2answer, max_len, mask_token, negative_samples):
        self.u2seq = u2seq
        self.users = sorted(self.u2seq.keys()) # дополнительный "надежный" (с заданным порядком) способ ходить по train_data, в отличие от хождения по train_data по ключу
        self.u2answer = u2answer
        self.max_len = max_len
        self.mask_token = mask_token
        self.negative_samples = negative_samples

    def __len__(self):
        return len(self.users)

    def __getitem__(self, index):
        user = self.users[index]
        seq = self.u2seq[user]
        answer = self.u2answer[user]
        negs = self.negative_samples[user]

        candidates = answer + negs  # ну понятно, [228] + [142, 1488, 0, ...]
        labels = [1] * len(answer) + [0] * len(negs)  # сказали что первое позитив, остальное негативы

        seq = seq + [self.mask_token]  # прилепили <3707> в конец. теперь хорошо обученная модель сможет угадать его label

        seq = seq[-self.max_len:]  # эти три строки - всё так же как в трейне
        padding_len = self.max_len - len(seq)
        seq = [0] * padding_len + seq

        return torch.LongTensor(seq), torch.LongTensor(candidates), torch.LongTensor(labels)
        # например:
        # seq (в сущности как tokens в трейне)  0   0   2969,   1574,   957,    1178,   2147,   1658,   3177,   1117, <3707>
        #
        # candidates: [индекс айтема реального позитива; сгенерированные негативы]
        # labels:     [1, 0, 0, 0, 0, ...]

In [44]:
val_torch_dataset = BertEvalDataset(
    u2seq=              train_data,
    u2answer=           val_data,
    max_len=            args.bert_max_len,
    mask_token=         len(smap) + 1,
    negative_samples=   negative_samples
)

In [45]:
# отслеживаем путь юзера с индексом 0 (оригинальный айди 1)
print(val_torch_dataset[0])

(tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    1,    2,    3,    4,    5,    6,    7,    8,    9,   10,   11,
          12,   13,   14,   15,   16,   17,   18,   19,   20,   21,   22,   23,
          24,   25,   26,   27,   28,   29,   30,   31,   32,   33,   34,   35,
          36,   37,   38,   39,   40,   41,   42,   43,   44,   45,   46,   47,
          48,   49,   50, 3707]), tensor([  51, 1068,  731, 1026,  314, 1148, 1008, 1030, 1513, 2575,  388,  398,
         525, 1438, 1682,  215,  258,  129, 2122, 1954, 1869, 2824,  859,  748,
         983,  183, 1436,  233, 1803, 1182,  874, 1305, 2056,  752,  844,  971,
        1065,  671, 2549, 2153,  325,  981, 1766,   60,  880, 1028,  557,  144,
     

In [46]:
val_torch_dataloader = torch.utils.data.DataLoader(val_torch_dataset, batch_size=args.val_batch_size,
                                       shuffle=False, pin_memory=True)

In [47]:
test_torch_dataset = BertEvalDataset(
        u2seq=              train_data,
        u2answer=           test_data,
        max_len=            args.bert_max_len,
        mask_token=         len(smap) + 1,
        negative_samples=   negative_samples
    )

In [48]:
# отслеживаем путь юзера с индексом 0 (оригинальный айди 1)
print(test_torch_dataset[0])

(tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    1,    2,    3,    4,    5,    6,    7,    8,    9,   10,   11,
          12,   13,   14,   15,   16,   17,   18,   19,   20,   21,   22,   23,
          24,   25,   26,   27,   28,   29,   30,   31,   32,   33,   34,   35,
          36,   37,   38,   39,   40,   41,   42,   43,   44,   45,   46,   47,
          48,   49,   50, 3707]), tensor([  52, 1068,  731, 1026,  314, 1148, 1008, 1030, 1513, 2575,  388,  398,
         525, 1438, 1682,  215,  258,  129, 2122, 1954, 1869, 2824,  859,  748,
         983,  183, 1436,  233, 1803, 1182,  874, 1305, 2056,  752,  844,  971,
        1065,  671, 2549, 2153,  325,  981, 1766,   60,  880, 1028,  557,  144,
     

In [49]:
test_torch_dataloader = torch.utils.data.DataLoader(test_torch_dataset, batch_size=args.test_batch_size,
                                                       shuffle=False, pin_memory=True)

In [50]:
args.num_items = len(smap)

описываем модель:

In [53]:
from torch import nn as nn

from bert4rec_modules_and_configs.models.bert_modules.embedding import BERTEmbedding
from bert4rec_modules_and_configs.models.bert_modules.transformer import TransformerBlock


class BERT(nn.Module):
    def __init__(self, args):
        super().__init__()

        fix_random_seed_as(args.model_init_seed)
        # self.init_weights()

        max_len = args.bert_max_len
        num_items = args.num_items
        n_layers = args.bert_num_blocks
        heads = args.bert_num_heads
        vocab_size = num_items + 2
        hidden = args.bert_hidden_units
        self.hidden = hidden
        dropout = args.bert_dropout

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=self.hidden, max_len=max_len, dropout=dropout)

        # multi-layers transformer blocks, deep network
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, heads, hidden * 4, dropout) for _ in range(n_layers)])

    def forward(self, x):
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x)

        # running over multiple transformer blocks
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)

        return x

    def init_weights(self):
        pass

In [54]:
class BERTModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        
        self.args = args
        
        self.bert = BERT(args)
        self.out = nn.Linear(self.bert.hidden, args.num_items + 1)

    def code(cls):
        return 'bert'

    def forward(self, x):
        x = self.bert(x)
        return self.out(x)

In [55]:
# from bert4rec_modules_and_configs.models.bert import BERTModel

In [56]:
model = BERTModel(args)

In [57]:
# отслеживаем путь юзера с индексом 0 (оригинальный айди 1)
some_batch = next(iter(train_torch_dataloader))
seq, _ = some_batch
print(seq.shape)
print(seq[0, ...].unsqueeze(0).shape)
print(model(seq[0, ...].unsqueeze(0)).shape)

torch.Size([128, 100])
torch.Size([1, 100])
torch.Size([1, 100, 3707])


In [58]:
from bert4rec_modules_and_configs.trainers import trainer_factory

export_root = setup_train(args)

trainer = trainer_factory(args, model, train_torch_dataloader, val_torch_dataloader, test_torch_dataloader, export_root)

Folder created: C:\Users\atama\Documents\GitHub\bert4rec_ysda_seminar\seminar\experiments\test_2024-03-11_4
experiments\test_2024-03-11_4
experiments\test_2024-03-11_4\config.json

Namespace(mode='train', template=None, test_model_path=None, dataset_code='ml-1m', min_rating=0, min_uc=5, min_sc=0, split='leave_one_out', dataset_split_seed=98765, eval_set_size=500, prepared_data_path='prepared_dataset/dataset_for_bert4rec.pickle', dataloader_code='bert', dataloader_random_seed=0.0, train_batch_size=128, val_batch_size=128, test_batch_size=128, train_negative_sampler_code='random', train_negative_sample_size=0, train_negative_sampling_seed=0, test_negative_sampler_code='popular', test_negative_sample_size=100, test_negative_sampling_seed=98765, trainer_code='bert', device='cuda', num_gpu=1, device_idx='0', optimizer='Adam', lr=0.001, weight_decay=0, momentum=None, decay_step=25, gamma=1.0, num_epochs=100, log_period_as_iter=12800, metric_ks=[1, 5, 10, 20, 50, 100], best_metric='NDCG@10', 

In [59]:
trainer.train()

Val: N@1 0.007, N@5 0.028, N@10 0.045, R@1 0.007, R@5 0.050, R@10 0.102: 100%|██████████| 48/48 [00:04<00:00,  9.85it/s]


Update Best NDCG@10 Model at 1


Epoch 1, loss 7.829 : 100%|██████████| 48/48 [00:02<00:00, 18.39it/s]
Val: N@1 0.021, N@5 0.052, N@10 0.074, R@1 0.021, R@5 0.084, R@10 0.154: 100%|██████████| 48/48 [00:02<00:00, 21.99it/s]


Update Best NDCG@10 Model at 1


Epoch 2, loss 7.526 : 100%|██████████| 48/48 [00:02<00:00, 18.78it/s]
Val: N@1 0.025, N@5 0.066, N@10 0.089, R@1 0.025, R@5 0.107, R@10 0.180: 100%|██████████| 48/48 [00:02<00:00, 22.05it/s]


Update Best NDCG@10 Model at 2


Epoch 3, loss 7.369 : 100%|██████████| 48/48 [00:02<00:00, 18.60it/s] 
Val: N@1 0.038, N@5 0.094, N@10 0.125, R@1 0.038, R@5 0.150, R@10 0.248: 100%|██████████| 48/48 [00:02<00:00, 22.01it/s]


Update Best NDCG@10 Model at 3


Epoch 4, loss 7.182 : 100%|██████████| 48/48 [00:02<00:00, 18.77it/s]
Val: N@1 0.053, N@5 0.122, N@10 0.160, R@1 0.053, R@5 0.193, R@10 0.312: 100%|██████████| 48/48 [00:02<00:00, 22.20it/s]


Update Best NDCG@10 Model at 4


Epoch 5, loss 6.986 : 100%|██████████| 48/48 [00:02<00:00, 18.80it/s]  
Val: N@1 0.070, N@5 0.163, N@10 0.204, R@1 0.070, R@5 0.255, R@10 0.383: 100%|██████████| 48/48 [00:02<00:00, 21.57it/s]


Update Best NDCG@10 Model at 5


Epoch 6, loss 6.835 : 100%|██████████| 48/48 [00:02<00:00, 18.63it/s]
Val: N@1 0.095, N@5 0.203, N@10 0.249, R@1 0.095, R@5 0.306, R@10 0.450: 100%|██████████| 48/48 [00:02<00:00, 22.03it/s]


Update Best NDCG@10 Model at 6


Epoch 7, loss 6.678 : 100%|██████████| 48/48 [00:02<00:00, 18.61it/s]  
Val: N@1 0.108, N@5 0.228, N@10 0.277, R@1 0.108, R@5 0.342, R@10 0.494: 100%|██████████| 48/48 [00:02<00:00, 21.59it/s]


Update Best NDCG@10 Model at 7


Epoch 8, loss 6.564 : 100%|██████████| 48/48 [00:02<00:00, 18.34it/s]
Val: N@1 0.125, N@5 0.243, N@10 0.288, R@1 0.125, R@5 0.357, R@10 0.497: 100%|██████████| 48/48 [00:02<00:00, 21.18it/s]


Update Best NDCG@10 Model at 8


Epoch 9, loss 6.443 : 100%|██████████| 48/48 [00:02<00:00, 18.67it/s]  
Val: N@1 0.142, N@5 0.269, N@10 0.317, R@1 0.142, R@5 0.390, R@10 0.539: 100%|██████████| 48/48 [00:02<00:00, 21.99it/s]


Update Best NDCG@10 Model at 9


Epoch 10, loss 6.359 : 100%|██████████| 48/48 [00:02<00:00, 18.72it/s]
Val: N@1 0.149, N@5 0.282, N@10 0.327, R@1 0.149, R@5 0.409, R@10 0.550: 100%|██████████| 48/48 [00:02<00:00, 22.02it/s]


Update Best NDCG@10 Model at 10


Epoch 11, loss 6.285 : 100%|██████████| 48/48 [00:02<00:00, 18.87it/s] 
Val: N@1 0.151, N@5 0.283, N@10 0.329, R@1 0.151, R@5 0.408, R@10 0.548: 100%|██████████| 48/48 [00:02<00:00, 22.01it/s]


Update Best NDCG@10 Model at 11


Epoch 12, loss 6.206 : 100%|██████████| 48/48 [00:02<00:00, 18.62it/s]
Val: N@1 0.174, N@5 0.310, N@10 0.356, R@1 0.174, R@5 0.438, R@10 0.578: 100%|██████████| 48/48 [00:02<00:00, 21.79it/s]


Update Best NDCG@10 Model at 12


Epoch 13, loss 6.142 : 100%|██████████| 48/48 [00:02<00:00, 18.43it/s] 
Val: N@1 0.173, N@5 0.312, N@10 0.357, R@1 0.173, R@5 0.443, R@10 0.581: 100%|██████████| 48/48 [00:02<00:00, 21.91it/s]


Update Best NDCG@10 Model at 13


Epoch 14, loss 6.098 : 100%|██████████| 48/48 [00:02<00:00, 18.68it/s]
Val: N@1 0.176, N@5 0.314, N@10 0.360, R@1 0.176, R@5 0.442, R@10 0.585: 100%|██████████| 48/48 [00:02<00:00, 22.02it/s]


Update Best NDCG@10 Model at 14


Epoch 15, loss 6.040 : 100%|██████████| 48/48 [00:02<00:00, 18.61it/s] 
Val: N@1 0.192, N@5 0.328, N@10 0.374, R@1 0.192, R@5 0.452, R@10 0.593: 100%|██████████| 48/48 [00:02<00:00, 22.03it/s]


Update Best NDCG@10 Model at 15


Epoch 16, loss 5.996 : 100%|██████████| 48/48 [00:02<00:00, 18.79it/s]
Val: N@1 0.192, N@5 0.339, N@10 0.382, R@1 0.192, R@5 0.474, R@10 0.609: 100%|██████████| 48/48 [00:02<00:00, 21.51it/s]


Update Best NDCG@10 Model at 16


Epoch 17, loss 5.969 : 100%|██████████| 48/48 [00:02<00:00, 18.49it/s] 
Val: N@1 0.191, N@5 0.338, N@10 0.382, R@1 0.191, R@5 0.475, R@10 0.611: 100%|██████████| 48/48 [00:02<00:00, 21.44it/s]


Update Best NDCG@10 Model at 17


Epoch 18, loss 5.921 : 100%|██████████| 48/48 [00:02<00:00, 18.15it/s]
Val: N@1 0.201, N@5 0.348, N@10 0.392, R@1 0.201, R@5 0.484, R@10 0.619: 100%|██████████| 48/48 [00:02<00:00, 20.29it/s]


Update Best NDCG@10 Model at 18


Epoch 19, loss 5.867 : 100%|██████████| 48/48 [00:02<00:00, 18.50it/s]
Val: N@1 0.211, N@5 0.360, N@10 0.401, R@1 0.211, R@5 0.497, R@10 0.626: 100%|██████████| 48/48 [00:02<00:00, 22.01it/s]


Update Best NDCG@10 Model at 19


Epoch 20, loss 5.853 : 100%|██████████| 48/48 [00:02<00:00, 18.77it/s]
Val: N@1 0.206, N@5 0.356, N@10 0.399, R@1 0.206, R@5 0.493, R@10 0.625: 100%|██████████| 48/48 [00:02<00:00, 21.69it/s]
Epoch 21, loss 5.814 : 100%|██████████| 48/48 [00:02<00:00, 18.35it/s]
Val: N@1 0.216, N@5 0.366, N@10 0.411, R@1 0.216, R@5 0.504, R@10 0.640: 100%|██████████| 48/48 [00:02<00:00, 22.09it/s]


Update Best NDCG@10 Model at 21


Epoch 22, loss 5.775 : 100%|██████████| 48/48 [00:02<00:00, 18.51it/s] 
Val: N@1 0.218, N@5 0.368, N@10 0.410, R@1 0.218, R@5 0.506, R@10 0.635: 100%|██████████| 48/48 [00:02<00:00, 22.10it/s]
Epoch 23, loss 5.755 : 100%|██████████| 48/48 [00:02<00:00, 18.72it/s]
Val: N@1 0.216, N@5 0.362, N@10 0.406, R@1 0.216, R@5 0.496, R@10 0.631: 100%|██████████| 48/48 [00:02<00:00, 21.90it/s]
Epoch 24, loss 5.730 : 100%|██████████| 48/48 [00:02<00:00, 18.77it/s] 
Val: N@1 0.223, N@5 0.371, N@10 0.412, R@1 0.223, R@5 0.507, R@10 0.634: 100%|██████████| 48/48 [00:02<00:00, 21.64it/s]


Update Best NDCG@10 Model at 24


Epoch 25, loss 5.693 : 100%|██████████| 48/48 [00:02<00:00, 18.78it/s]
Val: N@1 0.223, N@5 0.375, N@10 0.417, R@1 0.223, R@5 0.514, R@10 0.645: 100%|██████████| 48/48 [00:02<00:00, 21.85it/s]


Update Best NDCG@10 Model at 25


Epoch 26, loss 5.681 : 100%|██████████| 48/48 [00:02<00:00, 18.87it/s] 
Val: N@1 0.223, N@5 0.375, N@10 0.418, R@1 0.223, R@5 0.514, R@10 0.644: 100%|██████████| 48/48 [00:02<00:00, 21.47it/s]


Update Best NDCG@10 Model at 26


Epoch 27, loss 5.651 : 100%|██████████| 48/48 [00:02<00:00, 18.56it/s]
Val: N@1 0.230, N@5 0.383, N@10 0.427, R@1 0.230, R@5 0.521, R@10 0.656: 100%|██████████| 48/48 [00:02<00:00, 21.96it/s]


Update Best NDCG@10 Model at 27


Epoch 28, loss 5.645 : 100%|██████████| 48/48 [00:02<00:00, 18.76it/s] 
Val: N@1 0.232, N@5 0.387, N@10 0.428, R@1 0.232, R@5 0.530, R@10 0.656: 100%|██████████| 48/48 [00:02<00:00, 21.55it/s]


Update Best NDCG@10 Model at 28


Epoch 29, loss 5.616 : 100%|██████████| 48/48 [00:02<00:00, 18.63it/s]
Val: N@1 0.239, N@5 0.387, N@10 0.427, R@1 0.239, R@5 0.524, R@10 0.647: 100%|██████████| 48/48 [00:02<00:00, 21.86it/s]
Epoch 30, loss 5.593 : 100%|██████████| 48/48 [00:02<00:00, 18.72it/s] 
Val: N@1 0.238, N@5 0.389, N@10 0.431, R@1 0.238, R@5 0.527, R@10 0.657: 100%|██████████| 48/48 [00:02<00:00, 22.07it/s]


Update Best NDCG@10 Model at 30


Epoch 31, loss 5.566 : 100%|██████████| 48/48 [00:02<00:00, 18.63it/s]
Val: N@1 0.245, N@5 0.392, N@10 0.433, R@1 0.245, R@5 0.526, R@10 0.653: 100%|██████████| 48/48 [00:02<00:00, 21.86it/s]


Update Best NDCG@10 Model at 31


Epoch 32, loss 5.556 : 100%|██████████| 48/48 [00:02<00:00, 18.79it/s] 
Val: N@1 0.236, N@5 0.389, N@10 0.429, R@1 0.236, R@5 0.527, R@10 0.650: 100%|██████████| 48/48 [00:02<00:00, 21.94it/s]
Epoch 33, loss 5.546 : 100%|██████████| 48/48 [00:02<00:00, 18.85it/s]
Val: N@1 0.248, N@5 0.400, N@10 0.442, R@1 0.248, R@5 0.536, R@10 0.666: 100%|██████████| 48/48 [00:02<00:00, 21.84it/s]


Update Best NDCG@10 Model at 33


Epoch 34, loss 5.522 : 100%|██████████| 48/48 [00:02<00:00, 18.63it/s] 
Val: N@1 0.239, N@5 0.392, N@10 0.433, R@1 0.239, R@5 0.532, R@10 0.659: 100%|██████████| 48/48 [00:02<00:00, 21.41it/s]
Epoch 35, loss 5.502 : 100%|██████████| 48/48 [00:02<00:00, 18.29it/s]
Val: N@1 0.251, N@5 0.403, N@10 0.443, R@1 0.251, R@5 0.540, R@10 0.665: 100%|██████████| 48/48 [00:02<00:00, 21.94it/s]


Update Best NDCG@10 Model at 35


Epoch 36, loss 5.489 : 100%|██████████| 48/48 [00:02<00:00, 18.66it/s]
Val: N@1 0.247, N@5 0.402, N@10 0.442, R@1 0.247, R@5 0.542, R@10 0.665: 100%|██████████| 48/48 [00:02<00:00, 21.74it/s]
Epoch 37, loss 5.481 : 100%|██████████| 48/48 [00:02<00:00, 18.77it/s]
Val: N@1 0.247, N@5 0.398, N@10 0.440, R@1 0.247, R@5 0.532, R@10 0.663: 100%|██████████| 48/48 [00:02<00:00, 21.96it/s]
Epoch 38, loss 5.456 : 100%|██████████| 48/48 [00:02<00:00, 18.68it/s]
Val: N@1 0.242, N@5 0.399, N@10 0.440, R@1 0.242, R@5 0.542, R@10 0.668: 100%|██████████| 48/48 [00:02<00:00, 21.70it/s]
Epoch 39, loss 5.456 : 100%|██████████| 48/48 [00:02<00:00, 18.66it/s]
Val: N@1 0.251, N@5 0.403, N@10 0.444, R@1 0.251, R@5 0.542, R@10 0.668: 100%|██████████| 48/48 [00:02<00:00, 22.15it/s]


Update Best NDCG@10 Model at 39


Epoch 40, loss 5.429 : 100%|██████████| 48/48 [00:02<00:00, 18.89it/s]
Val: N@1 0.249, N@5 0.406, N@10 0.445, R@1 0.249, R@5 0.549, R@10 0.671: 100%|██████████| 48/48 [00:02<00:00, 21.90it/s]


Update Best NDCG@10 Model at 40


Epoch 41, loss 5.439 : 100%|██████████| 48/48 [00:02<00:00, 18.83it/s] 
Val: N@1 0.250, N@5 0.404, N@10 0.444, R@1 0.250, R@5 0.544, R@10 0.667: 100%|██████████| 48/48 [00:02<00:00, 21.93it/s]
Epoch 42, loss 5.432 : 100%|██████████| 48/48 [00:02<00:00, 18.66it/s]
Val: N@1 0.251, N@5 0.399, N@10 0.443, R@1 0.251, R@5 0.535, R@10 0.669: 100%|██████████| 48/48 [00:02<00:00, 21.89it/s]
Epoch 43, loss 5.399 : 100%|██████████| 48/48 [00:02<00:00, 18.91it/s] 
Val: N@1 0.253, N@5 0.405, N@10 0.449, R@1 0.253, R@5 0.542, R@10 0.678: 100%|██████████| 48/48 [00:02<00:00, 22.12it/s]


Update Best NDCG@10 Model at 43


Epoch 44, loss 5.395 : 100%|██████████| 48/48 [00:02<00:00, 18.69it/s]
Val: N@1 0.247, N@5 0.400, N@10 0.444, R@1 0.247, R@5 0.540, R@10 0.678: 100%|██████████| 48/48 [00:02<00:00, 21.74it/s]
Epoch 45, loss 5.384 : 100%|██████████| 48/48 [00:02<00:00, 18.50it/s] 
Val: N@1 0.255, N@5 0.406, N@10 0.450, R@1 0.255, R@5 0.543, R@10 0.676: 100%|██████████| 48/48 [00:02<00:00, 21.99it/s]


Update Best NDCG@10 Model at 45


Epoch 46, loss 5.370 : 100%|██████████| 48/48 [00:02<00:00, 18.87it/s]
Val: N@1 0.260, N@5 0.413, N@10 0.454, R@1 0.260, R@5 0.551, R@10 0.680: 100%|██████████| 48/48 [00:02<00:00, 22.06it/s]


Update Best NDCG@10 Model at 46


Epoch 47, loss 5.361 : 100%|██████████| 48/48 [00:02<00:00, 18.82it/s] 
Val: N@1 0.256, N@5 0.409, N@10 0.450, R@1 0.256, R@5 0.548, R@10 0.675: 100%|██████████| 48/48 [00:02<00:00, 21.73it/s]
Epoch 48, loss 5.355 : 100%|██████████| 48/48 [00:02<00:00, 18.74it/s]
Val: N@1 0.266, N@5 0.418, N@10 0.458, R@1 0.266, R@5 0.558, R@10 0.682: 100%|██████████| 48/48 [00:02<00:00, 21.78it/s]


Update Best NDCG@10 Model at 48


Epoch 49, loss 5.330 : 100%|██████████| 48/48 [00:02<00:00, 18.76it/s] 
Val: N@1 0.262, N@5 0.410, N@10 0.453, R@1 0.262, R@5 0.545, R@10 0.675: 100%|██████████| 48/48 [00:02<00:00, 21.92it/s]
Epoch 50, loss 5.329 : 100%|██████████| 48/48 [00:02<00:00, 18.78it/s]
Val: N@1 0.264, N@5 0.414, N@10 0.455, R@1 0.264, R@5 0.548, R@10 0.673: 100%|██████████| 48/48 [00:02<00:00, 22.04it/s]
Epoch 51, loss 5.312 : 100%|██████████| 48/48 [00:02<00:00, 18.77it/s] 
Val: N@1 0.254, N@5 0.404, N@10 0.448, R@1 0.254, R@5 0.539, R@10 0.674: 100%|██████████| 48/48 [00:02<00:00, 21.88it/s]
Epoch 52, loss 5.312 : 100%|██████████| 48/48 [00:02<00:00, 18.70it/s]
Val: N@1 0.259, N@5 0.409, N@10 0.449, R@1 0.259, R@5 0.542, R@10 0.668: 100%|██████████| 48/48 [00:02<00:00, 21.93it/s]
Logging to Tensorboard: 100%|██████████| 48/48 [00:02<00:00, 18.73it/s]
Val: N@1 0.260, N@5 0.411, N@10 0.452, R@1 0.260, R@5 0.549, R@10 0.675: 100%|██████████| 48/48 [00:02<00:00, 22.03it/s]
Epoch 54, loss 5.289 : 100%|█████████

Update Best NDCG@10 Model at 62


Epoch 63, loss 5.202 : 100%|██████████| 48/48 [00:02<00:00, 18.66it/s]
Val: N@1 0.261, N@5 0.410, N@10 0.451, R@1 0.261, R@5 0.546, R@10 0.672: 100%|██████████| 48/48 [00:02<00:00, 21.68it/s]
Epoch 64, loss 5.183 : 100%|██████████| 48/48 [00:02<00:00, 18.54it/s] 
Val: N@1 0.257, N@5 0.408, N@10 0.449, R@1 0.257, R@5 0.545, R@10 0.674: 100%|██████████| 48/48 [00:02<00:00, 21.86it/s]
Epoch 65, loss 5.204 : 100%|██████████| 48/48 [00:02<00:00, 18.57it/s]
Val: N@1 0.260, N@5 0.413, N@10 0.453, R@1 0.260, R@5 0.553, R@10 0.679: 100%|██████████| 48/48 [00:02<00:00, 21.79it/s]
Epoch 66, loss 5.173 : 100%|██████████| 48/48 [00:02<00:00, 18.79it/s] 
Val: N@1 0.262, N@5 0.415, N@10 0.456, R@1 0.262, R@5 0.553, R@10 0.680: 100%|██████████| 48/48 [00:02<00:00, 22.15it/s]
Epoch 67, loss 5.171 : 100%|██████████| 48/48 [00:02<00:00, 18.82it/s]
Val: N@1 0.264, N@5 0.412, N@10 0.453, R@1 0.264, R@5 0.549, R@10 0.673: 100%|██████████| 48/48 [00:02<00:00, 21.97it/s]
Epoch 68, loss 5.188 : 100%|██████████

Update Best NDCG@10 Model at 75


Epoch 76, loss 5.125 : 100%|██████████| 48/48 [00:02<00:00, 18.66it/s]
Val: N@1 0.267, N@5 0.415, N@10 0.453, R@1 0.267, R@5 0.547, R@10 0.665: 100%|██████████| 48/48 [00:02<00:00, 21.99it/s]
Epoch 77, loss 5.117 : 100%|██████████| 48/48 [00:02<00:00, 18.73it/s] 
Val: N@1 0.264, N@5 0.415, N@10 0.454, R@1 0.264, R@5 0.553, R@10 0.670: 100%|██████████| 48/48 [00:02<00:00, 21.95it/s]
Epoch 78, loss 5.095 : 100%|██████████| 48/48 [00:02<00:00, 18.69it/s]
Val: N@1 0.267, N@5 0.419, N@10 0.459, R@1 0.267, R@5 0.555, R@10 0.678: 100%|██████████| 48/48 [00:02<00:00, 21.96it/s]
Epoch 79, loss 5.104 : 100%|██████████| 48/48 [00:02<00:00, 18.62it/s] 
Val: N@1 0.262, N@5 0.416, N@10 0.455, R@1 0.262, R@5 0.558, R@10 0.678: 100%|██████████| 48/48 [00:02<00:00, 21.82it/s]
Epoch 80, loss 5.089 : 100%|██████████| 48/48 [00:02<00:00, 18.75it/s]
Val: N@1 0.273, N@5 0.424, N@10 0.462, R@1 0.273, R@5 0.561, R@10 0.679: 100%|██████████| 48/48 [00:02<00:00, 22.08it/s]
Epoch 81, loss 5.088 : 100%|██████████

Update Best NDCG@10 Model at 84


Epoch 85, loss 5.056 : 100%|██████████| 48/48 [00:02<00:00, 18.50it/s] 
Val: N@1 0.262, N@5 0.418, N@10 0.456, R@1 0.262, R@5 0.558, R@10 0.677: 100%|██████████| 48/48 [00:02<00:00, 21.68it/s]
Epoch 86, loss 5.055 : 100%|██████████| 48/48 [00:02<00:00, 18.89it/s]
Val: N@1 0.268, N@5 0.422, N@10 0.461, R@1 0.268, R@5 0.561, R@10 0.680: 100%|██████████| 48/48 [00:02<00:00, 21.76it/s]
Epoch 87, loss 5.057 : 100%|██████████| 48/48 [00:02<00:00, 18.72it/s] 
Val: N@1 0.268, N@5 0.420, N@10 0.461, R@1 0.268, R@5 0.558, R@10 0.684: 100%|██████████| 48/48 [00:02<00:00, 22.02it/s]
Epoch 88, loss 5.049 : 100%|██████████| 48/48 [00:02<00:00, 18.71it/s]
Val: N@1 0.271, N@5 0.423, N@10 0.462, R@1 0.271, R@5 0.560, R@10 0.681: 100%|██████████| 48/48 [00:02<00:00, 22.03it/s]
Epoch 89, loss 5.041 : 100%|██████████| 48/48 [00:02<00:00, 18.63it/s]
Val: N@1 0.269, N@5 0.418, N@10 0.456, R@1 0.269, R@5 0.554, R@10 0.672: 100%|██████████| 48/48 [00:02<00:00, 21.94it/s]
Epoch 90, loss 5.046 : 100%|██████████

In [60]:
trainer.test()

Test best model with test set!


Val: N@1 0.215, N@5 0.356, N@10 0.395, R@1 0.215, R@5 0.486, R@10 0.605: 100%|██████████| 48/48 [00:02<00:00, 19.85it/s]

{'Recall@100': 0.9970703125, 'NDCG@100': 0.47501263581216335, 'Recall@50': 0.8770073788861433, 'NDCG@50': 0.4555890740205844, 'Recall@20': 0.7222764765222868, 'NDCG@20': 0.42481197354694206, 'Recall@10': 0.60498046875, 'NDCG@10': 0.3951025288552046, 'Recall@5': 0.4856228306889534, 'NDCG@5': 0.35642762916783494, 'Recall@1': 0.2152235247194767, 'NDCG@1': 0.2152235247194767}





old, recheck reproducibility:

In [56]:
trainer.train()

Val: N@1 0.007, N@5 0.028, N@10 0.045, R@1 0.007, R@5 0.050, R@10 0.102: 100%|██████████| 48/48 [00:05<00:00,  9.57it/s]


Update Best NDCG@10 Model at 1


Epoch 1, loss 7.829 : 100%|██████████| 48/48 [00:02<00:00, 17.95it/s]
Val: N@1 0.021, N@5 0.052, N@10 0.074, R@1 0.021, R@5 0.084, R@10 0.154: 100%|██████████| 48/48 [00:02<00:00, 22.47it/s]


Update Best NDCG@10 Model at 1


Epoch 2, loss 7.526 : 100%|██████████| 48/48 [00:02<00:00, 18.91it/s]
Val: N@1 0.025, N@5 0.066, N@10 0.089, R@1 0.025, R@5 0.107, R@10 0.180: 100%|██████████| 48/48 [00:02<00:00, 22.52it/s]


Update Best NDCG@10 Model at 2


Epoch 3, loss 7.369 : 100%|██████████| 48/48 [00:02<00:00, 18.76it/s] 
Val: N@1 0.038, N@5 0.094, N@10 0.125, R@1 0.038, R@5 0.150, R@10 0.248: 100%|██████████| 48/48 [00:02<00:00, 22.51it/s]


Update Best NDCG@10 Model at 3


Epoch 4, loss 7.182 : 100%|██████████| 48/48 [00:02<00:00, 18.83it/s]
Val: N@1 0.053, N@5 0.122, N@10 0.160, R@1 0.053, R@5 0.193, R@10 0.312: 100%|██████████| 48/48 [00:02<00:00, 22.24it/s]


Update Best NDCG@10 Model at 4


Epoch 5, loss 6.986 : 100%|██████████| 48/48 [00:02<00:00, 18.89it/s]  
Val: N@1 0.070, N@5 0.163, N@10 0.204, R@1 0.070, R@5 0.255, R@10 0.383: 100%|██████████| 48/48 [00:02<00:00, 22.28it/s]


Update Best NDCG@10 Model at 5


Epoch 6, loss 6.835 : 100%|██████████| 48/48 [00:02<00:00, 18.85it/s]
Val: N@1 0.095, N@5 0.203, N@10 0.249, R@1 0.095, R@5 0.306, R@10 0.450: 100%|██████████| 48/48 [00:02<00:00, 22.60it/s]


Update Best NDCG@10 Model at 6


Epoch 7, loss 6.678 : 100%|██████████| 48/48 [00:02<00:00, 18.72it/s]  
Val: N@1 0.108, N@5 0.228, N@10 0.277, R@1 0.108, R@5 0.342, R@10 0.494: 100%|██████████| 48/48 [00:02<00:00, 21.63it/s]


Update Best NDCG@10 Model at 7


Epoch 8, loss 6.564 : 100%|██████████| 48/48 [00:02<00:00, 18.65it/s]
Val: N@1 0.125, N@5 0.243, N@10 0.288, R@1 0.125, R@5 0.357, R@10 0.497: 100%|██████████| 48/48 [00:02<00:00, 22.32it/s]


Update Best NDCG@10 Model at 8


Epoch 9, loss 6.443 : 100%|██████████| 48/48 [00:02<00:00, 18.71it/s]  
Val: N@1 0.142, N@5 0.269, N@10 0.317, R@1 0.142, R@5 0.390, R@10 0.539: 100%|██████████| 48/48 [00:02<00:00, 22.44it/s]


Update Best NDCG@10 Model at 9


Epoch 10, loss 6.359 : 100%|██████████| 48/48 [00:02<00:00, 18.73it/s]
Val: N@1 0.149, N@5 0.282, N@10 0.327, R@1 0.149, R@5 0.409, R@10 0.550: 100%|██████████| 48/48 [00:02<00:00, 22.46it/s]


Update Best NDCG@10 Model at 10


Epoch 11, loss 6.285 : 100%|██████████| 48/48 [00:02<00:00, 18.80it/s] 
Val: N@1 0.151, N@5 0.283, N@10 0.329, R@1 0.151, R@5 0.408, R@10 0.548: 100%|██████████| 48/48 [00:02<00:00, 22.46it/s]


Update Best NDCG@10 Model at 11


Epoch 12, loss 6.206 : 100%|██████████| 48/48 [00:02<00:00, 18.79it/s]
Val: N@1 0.174, N@5 0.310, N@10 0.356, R@1 0.174, R@5 0.438, R@10 0.578: 100%|██████████| 48/48 [00:02<00:00, 22.53it/s]


Update Best NDCG@10 Model at 12


Epoch 13, loss 6.142 : 100%|██████████| 48/48 [00:02<00:00, 18.74it/s] 
Val: N@1 0.173, N@5 0.312, N@10 0.357, R@1 0.173, R@5 0.443, R@10 0.581: 100%|██████████| 48/48 [00:02<00:00, 22.42it/s]


Update Best NDCG@10 Model at 13


Epoch 14, loss 6.098 : 100%|██████████| 48/48 [00:02<00:00, 18.70it/s]
Val: N@1 0.176, N@5 0.314, N@10 0.360, R@1 0.176, R@5 0.442, R@10 0.585: 100%|██████████| 48/48 [00:02<00:00, 22.28it/s]


Update Best NDCG@10 Model at 14


Epoch 15, loss 6.040 : 100%|██████████| 48/48 [00:02<00:00, 18.76it/s] 
Val: N@1 0.192, N@5 0.328, N@10 0.374, R@1 0.192, R@5 0.452, R@10 0.593: 100%|██████████| 48/48 [00:02<00:00, 22.32it/s]


Update Best NDCG@10 Model at 15


Epoch 16, loss 5.996 : 100%|██████████| 48/48 [00:02<00:00, 18.79it/s]
Val: N@1 0.192, N@5 0.339, N@10 0.382, R@1 0.192, R@5 0.474, R@10 0.609: 100%|██████████| 48/48 [00:02<00:00, 22.22it/s]


Update Best NDCG@10 Model at 16


Epoch 17, loss 5.969 : 100%|██████████| 48/48 [00:02<00:00, 18.76it/s] 
Val: N@1 0.191, N@5 0.338, N@10 0.382, R@1 0.191, R@5 0.475, R@10 0.611: 100%|██████████| 48/48 [00:02<00:00, 22.04it/s]


Update Best NDCG@10 Model at 17


Epoch 18, loss 5.921 : 100%|██████████| 48/48 [00:02<00:00, 19.01it/s]
Val: N@1 0.201, N@5 0.348, N@10 0.392, R@1 0.201, R@5 0.484, R@10 0.619: 100%|██████████| 48/48 [00:02<00:00, 22.24it/s]


Update Best NDCG@10 Model at 18


Epoch 19, loss 5.867 : 100%|██████████| 48/48 [00:02<00:00, 18.97it/s]
Val: N@1 0.211, N@5 0.360, N@10 0.401, R@1 0.211, R@5 0.497, R@10 0.626: 100%|██████████| 48/48 [00:02<00:00, 22.50it/s]


Update Best NDCG@10 Model at 19


Epoch 20, loss 5.853 : 100%|██████████| 48/48 [00:02<00:00, 18.87it/s]
Val: N@1 0.206, N@5 0.356, N@10 0.399, R@1 0.206, R@5 0.493, R@10 0.625: 100%|██████████| 48/48 [00:02<00:00, 22.32it/s]
Epoch 21, loss 5.814 : 100%|██████████| 48/48 [00:02<00:00, 18.60it/s]
Val: N@1 0.216, N@5 0.366, N@10 0.411, R@1 0.216, R@5 0.504, R@10 0.640: 100%|██████████| 48/48 [00:02<00:00, 21.98it/s]


Update Best NDCG@10 Model at 21


Epoch 22, loss 5.775 : 100%|██████████| 48/48 [00:02<00:00, 18.49it/s] 
Val: N@1 0.218, N@5 0.368, N@10 0.410, R@1 0.218, R@5 0.506, R@10 0.635: 100%|██████████| 48/48 [00:02<00:00, 21.78it/s]
Epoch 23, loss 5.755 : 100%|██████████| 48/48 [00:02<00:00, 18.91it/s]
Val: N@1 0.216, N@5 0.362, N@10 0.406, R@1 0.216, R@5 0.496, R@10 0.631: 100%|██████████| 48/48 [00:02<00:00, 22.29it/s]
Epoch 24, loss 5.730 : 100%|██████████| 48/48 [00:02<00:00, 18.76it/s] 
Val: N@1 0.223, N@5 0.371, N@10 0.412, R@1 0.223, R@5 0.507, R@10 0.634: 100%|██████████| 48/48 [00:02<00:00, 21.86it/s]


Update Best NDCG@10 Model at 24


Epoch 25, loss 5.693 : 100%|██████████| 48/48 [00:02<00:00, 18.45it/s]
Val: N@1 0.223, N@5 0.375, N@10 0.417, R@1 0.223, R@5 0.514, R@10 0.645: 100%|██████████| 48/48 [00:02<00:00, 21.04it/s]


Update Best NDCG@10 Model at 25


Epoch 26, loss 5.681 : 100%|██████████| 48/48 [00:02<00:00, 18.86it/s] 
Val: N@1 0.223, N@5 0.375, N@10 0.418, R@1 0.223, R@5 0.514, R@10 0.644: 100%|██████████| 48/48 [00:02<00:00, 22.27it/s]


Update Best NDCG@10 Model at 26


Epoch 27, loss 5.651 : 100%|██████████| 48/48 [00:02<00:00, 18.78it/s]
Val: N@1 0.230, N@5 0.383, N@10 0.427, R@1 0.230, R@5 0.521, R@10 0.656: 100%|██████████| 48/48 [00:02<00:00, 21.95it/s]


Update Best NDCG@10 Model at 27


Epoch 28, loss 5.645 : 100%|██████████| 48/48 [00:02<00:00, 18.82it/s] 
Val: N@1 0.232, N@5 0.387, N@10 0.428, R@1 0.232, R@5 0.530, R@10 0.656: 100%|██████████| 48/48 [00:02<00:00, 22.08it/s]


Update Best NDCG@10 Model at 28


Epoch 29, loss 5.616 : 100%|██████████| 48/48 [00:02<00:00, 19.03it/s]
Val: N@1 0.239, N@5 0.387, N@10 0.427, R@1 0.239, R@5 0.524, R@10 0.647: 100%|██████████| 48/48 [00:02<00:00, 22.20it/s]
Epoch 30, loss 5.593 : 100%|██████████| 48/48 [00:02<00:00, 18.81it/s] 
Val: N@1 0.238, N@5 0.389, N@10 0.431, R@1 0.238, R@5 0.527, R@10 0.657: 100%|██████████| 48/48 [00:02<00:00, 22.28it/s]


Update Best NDCG@10 Model at 30


Epoch 31, loss 5.566 : 100%|██████████| 48/48 [00:02<00:00, 18.66it/s]
Val: N@1 0.245, N@5 0.392, N@10 0.433, R@1 0.245, R@5 0.526, R@10 0.653: 100%|██████████| 48/48 [00:02<00:00, 21.82it/s]


Update Best NDCG@10 Model at 31


Epoch 32, loss 5.556 : 100%|██████████| 48/48 [00:02<00:00, 18.82it/s] 
Val: N@1 0.236, N@5 0.389, N@10 0.429, R@1 0.236, R@5 0.527, R@10 0.650: 100%|██████████| 48/48 [00:02<00:00, 22.31it/s]
Epoch 33, loss 5.546 : 100%|██████████| 48/48 [00:02<00:00, 18.63it/s]
Val: N@1 0.248, N@5 0.400, N@10 0.442, R@1 0.248, R@5 0.536, R@10 0.666: 100%|██████████| 48/48 [00:02<00:00, 21.87it/s]


Update Best NDCG@10 Model at 33


Epoch 34, loss 5.522 : 100%|██████████| 48/48 [00:02<00:00, 18.79it/s] 
Val: N@1 0.239, N@5 0.392, N@10 0.433, R@1 0.239, R@5 0.532, R@10 0.659: 100%|██████████| 48/48 [00:02<00:00, 22.04it/s]
Epoch 35, loss 5.502 : 100%|██████████| 48/48 [00:02<00:00, 18.88it/s]
Val: N@1 0.251, N@5 0.403, N@10 0.443, R@1 0.251, R@5 0.540, R@10 0.665: 100%|██████████| 48/48 [00:02<00:00, 21.10it/s]


Update Best NDCG@10 Model at 35


Epoch 36, loss 5.489 : 100%|██████████| 48/48 [00:02<00:00, 18.22it/s]
Val: N@1 0.247, N@5 0.402, N@10 0.442, R@1 0.247, R@5 0.542, R@10 0.665: 100%|██████████| 48/48 [00:02<00:00, 21.45it/s]
Epoch 37, loss 5.481 : 100%|██████████| 48/48 [00:02<00:00, 18.79it/s]
Val: N@1 0.247, N@5 0.398, N@10 0.440, R@1 0.247, R@5 0.532, R@10 0.663: 100%|██████████| 48/48 [00:02<00:00, 20.75it/s]
Epoch 38, loss 5.456 : 100%|██████████| 48/48 [00:02<00:00, 18.46it/s]
Val: N@1 0.242, N@5 0.399, N@10 0.440, R@1 0.242, R@5 0.542, R@10 0.668: 100%|██████████| 48/48 [00:02<00:00, 22.21it/s]
Epoch 39, loss 5.456 : 100%|██████████| 48/48 [00:02<00:00, 18.98it/s]
Val: N@1 0.251, N@5 0.403, N@10 0.444, R@1 0.251, R@5 0.542, R@10 0.668: 100%|██████████| 48/48 [00:02<00:00, 21.95it/s]


Update Best NDCG@10 Model at 39


Epoch 40, loss 5.429 : 100%|██████████| 48/48 [00:02<00:00, 18.88it/s]
Val: N@1 0.249, N@5 0.406, N@10 0.445, R@1 0.249, R@5 0.549, R@10 0.671: 100%|██████████| 48/48 [00:02<00:00, 22.10it/s]


Update Best NDCG@10 Model at 40


Epoch 41, loss 5.439 : 100%|██████████| 48/48 [00:02<00:00, 18.71it/s] 
Val: N@1 0.250, N@5 0.404, N@10 0.444, R@1 0.250, R@5 0.544, R@10 0.667: 100%|██████████| 48/48 [00:02<00:00, 22.29it/s]
Epoch 42, loss 5.432 : 100%|██████████| 48/48 [00:02<00:00, 18.98it/s]
Val: N@1 0.251, N@5 0.399, N@10 0.443, R@1 0.251, R@5 0.535, R@10 0.669: 100%|██████████| 48/48 [00:02<00:00, 21.75it/s]
Epoch 43, loss 5.399 : 100%|██████████| 48/48 [00:02<00:00, 18.89it/s] 
Val: N@1 0.253, N@5 0.405, N@10 0.449, R@1 0.253, R@5 0.542, R@10 0.678: 100%|██████████| 48/48 [00:02<00:00, 21.98it/s]


Update Best NDCG@10 Model at 43


Epoch 44, loss 5.395 : 100%|██████████| 48/48 [00:02<00:00, 18.60it/s]
Val: N@1 0.247, N@5 0.400, N@10 0.444, R@1 0.247, R@5 0.540, R@10 0.678: 100%|██████████| 48/48 [00:02<00:00, 22.10it/s]
Epoch 45, loss 5.384 : 100%|██████████| 48/48 [00:02<00:00, 18.92it/s] 
Val: N@1 0.255, N@5 0.406, N@10 0.450, R@1 0.255, R@5 0.543, R@10 0.676: 100%|██████████| 48/48 [00:02<00:00, 22.11it/s]


Update Best NDCG@10 Model at 45


Epoch 46, loss 5.370 : 100%|██████████| 48/48 [00:02<00:00, 18.72it/s]
Val: N@1 0.260, N@5 0.413, N@10 0.454, R@1 0.260, R@5 0.551, R@10 0.680: 100%|██████████| 48/48 [00:02<00:00, 22.04it/s]


Update Best NDCG@10 Model at 46


Epoch 47, loss 5.361 : 100%|██████████| 48/48 [00:02<00:00, 18.77it/s] 
Val: N@1 0.256, N@5 0.409, N@10 0.450, R@1 0.256, R@5 0.548, R@10 0.675: 100%|██████████| 48/48 [00:02<00:00, 21.84it/s]
Epoch 48, loss 5.355 : 100%|██████████| 48/48 [00:02<00:00, 19.00it/s]
Val: N@1 0.266, N@5 0.418, N@10 0.458, R@1 0.266, R@5 0.558, R@10 0.682: 100%|██████████| 48/48 [00:02<00:00, 21.95it/s]


Update Best NDCG@10 Model at 48


Epoch 49, loss 5.330 : 100%|██████████| 48/48 [00:02<00:00, 18.49it/s] 
Val: N@1 0.262, N@5 0.410, N@10 0.453, R@1 0.262, R@5 0.545, R@10 0.675: 100%|██████████| 48/48 [00:02<00:00, 22.07it/s]
Epoch 50, loss 5.329 : 100%|██████████| 48/48 [00:02<00:00, 18.42it/s]
Val: N@1 0.264, N@5 0.414, N@10 0.455, R@1 0.264, R@5 0.548, R@10 0.673: 100%|██████████| 48/48 [00:02<00:00, 21.95it/s]
Epoch 51, loss 5.312 : 100%|██████████| 48/48 [00:02<00:00, 18.69it/s] 
Val: N@1 0.254, N@5 0.404, N@10 0.448, R@1 0.254, R@5 0.539, R@10 0.674: 100%|██████████| 48/48 [00:02<00:00, 22.03it/s]
Epoch 52, loss 5.312 : 100%|██████████| 48/48 [00:02<00:00, 18.51it/s]
Val: N@1 0.259, N@5 0.409, N@10 0.449, R@1 0.259, R@5 0.542, R@10 0.668: 100%|██████████| 48/48 [00:02<00:00, 21.87it/s]
Logging to Tensorboard: 100%|██████████| 48/48 [00:02<00:00, 18.74it/s]
Val: N@1 0.260, N@5 0.411, N@10 0.452, R@1 0.260, R@5 0.549, R@10 0.675: 100%|██████████| 48/48 [00:02<00:00, 22.03it/s]
Epoch 54, loss 5.289 : 100%|█████████

Update Best NDCG@10 Model at 62


Epoch 63, loss 5.202 : 100%|██████████| 48/48 [00:02<00:00, 18.82it/s]
Val: N@1 0.261, N@5 0.410, N@10 0.451, R@1 0.261, R@5 0.546, R@10 0.672: 100%|██████████| 48/48 [00:02<00:00, 22.29it/s]
Epoch 64, loss 5.183 : 100%|██████████| 48/48 [00:02<00:00, 18.88it/s] 
Val: N@1 0.257, N@5 0.408, N@10 0.449, R@1 0.257, R@5 0.545, R@10 0.674: 100%|██████████| 48/48 [00:02<00:00, 22.04it/s]
Epoch 65, loss 5.204 : 100%|██████████| 48/48 [00:02<00:00, 18.72it/s]
Val: N@1 0.260, N@5 0.413, N@10 0.453, R@1 0.260, R@5 0.553, R@10 0.679: 100%|██████████| 48/48 [00:02<00:00, 22.38it/s]
Epoch 66, loss 5.173 : 100%|██████████| 48/48 [00:02<00:00, 18.55it/s] 
Val: N@1 0.262, N@5 0.415, N@10 0.456, R@1 0.262, R@5 0.553, R@10 0.680: 100%|██████████| 48/48 [00:02<00:00, 21.98it/s]
Epoch 67, loss 5.171 : 100%|██████████| 48/48 [00:02<00:00, 18.89it/s]
Val: N@1 0.264, N@5 0.412, N@10 0.453, R@1 0.264, R@5 0.549, R@10 0.673: 100%|██████████| 48/48 [00:02<00:00, 21.29it/s]
Epoch 68, loss 5.188 : 100%|██████████

Update Best NDCG@10 Model at 75


Epoch 76, loss 5.125 : 100%|██████████| 48/48 [00:02<00:00, 18.70it/s]
Val: N@1 0.267, N@5 0.415, N@10 0.453, R@1 0.267, R@5 0.547, R@10 0.665: 100%|██████████| 48/48 [00:02<00:00, 22.02it/s]
Epoch 77, loss 5.117 : 100%|██████████| 48/48 [00:02<00:00, 18.89it/s] 
Val: N@1 0.264, N@5 0.415, N@10 0.454, R@1 0.264, R@5 0.553, R@10 0.670: 100%|██████████| 48/48 [00:02<00:00, 21.87it/s]
Epoch 78, loss 5.095 : 100%|██████████| 48/48 [00:02<00:00, 18.72it/s]
Val: N@1 0.267, N@5 0.419, N@10 0.459, R@1 0.267, R@5 0.555, R@10 0.678: 100%|██████████| 48/48 [00:02<00:00, 22.18it/s]
Epoch 79, loss 5.104 : 100%|██████████| 48/48 [00:02<00:00, 18.75it/s] 
Val: N@1 0.262, N@5 0.416, N@10 0.455, R@1 0.262, R@5 0.558, R@10 0.678: 100%|██████████| 48/48 [00:02<00:00, 21.91it/s]
Epoch 80, loss 5.089 : 100%|██████████| 48/48 [00:02<00:00, 18.22it/s]
Val: N@1 0.273, N@5 0.424, N@10 0.462, R@1 0.273, R@5 0.561, R@10 0.679: 100%|██████████| 48/48 [00:02<00:00, 21.44it/s]
Epoch 81, loss 5.088 : 100%|██████████

Update Best NDCG@10 Model at 84


Epoch 85, loss 5.056 : 100%|██████████| 48/48 [00:02<00:00, 18.75it/s] 
Val: N@1 0.262, N@5 0.418, N@10 0.456, R@1 0.262, R@5 0.558, R@10 0.677: 100%|██████████| 48/48 [00:02<00:00, 22.17it/s]
Epoch 86, loss 5.055 : 100%|██████████| 48/48 [00:02<00:00, 18.59it/s]
Val: N@1 0.268, N@5 0.422, N@10 0.461, R@1 0.268, R@5 0.561, R@10 0.680: 100%|██████████| 48/48 [00:02<00:00, 21.40it/s]
Epoch 87, loss 5.057 : 100%|██████████| 48/48 [00:02<00:00, 18.67it/s] 
Val: N@1 0.268, N@5 0.420, N@10 0.461, R@1 0.268, R@5 0.558, R@10 0.684: 100%|██████████| 48/48 [00:02<00:00, 21.52it/s]
Epoch 88, loss 5.049 : 100%|██████████| 48/48 [00:02<00:00, 18.68it/s]
Val: N@1 0.271, N@5 0.423, N@10 0.462, R@1 0.271, R@5 0.560, R@10 0.681: 100%|██████████| 48/48 [00:02<00:00, 22.24it/s]
Epoch 89, loss 5.041 : 100%|██████████| 48/48 [00:02<00:00, 18.76it/s]
Val: N@1 0.269, N@5 0.418, N@10 0.456, R@1 0.269, R@5 0.554, R@10 0.672: 100%|██████████| 48/48 [00:02<00:00, 22.05it/s]
Epoch 90, loss 5.046 : 100%|██████████

In [57]:
trainer.test()

Test best model with test set!


Val: N@1 0.215, N@5 0.356, N@10 0.395, R@1 0.215, R@5 0.486, R@10 0.605: 100%|██████████| 48/48 [00:02<00:00, 19.79it/s]

{'Recall@100': 0.9970703125, 'NDCG@100': 0.47501263581216335, 'Recall@50': 0.8770073788861433, 'NDCG@50': 0.4555890740205844, 'Recall@20': 0.7222764765222868, 'NDCG@20': 0.42481197354694206, 'Recall@10': 0.60498046875, 'NDCG@10': 0.3951025288552046, 'Recall@5': 0.4856228306889534, 'NDCG@5': 0.35642762916783494, 'Recall@1': 0.2152235247194767, 'NDCG@1': 0.2152235247194767}





---

посчитаем ndcg@10 как в бейзлайнах - по датафрейму neg_sampled_test и по той функции рассчёта, которая использовалась там:

In [61]:
best_model = torch.load(os.path.join(export_root, 'models', 'best_acc_model.pth')).get('model_state_dict')

In [62]:
model.load_state_dict(best_model)

<All keys matched successfully>

In [63]:
model.eval()

BERTModel(
  (bert): BERT(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(3708, 256, padding_idx=0)
      (position): PositionalEmbedding(
        (pe): Embedding(100, 256)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_blocks): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadedAttention(
          (linear_layers): ModuleList(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): Linear(in_features=256, out_features=256, bias=True)
            (2): Linear(in_features=256, out_features=256, bias=True)
          )
          (output_linear): Linear(in_features=256, out_features=256, bias=True)
          (attention): Attention()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=256, out_features=1024, bias=True)
          (w_2): Linear(in_features=1024, out_features=256, bias=True)
          (dr

In [64]:
from tqdm import tqdm

# получаем предикты для mask_token на тестовой последовательности:
scores_for_mask_token_for_each_batch = []

with torch.no_grad():
    tqdm_dataloader = tqdm(test_torch_dataloader)
    
    for batch_idx, batch in enumerate(tqdm_dataloader):
        batch = [x.to('cuda') for x in batch]
        
        scores = model(batch[0])
        scores_for_mask_token = scores[:, -1, :]
        
        scores_for_mask_token_for_each_batch.append(scores_for_mask_token.cpu())

100%|██████████| 48/48 [00:00<00:00, 50.93it/s]


In [65]:
scores_for_mask_token = torch.cat(scores_for_mask_token_for_each_batch, dim=0)

scores_for_mask_token.shape

torch.Size([6040, 3707])

In [66]:
user_to_scores_of_items = dict()

for user_ix, scores in enumerate(scores_for_mask_token):
    user_to_scores_of_items[user_ix] = dict()
    
    for item_ix, score in enumerate(scores):
        user_to_scores_of_items[user_ix][item_ix] = score.item()

In [67]:
# отслеживаем путь юзера с индексом 0 (оригинальный айди 1)
print(
    user_to_scores_of_items[0][49], # видел на трейне
    user_to_scores_of_items[0][50], # видел на трейне
    user_to_scores_of_items[0][51], # не видел на трейне (а стоило бы вернуть - мы его украли в валидационное взаимодействие)
    user_to_scores_of_items[0][52], # не видел на трейне
    user_to_scores_of_items[0][53], # не видел на трейне (и в реальности его не было в train + val + test)
    user_to_scores_of_items[0][54], # не видел на трейне (и в реальности его не было в train + val + test)
)

6.759389877319336 10.463557243347168 7.453557968139648 10.459239959716797 2.9089155197143555 3.2975363731384277


выглядит правдоподобно

In [68]:
from metrics import *

In [69]:
def calculate_grouped_ndcg_for_bert4rec_output(valDf, model_scores, k):
    data = valDf.copy()
    
    data["predicted_rating"] = data.apply(lambda row: model_scores[row["user_id"]][row["movie_id"]], axis=1)
    
    nonnull_users = set(data[data.rating > 0].user_id)
    data = data[data.user_id.isin(nonnull_users)]
    
    return np.mean(data.groupby("user_id").apply(lambda x: ndcg_score(x.rating, x.predicted_rating, k)))

In [70]:
calculate_grouped_ndcg_for_bert4rec_output(neg_sampled_test, user_to_scores_of_items, 10)

0.3927996803071895