# RePlay bandit models comparison

We will show the main RePlay functionality and compare performance of RePlay models on well-known MovieLens dataset. For simplicity we consider here only the various bandit algorithms, both context-free and context-aware. The list of considered strategies for comparison:

Context-free algorithms:
* Most popular;
* Vanilla UCB algorithm;
* Vanilla TS algorithm; (Beta-Binomial);
* KL-UCB [Smb et al];

Contextual bandits algorithms:
* Lin-UCB [Smb et al];
* Linear TS [Smb et al] (Thompson sampling with linear feature vectors);


### Dataset
We will compare RePlay models on __MovieLens 1m__. 

### Dataset preprocessing: 
Ratings greater than or equal to 3 are considered as positive interactions.

### Data split
Dataset is split by date so that 20% of the last interactions as are placed in the test part. Cold items and users are dropped.

### Predict:
We will predict top-10 most relevant films for each user.

### Metrics
Quality metrics used:__ndcg@k, hitrate@k, map@k, mrr@k__ for k = 1, 5, 10
Additional metrics used: __coverage@k__ and __surprisal@k__.

In [1]:
# ! pip install rs-datasets

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
%config Completer.use_jedi = False

In [4]:
import warnings
from optuna.exceptions import ExperimentalWarning
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=ExperimentalWarning)

`State` object allows passing existing Spark session or create a new one, which will be used by the all RePlay modules.

To create session with custom parameters ``spark.driver.memory`` and ``spark.sql.shuffle.partitions`` use function `get_spark_session` from `session_handler` module.

In [5]:
import logging
import time

from pyspark.sql import functions as sf, types as st
from pyspark.sql.types import IntegerType

from replay.data import Dataset, FeatureHint, FeatureInfo, FeatureSchema, FeatureType
from replay.experimental.preprocessing.data_preparator import Indexer, DataPreparator
from replay.metrics import Experiment
from replay.metrics import Coverage, HitRate, MRR, MAP, NDCG, Surprisal
from replay.models import (
    PopRec, 
    RandomRec,
    UCB,
    Wilson, 
    ThompsonSampling, #added TS
    LinUCB, #added LinUCB (disjoint version)
)

from replay.models.base_rec import HybridRecommender
from replay.utils.session_handler import State
from replay.splitters import TimeSplitter
from replay.utils.spark_utils import convert2spark, get_log_info
from rs_datasets import MovieLens

import pandas as pd
import numpy as np

In [6]:
spark = State().session
spark

24/08/25 22:30:33 WARN Utils: Your hostname, sudakovcom-MS-7D48 resolves to a loopback address: 127.0.1.1; using 10.255.173.26 instead (on interface enp3s0)
24/08/25 22:30:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/08/25 22:30:34 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/08/25 22:30:34 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [7]:
spark.sparkContext.setLogLevel('ERROR')

In [8]:
logger = logging.getLogger("replay")

In [9]:
K = 10
K_list_metrics = [1, 5, 10]
BUDGET = 20
BUDGET_NN = 10
SEED = 12345

## 0. Preprocessing <a name='data-preparator'></a>

### 0.1 Data loading

In [10]:
data = MovieLens("1m")
data.info()

ratings


Unnamed: 0,user_id,item_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968



users


Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117



items


Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance





In [11]:
data.ratings

Unnamed: 0,user_id,item_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291
...,...,...,...,...
1000204,6040,1091,1,956716541
1000205,6040,1094,5,956704887
1000206,6040,562,5,956704746
1000207,6040,1096,4,956715648


#### log preprocessing

- converting to spark dataframe
- renaming columns
- checking for nulls
- converting timestamp to Timestamp format

In [12]:
preparator = DataPreparator()

In [13]:
%%time
log = preparator.transform(columns_mapping={'user_id': 'user_id',
                                      'item_id': 'item_id',
                                      'relevance': 'rating',
                                      'timestamp': 'timestamp'
                                     }, 
                           data=data.ratings)

25-Aug-24 22:30:36, replay, INFO: Columns with ids of users or items are present in mapping. The dataframe will be treated as an interactions log.
  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


CPU times: user 18.3 ms, sys: 9.02 ms, total: 27.3 ms
Wall time: 2.66 s


In [14]:
log.show(2)
#total number of interactions
log.count()

+-------+-------+---------+-------------------+
|user_id|item_id|relevance|          timestamp|
+-------+-------+---------+-------------------+
|      1|   1193|      5.0|2001-01-01 01:12:40|
|      1|    661|      3.0|2001-01-01 01:35:09|
+-------+-------+---------+-------------------+
only showing top 2 rows



1000209

In [15]:
# will consider ratings >= 3 as positive feedback. A positive feedback is treated with relevance = 1
only_positives_log = log.filter(sf.col('relevance') >= 3).withColumn('relevance', sf.lit(1))
only_positives_log.count()

836478

<a id='indexing'></a>
### 0.2. Indexing

Convert given users' and items' identifiers (\_id) to integers starting at zero without gaps (\_idx) with Indexer class.

In [16]:
indexer = Indexer(user_col='user_id', item_col='item_id')

Take all available user and item ids from log and features and pass them to Indexer. The _ids_ could repeat, the indexes will be ordered by label frequencies, so the most frequent label gets index 0.

In [17]:
%%time
indexer.fit(users=log.select('user_id'),
           items=log.select('item_id'))

CPU times: user 14.3 ms, sys: 2.54 ms, total: 16.9 ms
Wall time: 1.04 s


In [18]:
%%time
log_replay = indexer.transform(df=only_positives_log)
log_replay.show(2)
log_replay.count()

+--------+--------+---------+-------------------+
|user_idx|item_idx|relevance|          timestamp|
+--------+--------+---------+-------------------+
|    4131|      43|        1|2001-01-01 01:12:40|
|    4131|     585|        1|2001-01-01 01:35:09|
+--------+--------+---------+-------------------+
only showing top 2 rows

CPU times: user 17.3 ms, sys: 6.54 ms, total: 23.8 ms
Wall time: 888 ms


836478

### 0.2. Data split

In [19]:
# train/test split 
train_spl = TimeSplitter(
    time_threshold=0.2,
    drop_cold_items=True,
    drop_cold_users=True,
    query_column="user_idx",
    item_column="item_idx",
)

train, test = train_spl.split(log_replay)
print('train info:\n', get_log_info(train))
print('test info:\n', get_log_info(test))

train info:
 total lines: 669181, total users: 5397, total items: 3569


                                                                                

test info:
 total lines: 86542, total users: 1139, total items: 3279


In [20]:
train.is_cached

False

In [21]:
# train/test split for hyperparameters selection
opt_train, opt_val = train_spl.split(train)
opt_train.count(), opt_val.count()

                                                                                

(535343, 24241)

In [22]:
opt_train.is_cached

False

In [23]:
# negative feedback will be used for Wilson and UCB models
only_negatives_log = indexer.transform(df=log.filter(sf.col('relevance') < 3).withColumn('relevance', sf.lit(0.)))
test_start = test.agg(sf.min('timestamp')).collect()[0][0]

# train with both positive and negative feedback
pos_neg_train=(train
              .withColumn('relevance', sf.lit(1.))
              .union(only_negatives_log.filter(sf.col('timestamp') < test_start))
             )
pos_neg_train.cache()
pos_neg_train.count()

798993

In [24]:
pos_neg_train.is_cached

True

In [25]:
pos_neg_train.show(20)

+--------+--------+---------+-------------------+
|user_idx|item_idx|relevance|          timestamp|
+--------+--------+---------+-------------------+
|     677|    1314|      1.0|2000-12-02 08:30:12|
|     677|    1282|      1.0|2000-12-02 08:53:52|
|     677|     731|      1.0|2000-12-02 08:41:26|
|     677|     234|      1.0|2000-12-02 08:23:47|
|     677|     190|      1.0|2000-12-02 08:50:33|
|     677|     133|      1.0|2000-12-02 08:27:28|
|     677|     546|      1.0|2000-12-02 08:32:02|
|     677|    2090|      1.0|2000-12-02 08:53:17|
|     677|     421|      1.0|2000-12-02 08:50:13|
|     677|     154|      1.0|2000-12-02 08:44:14|
|     677|      96|      1.0|2000-12-02 08:11:05|
|     677|     221|      1.0|2000-12-02 08:58:08|
|     677|     395|      1.0|2000-12-02 08:33:02|
|     677|      19|      1.0|2000-12-02 08:46:50|
|     677|      73|      1.0|2000-12-02 08:31:30|
|     677|     182|      1.0|2000-12-02 08:18:10|
|     677|     836|      1.0|2000-12-02 08:47:30|


In [26]:
A = pos_neg_train.toPandas()
A.head(20)

Unnamed: 0,user_idx,item_idx,relevance,timestamp
0,677,1314,1.0,2000-12-02 08:30:12
1,677,1282,1.0,2000-12-02 08:53:52
2,677,731,1.0,2000-12-02 08:41:26
3,677,234,1.0,2000-12-02 08:23:47
4,677,190,1.0,2000-12-02 08:50:33
5,677,133,1.0,2000-12-02 08:27:28
6,677,546,1.0,2000-12-02 08:32:02
7,677,2090,1.0,2000-12-02 08:53:17
8,677,421,1.0,2000-12-02 08:50:13
9,677,154,1.0,2000-12-02 08:44:14


# 2. Models training

In [52]:
train_dataset.interactions.count()

669181

In [51]:
test_dataset.interactions.count()

86542

In [54]:
def fit_predict_add_res(name, model, experiment, train, test, suffix=''):
    """
    Run fit_predict for the `model`, measure time on fit_predict and evaluate metrics
    """
    start_time=time.time()
    
    # logs = {'log': train}
    dataset = {'dataset': train}
    predict_params = {'k': K, 'queries': test.query_ids.distinct()}
    # predict_params = {'k': K, 'users': test.select('user_idx').distinct()}
    
    if isinstance(model, (Wilson, UCB, ThompsonSampling, LinUCB)):
        dataset['dataset'] = train_neg_dataset

    # if isinstance(model, LinUCB):
    #     logs['item_features'] = item_features
    #     logs['user_features'] = user_features
    
    predict_params.update(dataset)

    model.fit(**dataset)
    fit_time = time.time() - start_time

    pred=model.predict(**predict_params)
    pred.show(100)
    pred.cache()
    print(pred.count())
    predict_time = time.time() - start_time - fit_time

    experiment.add_result(name + suffix, pred)
    metric_time = time.time() - start_time - fit_time - predict_time
    experiment.results.loc[name + suffix, 'fit_time'] = fit_time
    experiment.results.loc[name + suffix, 'predict_time'] = predict_time
    experiment.results.loc[name + suffix, 'metric_time'] = metric_time
    experiment.results.loc[name + suffix, 'full_time'] = (fit_time + 
                                                          predict_time +
                                                          metric_time)
    pred.unpersist()
    print(experiment.results[['NDCG@{}'.format(K), 'MRR@{}'.format(K), 'Coverage@{}'.format(K), 'fit_time']].sort_values('NDCG@{}'.format(K), ascending=False))

In [28]:
def full_pipeline(models, experiment, train, test, suffix='', budget=BUDGET):
    """
    For each model:
        -  if required: run hyperparameters search, set best params and save param values to `experiment`
        - pass model to `fit_predict_add_res`        
    """
    
    for name, [model, params] in models.items():
        model.logger.info(msg='{} started'.format(name))
        
        logger.info(msg='{} fit_predict started'.format(name))
        fit_predict_add_res(name, model, experiment, train, test, suffix)
        # here we call protected attribute to get all parameters set during model initialization
        experiment.results.loc[name + suffix, 'params'] = str(model._init_args)

### 2.1. Contextual bandit models

### 2.1.1 item features preprocessing

In [29]:
item_features_original = preparator.transform(columns_mapping={'item_id': 'item_id'}, 
                           data=data.items)
item_features = indexer.transform(df=item_features_original)
item_features.show(2)
#different item features

from pyspark.sql.functions import max,min
item_features.select(max(item_features.item_idx)).show()
item_features.select(min(item_features.item_idx)).show()
#just to check that the indexing is dense between 0 and 3882
item_features.count()

25-Aug-24 22:30:50, replay, INFO: Column with ids of users or items is absent in mapping. The dataframe will be treated as a users'/items' features dataframe.
  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


+--------+----------------+--------------------+
|item_idx|           title|              genres|
+--------+----------------+--------------------+
|      29|Toy Story (1995)|Animation|Childre...|
|     393|  Jumanji (1995)|Adventure|Childre...|
+--------+----------------+--------------------+
only showing top 2 rows

+-------------+
|max(item_idx)|
+-------------+
|         3882|
+-------------+

+-------------+
|min(item_idx)|
+-------------+
|            0|
+-------------+



3883

In [30]:
year = item_features.withColumn('year', sf.substring(sf.col('title'), -5, 4).astype(st.IntegerType())).select('item_idx', 'year')
year.show(2)

+--------+----+
|item_idx|year|
+--------+----+
|      29|1995|
|     393|1995|
+--------+----+
only showing top 2 rows



In [31]:
genres = (
    item_features.select(
        "item_idx",
        sf.split("genres", "\|").alias("genres")
    )
)

In [32]:
genres_list = (
    genres.select(sf.explode("genres").alias("genre"))
    .distinct().filter('genre <> "(no genres listed)"')
    .toPandas()["genre"].tolist()
)

In [33]:
genres_list

['Mystery',
 'Action',
 'Documentary',
 "Children's",
 'Drama',
 'Adventure',
 'Film-Noir',
 'Crime',
 'Animation',
 'Fantasy',
 'Comedy',
 'Western',
 'Romance',
 'Thriller',
 'War',
 'Sci-Fi',
 'Musical',
 'Horror']

In [34]:
item_features = genres
for genre in genres_list:
    item_features = item_features.withColumn(
        genre,
        sf.array_contains(sf.col("genres"), genre).astype(IntegerType())
    )
item_features = item_features.drop("genres").cache()
item_features.count()
item_features = item_features.join(year, on='item_idx', how='inner')
item_features.cache()

DataFrame[item_idx: int, Mystery: int, Action: int, Documentary: int, Children's: int, Drama: int, Adventure: int, Film-Noir: int, Crime: int, Animation: int, Fantasy: int, Comedy: int, Western: int, Romance: int, Thriller: int, War: int, Sci-Fi: int, Musical: int, Horror: int, year: int]

### 2.1.2 User features preprocessing

In [35]:
data.users.head()

Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [36]:
#same preprocessing for users as was done in 2.4.1.
user_features_original = preparator.transform(columns_mapping={'user_id': 'user_id'}, 
                           data=data.users)
user_features = indexer.transform(df=user_features_original)
#switch for a while into pandas
user_features = user_features.toPandas()
user_features.head(2)

25-Aug-24 22:30:52, replay, INFO: Column with ids of users or items is absent in mapping. The dataframe will be treated as a users'/items' features dataframe.
  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


Unnamed: 0,user_idx,gender,age,occupation,zip_code
0,4131,F,1,10,48067
1,2364,M,56,16,70072


In [37]:
from sklearn.preprocessing import OneHotEncoder

In [38]:
print("max ocupation index: ", user_features['occupation'].max())
print("min ocupation index: ", user_features['occupation'].min())
count_diff_zips = user_features['zip_code'].unique().size
print("different zip codes: ", count_diff_zips) #ok, too much different zip codes, let us drop them for now
users_pd = user_features.drop(columns=['zip_code'])
users_pd.head()
#binarize age variable
bins = [0, 20, 30, 40, 50, 60, np.inf]
names = ['<20', '20-29', '30-39','40-49', '51-60', '60+']

users_pd['agegroup'] = pd.cut(users_pd['age'], bins, labels=names)
users_pd = users_pd.drop(["age"], axis = 1)
users_pd.head()

#binarize following https://github.com/kfoofw/bandit_simulations/tree/master
columnsToEncode = ["agegroup","gender","occupation"]
myEncoder = OneHotEncoder(sparse=False, handle_unknown='ignore')
myEncoder.fit(users_pd[columnsToEncode])

users_pd = pd.concat([users_pd.drop(columnsToEncode, 1),
                           pd.DataFrame(myEncoder.transform(users_pd[columnsToEncode]), 
                                        columns = myEncoder.get_feature_names(columnsToEncode))], axis=1).reindex()
users_pd.head()

max ocupation index:  20
min ocupation index:  0
different zip codes:  3439


  users_pd = pd.concat([users_pd.drop(columnsToEncode, 1),


Unnamed: 0,user_idx,agegroup_20-29,agegroup_30-39,agegroup_40-49,agegroup_51-60,agegroup_<20,gender_F,gender_M,occupation_0,occupation_1,...,occupation_11,occupation_12,occupation_13,occupation_14,occupation_15,occupation_16,occupation_17,occupation_18,occupation_19,occupation_20
0,4131,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2364,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
2,4217,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
3,5916,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1603,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


In [39]:
#make it pyspark
user_features = spark.createDataFrame(users_pd)
user_features.printSchema()
user_features.show()
print("total users: ",user_features.count())

root
 |-- user_idx: integer (nullable = true)
 |-- agegroup_20-29: double (nullable = true)
 |-- agegroup_30-39: double (nullable = true)
 |-- agegroup_40-49: double (nullable = true)
 |-- agegroup_51-60: double (nullable = true)
 |-- agegroup_<20: double (nullable = true)
 |-- gender_F: double (nullable = true)
 |-- gender_M: double (nullable = true)
 |-- occupation_0: double (nullable = true)
 |-- occupation_1: double (nullable = true)
 |-- occupation_2: double (nullable = true)
 |-- occupation_3: double (nullable = true)
 |-- occupation_4: double (nullable = true)
 |-- occupation_5: double (nullable = true)
 |-- occupation_6: double (nullable = true)
 |-- occupation_7: double (nullable = true)
 |-- occupation_8: double (nullable = true)
 |-- occupation_9: double (nullable = true)
 |-- occupation_10: double (nullable = true)
 |-- occupation_11: double (nullable = true)
 |-- occupation_12: double (nullable = true)
 |-- occupation_13: double (nullable = true)
 |-- occupation_14: double

  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


+--------+--------------+--------------+--------------+--------------+------------+--------+--------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+
|user_idx|agegroup_20-29|agegroup_30-39|agegroup_40-49|agegroup_51-60|agegroup_<20|gender_F|gender_M|occupation_0|occupation_1|occupation_2|occupation_3|occupation_4|occupation_5|occupation_6|occupation_7|occupation_8|occupation_9|occupation_10|occupation_11|occupation_12|occupation_13|occupation_14|occupation_15|occupation_16|occupation_17|occupation_18|occupation_19|occupation_20|
+--------+--------------+--------------+--------------+--------------+------------+--------+--------+------------+------------+------------+------------+------------+------------+------------+------------+------------+----------

## 2.2. Fitting various bandits

In [55]:
bandit_models = {
          'Wilson': [Wilson(), 'no_opt'],
        #   'TS (context-free)': [ThompsonSampling(),'no_opt'],
          'Linear UCB (disjoint models)':[LinUCB(eps = 0.0, alpha = 1.0, regr_type = 'disjoint'), 'no_opt'],
          # 'UCB': [UCB(exploration_coef=2.0), 'no_opt'], 
         }

In [41]:
feature_schema = FeatureSchema(
    [
        FeatureInfo(
            column="user_idx",
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.QUERY_ID,
        ),
        FeatureInfo(
            column="item_idx",
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.ITEM_ID,
        ),
        FeatureInfo(
            column="relevance",
            feature_type=FeatureType.NUMERICAL,
            feature_hint=FeatureHint.RATING,
        ),
        FeatureInfo(
            column="timestamp",
            feature_type=FeatureType.NUMERICAL,
            feature_hint=FeatureHint.TIMESTAMP,
        ),
    ]
)

In [42]:
all_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=log_replay,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

train_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=train,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

test_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=test,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

train_neg_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=pos_neg_train,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

opt_train_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=opt_train,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

opt_val_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=opt_val,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

                                                                                

In [43]:
# e = Experiment(test, {MAP(): K, NDCG(): K, HitRate(): K_list_metrics, Coverage(train): K, Surprisal(train): K, MRR(): K})

e = Experiment(
    [
        MAP(K),
        NDCG(K),
        HitRate(K_list_metrics),
        Coverage(K),
        Surprisal(K),
        MRR(K)
    ],
    test_dataset.interactions,
    train_dataset.interactions,
    query_column=train_dataset.feature_schema.query_id_column,
    item_column=train_dataset.feature_schema.item_id_column,
    rating_column=train_dataset.feature_schema.interactions_rating_column,
    )

In [56]:
%%time
full_pipeline(bandit_models, e, train_dataset, test_dataset)

25-Aug-24 22:54:25, replay, INFO: Wilson started
INFO:replay:Wilson started
25-Aug-24 22:54:25, replay, INFO: Wilson fit_predict started
INFO:replay:Wilson fit_predict started
                                                                                

+--------+--------+------------------+
|user_idx|item_idx|         relevance|
+--------+--------+------------------+
|      18|     707|0.9754524358354291|
|      18|    1209| 0.972226295602107|
|      18|     557| 0.971408022009853|
|      18|     682|0.9698080855041509|
|      18|     234|0.9688400027862992|
|      18|    1081|0.9685863803126885|
|      18|      32|0.9660186961163115|
|      18|     238|0.9647924991594925|
|      18|     239|0.9647424995568235|
|      18|    1080|0.9643378106230265|
|      38|     106|0.9821985244892515|
|      38|     186|0.9795457647582008|
|      38|     569|0.9745366264739433|
|      38|     672|0.9737961511019336|
|      38|    1209| 0.972226295602107|
|      38|     557| 0.971408022009853|
|      38|     682|0.9698080855041509|
|      38|     191|0.9691067567560038|
|      38|     234|0.9688400027862992|
|      38|    1081|0.9685863803126885|
|      46|     400|0.9900259255456659|
|      46|     106|0.9821985244892515|
|      46|     186|0.9795

25-Aug-24 22:54:36, replay, INFO: Linear UCB (disjoint models) started          
INFO:replay:Linear UCB (disjoint models) started
25-Aug-24 22:54:36, replay, INFO: Linear UCB (disjoint models) fit_predict started
INFO:replay:Linear UCB (disjoint models) fit_predict started


                               NDCG@10    MRR@10  Coverage@10   fit_time
Wilson                        0.092121  0.180976     0.019333   0.970134
Linear UCB (disjoint models)  0.008875  0.021985     0.184365  15.257740


  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


+--------+--------+------------------+
|user_idx|item_idx|         relevance|
+--------+--------+------------------+
|      18|    2595|1.3169453963924966|
|      18|    2254|1.2917079134269787|
|      18|    2733|1.2364928810499012|
|      18|    2951|1.2124520507524346|
|      18|    2244| 1.191809973084726|
|      18|    2818|1.1915501269577513|
|      18|    2303|1.1897062015493016|
|      18|    2685| 1.189340716982624|
|      18|    2878|1.1843715461948212|
|      18|    3133|1.1840596330275244|
|      38|    2725|1.3462268707302054|
|      38|    2541|1.2728948307644592|
|      38|    1937|1.2073352372867336|
|      38|    2464|1.1921167645459503|
|      38|    3030|1.1551753009088448|
|      38|    2166| 1.137642974305147|
|      38|    1847| 1.119909860508019|
|      38|    2303|1.0974129341342356|
|      38|    2454|1.0830263513776037|
|      38|    2329| 1.077229904940043|
|      46|    2904|1.2299356891191637|
|      46|    2832|1.2033312017558484|
|      46|    2682|1.1456

                                                                                

                               NDCG@10    MRR@10  Coverage@10   fit_time
Wilson                        0.092121  0.180976     0.019333   0.970134
Linear UCB (disjoint models)  0.008875  0.021985     0.184365  18.582876
CPU times: user 1min 7s, sys: 3min 4s, total: 4min 11s
Wall time: 43.2 s


In [45]:
e.results.sort_values('NDCG@10', ascending=False)

Unnamed: 0,MAP@10,NDCG@10,HitRate@1,HitRate@5,HitRate@10,Coverage@10,Surprisal@10,MRR@10,fit_time,predict_time,metric_time,full_time,params
Linear UCB (disjoint models),0.002883,0.008875,0.007902,0.034241,0.073749,0.184365,0.614196,0.021985,15.25774,7.935328,10.143539,33.336606,"{'regression type': 'disjoint', 'seed': None}"
