In [1]:
import warnings
from pprint import pprint

import pandas as pd
from implicit.nearest_neighbours import CosineRecommender, TFIDFRecommender, BM25Recommender
from rectools import Columns
from rectools.dataset import Interactions
from rectools.metrics import MAP, MeanInvUserFreq, calc_metrics
from rectools.metrics import Precision, Recall, NDCG, Serendipity
from rectools.model_selection import TimeRangeSplitter

from models.userknn import UserKnn

warnings.filterwarnings("ignore")
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 200)
import pickle

In [2]:
interactions_df = pd.read_csv('../data/interactions.csv')
# users = pd.read_csv('../data/users.csv')
# items = pd.read_csv('../data/items.csv')

interactions_df.rename(columns={'last_watch_dt': Columns.Datetime,
                                'total_dur': Columns.Weight}, inplace=True)
# will cast types and save new pd.DataFrame inside in Interactions.df
interactions = Interactions(interactions_df)

# ! если хотите быстро прогнать этот ноутбук - раскомментируйте эту строку - она уменьшает данные
# interactions = Interactions(interactions_df.sample(frac=0.01))

interactions.df.head()

Unnamed: 0,user_id,item_id,datetime,weight,watched_pct
0,176549,9506,2021-05-11,4250.0,72.0
1,699317,1659,2021-05-29,8317.0,100.0
2,656683,7107,2021-05-09,10.0,0.0
3,864613,7638,2021-07-05,14483.0,100.0
4,964868,9506,2021-04-30,6725.0,100.0


In [None]:
N_SPLITS = 5
TEST_SIZE = '14D'

In [None]:
cv = TimeRangeSplitter(
    test_size=TEST_SIZE,
    n_splits=N_SPLITS,
    filter_already_seen=True,
    filter_cold_items=True,
    filter_cold_users=True,
)

In [None]:
metrics = {
    'Precision@10': Precision(k=10),
    'Recall@10': Recall(k=10),
    'NDCG@10': NDCG(k=10),
    'map@10': MAP(k=10),
    'novelty': MeanInvUserFreq(k=10),
    'Serendipity@10': Serendipity(k=10)
}

# few simple models to compare
models = {
    'cosine_userknn': CosineRecommender(),  # implicit 
    'tfidf_userknn': TFIDFRecommender(),
    'BM25_userknn': BM25Recommender()
}

In [None]:
results = []

fold_iterator = cv.split(interactions, collect_fold_stats=True)

for i_fold, (train_ids, test_ids, fold_info) in enumerate(fold_iterator):
    print(f"\n==================== Fold {i_fold}")
    pprint(fold_info)

    df_train = interactions.df.iloc[train_ids].copy()
    df_test = interactions.df.iloc[test_ids][Columns.UserItem].copy()

    catalog = df_train[Columns.Item].unique()

    for model_name, model in models.items():
        userknn_model = UserKnn(model=model, N_users=50)
        userknn_model.fit(df_train)

        recos = userknn_model.predict(df_test)

        metric_values = calc_metrics(
            metrics,
            reco=recos,
            interactions=df_test,
            prev_interactions=df_train,
            catalog=catalog,
        )

        fold = {"fold": i_fold, "model": model_name}
        fold.update(metric_values)
        results.append(fold)

In [None]:
df_metrics = pd.DataFrame(results)
df_metrics

In [None]:
df_metrics.groupby('model').mean()[metrics.keys()]

In [3]:
uknn = UserKnn(TFIDFRecommender(), N_users=50)

In [4]:
uknn.fit(interactions.df)

  0%|          | 0/962179 [00:00<?, ?it/s]

In [6]:
uknn.predict(pd.DataFrame([{'user_id': interactions.df.head(2).user_id.values[0]}]))

Unnamed: 0,user_id,item_id,score,rank
7,176549,15469,2.356529,1
18,176549,5518,2.35474,2
22,176549,12448,2.32781,3
49,176549,6737,2.294926,4
9,176549,5482,2.262816,5
11,176549,10688,2.252752,6
40,176549,4273,2.217489,7
17,176549,5695,2.177419,8
6,176549,7453,2.159942,9
0,176549,5600,2.159429,10


In [7]:
uknn.recommend(interactions.df.head(2).user_id.values[0], N_recs=10)

[15469, 5518, 12448, 6737, 5482, 10688, 4273, 5695, 7453, 5600]

In [8]:
uknn.recommend(1000000000, N_recs=10)

[10440, 15297, 9728, 13865, 4151, 3734, 2657, 4880, 142, 6809]

In [9]:
pickle.dump(uknn, open('../saved_models/userknn.pkl', "wb"))

In [10]:
with open('../saved_models/userknn.pkl', 'rb') as f:
    uknn_pkl = pickle.load(f)

In [11]:
uknn_pkl.recommend(interactions.df.head(2).user_id.values[0], N_recs=10)

[15469, 5518, 12448, 6737, 5482, 10688, 4273, 5695, 7453, 5600]