# RePlay bandit models comparison

We will show the main RePlay functionality and compare performance of RePlay models on well-known MovieLens dataset. For simplicity we consider here only the various bandit algorithms, both context-free and context-aware. The list of considered strategies for comparison:

Context-free algorithms:
* Most popular;
* Vanilla UCB algorithm;
* Vanilla TS algorithm; (Beta-Binomial);
* KL-UCB [Smb et al];

Contextual bandits algorithms:
* Lin-UCB [Smb et al];
* Linear TS [Smb et al] (Thompson sampling with linear feature vectors);


### Dataset
We will compare RePlay models on __MovieLens 1m__. 

### Dataset preprocessing: 
Ratings greater than or equal to 3 are considered as positive interactions.

### Data split
Dataset is split by date so that 20% of the last interactions as are placed in the test part. Cold items and users are dropped.

### Predict:
We will predict top-10 most relevant films for each user.

### Metrics
Quality metrics used:__ndcg@k, hitrate@k, map@k, mrr@k__ for k = 1, 5, 10
Additional metrics used: __coverage@k__ and __surprisal@k__.

In [1]:
# ! pip install rs-datasets

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
%config Completer.use_jedi = False

In [4]:
import warnings
from optuna.exceptions import ExperimentalWarning
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=ExperimentalWarning)

`State` object allows passing existing Spark session or create a new one, which will be used by the all RePlay modules.

To create session with custom parameters ``spark.driver.memory`` and ``spark.sql.shuffle.partitions`` use function `get_spark_session` from `session_handler` module.

In [5]:
import logging
import time

from pyspark.sql import functions as sf, types as st
from pyspark.sql.types import IntegerType

from replay.data import Dataset, FeatureHint, FeatureInfo, FeatureSchema, FeatureType
from replay.experimental.preprocessing.data_preparator import Indexer, DataPreparator
from replay.metrics import Experiment
from replay.metrics import Coverage, HitRate, MRR, MAP, NDCG, Surprisal
from replay.models import (
    PopRec, 
    RandomRec,
    UCB,
    Wilson, 
    ThompsonSampling, #added TS
    LinUCB, #added LinUCB (disjoint version)
)

from replay.models.base_rec import HybridRecommender
from replay.utils.session_handler import State
from replay.splitters import TimeSplitter
from replay.utils.spark_utils import convert2spark, get_log_info
from rs_datasets import MovieLens

import pandas as pd
import numpy as np

In [6]:
spark = State().session
spark

24/08/27 00:17:44 WARN Utils: Your hostname, sudakovcom-MS-7D48 resolves to a loopback address: 127.0.1.1; using 10.255.173.26 instead (on interface enp3s0)
24/08/27 00:17:44 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/08/27 00:17:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/08/27 00:17:45 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


In [7]:
spark.sparkContext.setLogLevel('ERROR')

In [8]:
logger = logging.getLogger("replay")

In [9]:
K = 10
K_list_metrics = [1, 5, 10]
BUDGET = 20
BUDGET_NN = 10
SEED = 12345

## 0. Preprocessing <a name='data-preparator'></a>

### 0.1 Data loading

In [10]:
data = MovieLens("1m")
data.info()

ratings


Unnamed: 0,user_id,item_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968



users


Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117



items


Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance





In [11]:
data.ratings

Unnamed: 0,user_id,item_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291
...,...,...,...,...
1000204,6040,1091,1,956716541
1000205,6040,1094,5,956704887
1000206,6040,562,5,956704746
1000207,6040,1096,4,956715648


#### log preprocessing

- converting to spark dataframe
- renaming columns
- checking for nulls
- converting timestamp to Timestamp format

In [12]:
preparator = DataPreparator()

In [13]:
%%time
log = preparator.transform(columns_mapping={'user_id': 'user_id',
                                      'item_id': 'item_id',
                                      'relevance': 'rating',
                                      'timestamp': 'timestamp'
                                     }, 
                           data=data.ratings)

27-Aug-24 00:17:47, replay, INFO: Columns with ids of users or items are present in mapping. The dataframe will be treated as an interactions log.
  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


CPU times: user 19.5 ms, sys: 7.53 ms, total: 27.1 ms
Wall time: 2.85 s


In [14]:
log.show(2)
#total number of interactions
log.count()

+-------+-------+---------+-------------------+
|user_id|item_id|relevance|          timestamp|
+-------+-------+---------+-------------------+
|      1|   1193|      5.0|2001-01-01 01:12:40|
|      1|    661|      3.0|2001-01-01 01:35:09|
+-------+-------+---------+-------------------+
only showing top 2 rows



1000209

In [15]:
# will consider ratings >= 3 as positive feedback. A positive feedback is treated with relevance = 1
only_positives_log = log.filter(sf.col('relevance') >= 3).withColumn('relevance', sf.lit(1))
only_positives_log.count()

836478

<a id='indexing'></a>
### 0.2. Indexing

Convert given users' and items' identifiers (\_id) to integers starting at zero without gaps (\_idx) with Indexer class.

In [16]:
indexer = Indexer(user_col='user_id', item_col='item_id')

Take all available user and item ids from log and features and pass them to Indexer. The _ids_ could repeat, the indexes will be ordered by label frequencies, so the most frequent label gets index 0.

In [17]:
%%time
indexer.fit(users=log.select('user_id'),
           items=log.select('item_id'))

CPU times: user 14.3 ms, sys: 4.33 ms, total: 18.6 ms
Wall time: 1.05 s


In [18]:
%%time
log_replay = indexer.transform(df=only_positives_log)
log_replay.show(2)
log_replay.count()

+--------+--------+---------+-------------------+
|user_idx|item_idx|relevance|          timestamp|
+--------+--------+---------+-------------------+
|    4131|      43|        1|2001-01-01 01:12:40|
|    4131|     585|        1|2001-01-01 01:35:09|
+--------+--------+---------+-------------------+
only showing top 2 rows

CPU times: user 21.3 ms, sys: 3.21 ms, total: 24.5 ms
Wall time: 919 ms


836478

### 0.2. Data split

In [19]:
# train/test split 
train_spl = TimeSplitter(
    time_threshold=0.2,
    drop_cold_items=True,
    drop_cold_users=True,
    query_column="user_idx",
    item_column="item_idx",
)

train, test = train_spl.split(log_replay)
print('train info:\n', get_log_info(train))
print('test info:\n', get_log_info(test))

train info:
 total lines: 669181, total users: 5397, total items: 3569


                                                                                

test info:
 total lines: 86542, total users: 1139, total items: 3279


In [20]:
train.is_cached

False

In [21]:
# train/test split for hyperparameters selection
opt_train, opt_val = train_spl.split(train)
opt_train.count(), opt_val.count()

(535343, 24241)

In [22]:
opt_train.is_cached

False

In [23]:
# negative feedback will be used for Wilson and UCB models
only_negatives_log = indexer.transform(df=log.filter(sf.col('relevance') < 3).withColumn('relevance', sf.lit(0.)))
test_start = test.agg(sf.min('timestamp')).collect()[0][0]

# train with both positive and negative feedback
pos_neg_train=(train
              .withColumn('relevance', sf.lit(1.))
              .union(only_negatives_log.filter(sf.col('timestamp') < test_start))
             )
pos_neg_train.cache()
pos_neg_train.count()

798993

In [24]:
pos_neg_train.is_cached

True

In [25]:
pos_neg_train.show(20)

+--------+--------+---------+-------------------+
|user_idx|item_idx|relevance|          timestamp|
+--------+--------+---------+-------------------+
|     677|    1314|      1.0|2000-12-02 08:30:12|
|     677|    1282|      1.0|2000-12-02 08:53:52|
|     677|     731|      1.0|2000-12-02 08:41:26|
|     677|     234|      1.0|2000-12-02 08:23:47|
|     677|     190|      1.0|2000-12-02 08:50:33|
|     677|     133|      1.0|2000-12-02 08:27:28|
|     677|     546|      1.0|2000-12-02 08:32:02|
|     677|    2090|      1.0|2000-12-02 08:53:17|
|     677|     421|      1.0|2000-12-02 08:50:13|
|     677|     154|      1.0|2000-12-02 08:44:14|
|     677|      96|      1.0|2000-12-02 08:11:05|
|     677|     221|      1.0|2000-12-02 08:58:08|
|     677|     395|      1.0|2000-12-02 08:33:02|
|     677|      19|      1.0|2000-12-02 08:46:50|
|     677|      73|      1.0|2000-12-02 08:31:30|
|     677|     182|      1.0|2000-12-02 08:18:10|
|     677|     836|      1.0|2000-12-02 08:47:30|


In [26]:
A = pos_neg_train.toPandas()
A.head(20)

Unnamed: 0,user_idx,item_idx,relevance,timestamp
0,677,1314,1.0,2000-12-02 08:30:12
1,677,1282,1.0,2000-12-02 08:53:52
2,677,731,1.0,2000-12-02 08:41:26
3,677,234,1.0,2000-12-02 08:23:47
4,677,190,1.0,2000-12-02 08:50:33
5,677,133,1.0,2000-12-02 08:27:28
6,677,546,1.0,2000-12-02 08:32:02
7,677,2090,1.0,2000-12-02 08:53:17
8,677,421,1.0,2000-12-02 08:50:13
9,677,154,1.0,2000-12-02 08:44:14


# 2. Models training

In [27]:
def fit_predict_add_res(name, model, experiment, train, test, suffix=''):
    """
    Run fit_predict for the `model`, measure time on fit_predict and evaluate metrics
    """
    start_time=time.time()
    
    dataset = {'dataset': train}
    predict_params = {'k': K, 'queries': test.interactions.select('user_idx').distinct()}
    
    if isinstance(model, (Wilson, UCB, ThompsonSampling, LinUCB)):
        dataset['dataset'] = train_neg_dataset
    
    predict_params.update(dataset)

    model.fit(**dataset)
    fit_time = time.time() - start_time

    pred=model.predict(**predict_params)
    pred.show(100)
    pred.cache()
    predict_time = time.time() - start_time - fit_time

    experiment.add_result(name + suffix, pred)
    metric_time = time.time() - start_time - fit_time - predict_time
    experiment.results.loc[name + suffix, 'fit_time'] = fit_time
    experiment.results.loc[name + suffix, 'predict_time'] = predict_time
    experiment.results.loc[name + suffix, 'metric_time'] = metric_time
    experiment.results.loc[name + suffix, 'full_time'] = (fit_time + 
                                                          predict_time +
                                                          metric_time)
    pred.unpersist()
    print(experiment.results[['NDCG@{}'.format(K), 'MRR@{}'.format(K), 'Coverage@{}'.format(K), 'fit_time']].sort_values('NDCG@{}'.format(K), ascending=False))

In [28]:
def full_pipeline(models, experiment, train, test, suffix='', budget=BUDGET):
    """
    For each model:
        -  if required: run hyperparameters search, set best params and save param values to `experiment`
        - pass model to `fit_predict_add_res`        
    """
    
    for name, [model, params] in models.items():
        model.logger.info(msg='{} started'.format(name))

        if params != 'no_opt':
            model.logger.info(msg='{} optimization started'.format(name))
            best_params = model.optimize(opt_train_dataset, 
                                         opt_val_dataset, 
                                         param_borders=params, 
                                         k=K, 
                                         budget=budget)
            logger.info(msg='best params for {} are: {}'.format(name, best_params))
            model.set_params(**best_params)

        
        logger.info(msg='{} fit_predict started'.format(name))
        fit_predict_add_res(name, model, experiment, train, test, suffix)
        # here we call protected attribute to get all parameters set during model initialization
        experiment.results.loc[name + suffix, 'params'] = str(model._init_args)

### 2.1. Contextual bandit models

### 2.1.1 item features preprocessing

In [29]:
item_features_original = preparator.transform(columns_mapping={'item_id': 'item_id'}, 
                           data=data.items)
item_features = indexer.transform(df=item_features_original)
item_features.show(2)
#different item features

from pyspark.sql.functions import max,min
item_features.select(max(item_features.item_idx)).show()
item_features.select(min(item_features.item_idx)).show()
#just to check that the indexing is dense between 0 and 3882
item_features.count()

27-Aug-24 00:18:01, replay, INFO: Column with ids of users or items is absent in mapping. The dataframe will be treated as a users'/items' features dataframe.
  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


+--------+----------------+--------------------+
|item_idx|           title|              genres|
+--------+----------------+--------------------+
|      29|Toy Story (1995)|Animation|Childre...|
|     393|  Jumanji (1995)|Adventure|Childre...|
+--------+----------------+--------------------+
only showing top 2 rows

+-------------+
|max(item_idx)|
+-------------+
|         3882|
+-------------+

+-------------+
|min(item_idx)|
+-------------+
|            0|
+-------------+



3883

In [30]:
year = item_features.withColumn('year', sf.substring(sf.col('title'), -5, 4).astype(st.IntegerType())).select('item_idx', 'year')
year.show(2)

+--------+----+
|item_idx|year|
+--------+----+
|      29|1995|
|     393|1995|
+--------+----+
only showing top 2 rows



In [31]:
genres = (
    item_features.select(
        "item_idx",
        sf.split("genres", "\|").alias("genres")
    )
)

In [32]:
genres_list = (
    genres.select(sf.explode("genres").alias("genre"))
    .distinct().filter('genre <> "(no genres listed)"')
    .toPandas()["genre"].tolist()
)

In [33]:
genres_list

['Mystery',
 'Action',
 'Documentary',
 "Children's",
 'Drama',
 'Adventure',
 'Film-Noir',
 'Crime',
 'Animation',
 'Fantasy',
 'Comedy',
 'Western',
 'Romance',
 'Thriller',
 'War',
 'Sci-Fi',
 'Musical',
 'Horror']

In [34]:
item_features = genres
for genre in genres_list:
    item_features = item_features.withColumn(
        genre,
        sf.array_contains(sf.col("genres"), genre).astype(IntegerType())
    )
item_features = item_features.drop("genres").cache()
item_features.count()
item_features = item_features.join(year, on='item_idx', how='inner')
item_features.cache()

DataFrame[item_idx: int, Mystery: int, Action: int, Documentary: int, Children's: int, Drama: int, Adventure: int, Film-Noir: int, Crime: int, Animation: int, Fantasy: int, Comedy: int, Western: int, Romance: int, Thriller: int, War: int, Sci-Fi: int, Musical: int, Horror: int, year: int]

### 2.1.2 User features preprocessing

In [35]:
data.users.head()

Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [36]:
#same preprocessing for users as was done in 2.4.1.
user_features_original = preparator.transform(columns_mapping={'user_id': 'user_id'}, 
                           data=data.users)
user_features = indexer.transform(df=user_features_original)
#switch for a while into pandas
user_features = user_features.toPandas()
user_features.head(2)

27-Aug-24 00:18:03, replay, INFO: Column with ids of users or items is absent in mapping. The dataframe will be treated as a users'/items' features dataframe.
  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


Unnamed: 0,user_idx,gender,age,occupation,zip_code
0,4131,F,1,10,48067
1,2364,M,56,16,70072


In [37]:
from sklearn.preprocessing import OneHotEncoder

In [38]:
print("max ocupation index: ", user_features['occupation'].max())
print("min ocupation index: ", user_features['occupation'].min())
count_diff_zips = user_features['zip_code'].unique().size
print("different zip codes: ", count_diff_zips) #ok, too much different zip codes, let us drop them for now
users_pd = user_features.drop(columns=['zip_code'])
users_pd.head()
#binarize age variable
bins = [0, 20, 30, 40, 50, 60, np.inf]
names = ['<20', '20-29', '30-39','40-49', '51-60', '60+']

users_pd['agegroup'] = pd.cut(users_pd['age'], bins, labels=names)
users_pd = users_pd.drop(["age"], axis = 1)
users_pd.head()

#binarize following https://github.com/kfoofw/bandit_simulations/tree/master
columnsToEncode = ["agegroup","gender","occupation"]
myEncoder = OneHotEncoder(sparse=False, handle_unknown='ignore')
myEncoder.fit(users_pd[columnsToEncode])

users_pd = pd.concat([users_pd.drop(columnsToEncode, 1),
                           pd.DataFrame(myEncoder.transform(users_pd[columnsToEncode]), 
                                        columns = myEncoder.get_feature_names(columnsToEncode))], axis=1).reindex()
users_pd.head()

max ocupation index:  20
min ocupation index:  0
different zip codes:  3439


  users_pd = pd.concat([users_pd.drop(columnsToEncode, 1),


Unnamed: 0,user_idx,agegroup_20-29,agegroup_30-39,agegroup_40-49,agegroup_51-60,agegroup_<20,gender_F,gender_M,occupation_0,occupation_1,...,occupation_11,occupation_12,occupation_13,occupation_14,occupation_15,occupation_16,occupation_17,occupation_18,occupation_19,occupation_20
0,4131,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,2364,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0
2,4217,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
3,5916,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1603,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


In [39]:
#make it pyspark
user_features = spark.createDataFrame(users_pd)
user_features.printSchema()
user_features.show()
print("total users: ",user_features.count())

root
 |-- user_idx: integer (nullable = true)
 |-- agegroup_20-29: double (nullable = true)
 |-- agegroup_30-39: double (nullable = true)
 |-- agegroup_40-49: double (nullable = true)
 |-- agegroup_51-60: double (nullable = true)
 |-- agegroup_<20: double (nullable = true)
 |-- gender_F: double (nullable = true)
 |-- gender_M: double (nullable = true)
 |-- occupation_0: double (nullable = true)
 |-- occupation_1: double (nullable = true)
 |-- occupation_2: double (nullable = true)
 |-- occupation_3: double (nullable = true)
 |-- occupation_4: double (nullable = true)
 |-- occupation_5: double (nullable = true)
 |-- occupation_6: double (nullable = true)
 |-- occupation_7: double (nullable = true)
 |-- occupation_8: double (nullable = true)
 |-- occupation_9: double (nullable = true)
 |-- occupation_10: double (nullable = true)
 |-- occupation_11: double (nullable = true)
 |-- occupation_12: double (nullable = true)
 |-- occupation_13: double (nullable = true)
 |-- occupation_14: double

  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


+--------+--------------+--------------+--------------+--------------+------------+--------+--------+------------+------------+------------+------------+------------+------------+------------+------------+------------+------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+
|user_idx|agegroup_20-29|agegroup_30-39|agegroup_40-49|agegroup_51-60|agegroup_<20|gender_F|gender_M|occupation_0|occupation_1|occupation_2|occupation_3|occupation_4|occupation_5|occupation_6|occupation_7|occupation_8|occupation_9|occupation_10|occupation_11|occupation_12|occupation_13|occupation_14|occupation_15|occupation_16|occupation_17|occupation_18|occupation_19|occupation_20|
+--------+--------------+--------------+--------------+--------------+------------+--------+--------+------------+------------+------------+------------+------------+------------+------------+------------+------------+----------

## 2.2. Fitting various bandits

In [40]:
bandit_models = {
          'Popular': [PopRec(), 'no_opt'], 
          'Random (uniform)': [RandomRec(seed=SEED, distribution='uniform'), 'no_opt'], 
          'Random (popularity-based)': [RandomRec(seed=SEED, distribution='popular_based'), {"alpha": [-0.5, 100]}],
          'UCB': [UCB(exploration_coef=2.0), 'no_opt'], #2.0 as default, 0.5 as original 
          'Wilson': [Wilson(), 'no_opt'],
          'TS (context-free)': [ThompsonSampling(),'no_opt'],
          'Linear UCB (eps = -10.0)(disjoint models)':[LinUCB(eps = -10.0, alpha = 1.0, regr_type = 'disjoint'), 'no_opt'],
          'Linear UCB (eps = -5.0)(disjoint models)':[LinUCB(eps = -5.0, alpha = 1.0, regr_type = 'disjoint'), 'no_opt'],
          'Linear UCB (eps = -2.0)(disjoint models)':[LinUCB(eps = -2.0, alpha = 1.0, regr_type = 'disjoint'), 'no_opt'],
          'Linear UCB (disjoint models)':[LinUCB(eps = 0.0, alpha = 1.0, regr_type = 'disjoint'), {"eps": [-20.0, 10.0], "alpha": [0.001, 10.0]}]
         }

In [41]:
feature_schema = FeatureSchema(
    [
        FeatureInfo(
            column="user_idx",
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.QUERY_ID,
        ),
        FeatureInfo(
            column="item_idx",
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.ITEM_ID,
        ),
        FeatureInfo(
            column="relevance",
            feature_type=FeatureType.NUMERICAL,
            feature_hint=FeatureHint.RATING,
        ),
        FeatureInfo(
            column="timestamp",
            feature_type=FeatureType.NUMERICAL,
            feature_hint=FeatureHint.TIMESTAMP,
        ),
    ]
)

In [42]:
feature_schema.columns

['user_idx', 'item_idx', 'relevance', 'timestamp']

In [43]:
feature_schema.query_id_column

'user_idx'

In [44]:
all_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=log_replay,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

train_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=train,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

test_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=test,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

train_neg_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=pos_neg_train,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

opt_train_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=opt_train,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

opt_val_dataset = Dataset(
    feature_schema=feature_schema,
    interactions=opt_val,
    item_features=item_features,
    query_features=user_features,
    categorical_encoded = True
)

In [45]:
e = Experiment(
    [
        MAP(K),
        NDCG(K),
        HitRate(K_list_metrics),
        Coverage(K),
        Surprisal(K),
        MRR(K)
    ],
    test_dataset.interactions,
    train_dataset.interactions,
    query_column=train_dataset.feature_schema.query_id_column,
    item_column=train_dataset.feature_schema.item_id_column,
    rating_column=train_dataset.feature_schema.interactions_rating_column,
    )

In [46]:
%%time
full_pipeline(bandit_models, e, train_dataset, test_dataset)

27-Aug-24 00:18:17, replay, INFO: Popular started
27-Aug-24 00:18:17, replay, INFO: Popular fit_predict started
                                                                                

+--------+--------+-------------------+
|user_idx|item_idx|          relevance|
+--------+--------+-------------------+
|      18|      14| 0.3258278145695364|
|      18|      32| 0.2682119205298013|
|      18|      48|  0.223841059602649|
|      18|      47| 0.2187086092715232|
|      18|      89|0.16258278145695365|
|      18|     114|0.15960264900662252|
|      18|     101| 0.1490066225165563|
|      18|     105| 0.1478476821192053|
|      18|     124|0.14718543046357616|
|      18|     157|0.14089403973509934|
|      46|      12|0.33410596026490064|
|      46|      18| 0.3086092715231788|
|      46|      19| 0.2905629139072848|
|      46|      22| 0.2900662251655629|
|      46|      28| 0.2814569536423841|
|      46|      30| 0.2682119205298013|
|      46|      31|               0.25|
|      46|      35|0.24188741721854304|
|      46|      34|0.23791390728476822|
|      46|      43|0.22996688741721855|
|     186|      12|0.33410596026490064|
|     186|      17| 0.3130794701986755|


27-Aug-24 00:18:33, replay, INFO: Random (uniform) started                      
27-Aug-24 00:18:33, replay, INFO: Random (uniform) fit_predict started


         NDCG@10    MRR@10  Coverage@10  fit_time
Popular  0.24367  0.390414     0.033903  1.740541


                                                                                

+--------+--------+------------------+
|user_idx|item_idx|         relevance|
+--------+--------+------------------+
|      18|     757|0.9993475987324224|
|      18|    1948|0.9958625269029056|
|      18|    1658|0.9958581137949202|
|      18|    1522|0.9951778117229308|
|      18|    3706|0.9930875755613394|
|      18|     197|0.9930827474156511|
|      18|    2644|0.9929310746247295|
|      18|    2730|0.9927206244709196|
|      18|    3202|0.9926568867577302|
|      18|    1782|0.9921495467387857|
|      46|    3827|0.9994992249598874|
|      46|    3207|0.9991521584153258|
|      46|    3222|0.9991089319441182|
|      46|    2128|0.9955462962465224|
|      46|    1540|0.9949174902180918|
|      46|    2273|0.9942066707891665|
|      46|    2286|0.9940611391623346|
|      46|    3396| 0.991560971355034|
|      46|    1596|0.9901570394788072|
|      46|     311|0.9865957485430298|
|     186|    3637|0.9969224342334095|
|     186|     734|0.9964346644226986|
|     186|     493|0.9940

27-Aug-24 00:18:47, replay, INFO: Random (popularity-based) started             
27-Aug-24 00:18:47, replay, INFO: Random (popularity-based) optimization started
[I 2024-08-27 00:18:47,676] A new study created in memory with name: no-name-ca65b5c6-62bc-4d48-bcdb-fee08f366963


                   NDCG@10    MRR@10  Coverage@10  fit_time
Popular           0.243670  0.390414     0.033903  1.740541
Random (uniform)  0.022416  0.058170     0.942281  1.568186


  res[param] = suggest_fn(param, low=low, high=high)
[I 2024-08-27 00:18:57,033] Trial 0 finished with value: 0.026565537776860486 and parameters: {'distribution': 'popular_based', 'alpha': 0.0}. Best is trial 0 with value: 0.026565537776860486.
  res[param] = suggest_fn(param, low=low, high=high)
[I 2024-08-27 00:19:05,627] Trial 1 finished with value: 0.029649158269859532 and parameters: {'distribution': 'popular_based', 'alpha': 97.8655970002552}. Best is trial 1 with value: 0.029649158269859532.
  res[param] = suggest_fn(param, low=low, high=high)
[I 2024-08-27 00:19:14,621] Trial 2 finished with value: 0.03110143492922837 and parameters: {'distribution': 'popular_based', 'alpha': 46.420884965440386}. Best is trial 2 with value: 0.03110143492922837.
  res[param] = suggest_fn(param, low=low, high=high)
[I 2024-08-27 00:19:23,076] Trial 3 finished with value: 0.03335014317834524 and parameters: {'distribution': 'popular_based', 'alpha': 45.90756397067951}. Best is trial 3 with value:

+--------+--------+------------------+
|user_idx|item_idx|         relevance|
+--------+--------+------------------+
|      18|     656|0.9997139455542501|
|      18|    2892|0.9994633451862948|
|      18|    3531|0.9914165570707009|
|      18|    2815|0.9906310529803498|
|      18|    2760|0.9889421916372966|
|      18|    1823|0.9888254013706983|
|      18|    2223|0.9872588270995802|
|      18|    1641|0.9869414493052643|
|      18|    2436|0.9865949351371692|
|      18|    1674|0.9864513183475915|
|      46|    2673| 0.999130288906365|
|      46|    2175|0.9951027381140876|
|      46|    2181|0.9942066707891665|
|      46|    2913| 0.993477244244402|
|      46|    2676|0.9925041031141819|
|      46|     564|0.9905236941807516|
|      46|    2519|0.9870461757938168|
|      46|     219|0.9865957485430298|
|      46|    2797|0.9858041049618765|
|      46|    1871|0.9824781310008317|
|     186|    2409|0.9983813611913886|
|     186|     700|0.9964346644226986|
|     186|    2623|0.9952

27-Aug-24 00:21:50, replay, INFO: UCB started
27-Aug-24 00:21:50, replay, INFO: UCB fit_predict started


                            NDCG@10    MRR@10  Coverage@10  fit_time
Popular                    0.243670  0.390414     0.033903  1.740541
Random (popularity-based)  0.027783  0.065515     0.737181  1.732029
Random (uniform)           0.022416  0.058170     0.942281  1.568186


                                                                                

+--------+--------+-----------------+
|user_idx|item_idx|        relevance|
+--------+--------+-----------------+
|      18|    3882|6.213656579361331|
|      18|    3881|6.213656579361331|
|      18|    3880|6.213656579361331|
|      18|    3879|6.213656579361331|
|      18|    3878|6.213656579361331|
|      18|    3877|6.213656579361331|
|      18|    3876|6.213656579361331|
|      18|    3875|6.213656579361331|
|      18|    3874|6.213656579361331|
|      18|    3873|6.213656579361331|
|      46|    3882|6.213656579361331|
|      46|    3881|6.213656579361331|
|      46|    3880|6.213656579361331|
|      46|    3879|6.213656579361331|
|      46|    3878|6.213656579361331|
|      46|    3877|6.213656579361331|
|      46|    3876|6.213656579361331|
|      46|    3875|6.213656579361331|
|      46|    3874|6.213656579361331|
|      46|    3873|6.213656579361331|
|     186|    3882|6.213656579361331|
|     186|    3881|6.213656579361331|
|     186|    3880|6.213656579361331|
|     186|  

27-Aug-24 00:22:01, replay, INFO: Wilson started
27-Aug-24 00:22:01, replay, INFO: Wilson fit_predict started


                            NDCG@10    MRR@10  Coverage@10  fit_time
Popular                    0.243670  0.390414     0.033903  1.740541
Random (popularity-based)  0.027783  0.065515     0.737181  1.732029
Random (uniform)           0.022416  0.058170     0.942281  1.568186
UCB                        0.000000  0.000000     0.000000  0.930765
+--------+--------+------------------+
|user_idx|item_idx|         relevance|
+--------+--------+------------------+
|      18|     707|0.9754524358354291|
|      18|    1209| 0.972226295602107|
|      18|     557| 0.971408022009853|
|      18|     682|0.9698080855041509|
|      18|     234|0.9688400027862992|
|      18|    1081|0.9685863803126885|
|      18|      32|0.9660186961163115|
|      18|     238|0.9647924991594925|
|      18|     239|0.9647424995568235|
|      18|    1080|0.9643378106230265|
|      46|     400|0.9900259255456659|
|      46|     106|0.9821985244892515|
|      46|     186|0.9795457647582008|
|      46|     707|0.9754524358

27-Aug-24 00:22:12, replay, INFO: TS (context-free) started
27-Aug-24 00:22:12, replay, INFO: TS (context-free) fit_predict started


                            NDCG@10    MRR@10  Coverage@10  fit_time
Popular                    0.243670  0.390414     0.033903  1.740541
Wilson                     0.092121  0.180976     0.017092  0.813008
Random (popularity-based)  0.027783  0.065515     0.737181  1.732029
Random (uniform)           0.022416  0.058170     0.942281  1.568186
UCB                        0.000000  0.000000     0.000000  0.930765
+--------+--------+------------------+
|user_idx|item_idx|         relevance|
+--------+--------+------------------+
|      18|    2145| 0.999321272887753|
|      18|    2888|0.9982630237251762|
|      18|    3038|0.9979932560667897|
|      18|    2897|0.9975859876409066|
|      18|    2847|0.9967084466736913|
|      18|    1209|0.9965098151983792|
|      18|    3177| 0.996482216371184|
|      18|    2072|0.9957395895496687|
|      18|    1880|0.9956008285138828|
|      18|    2826|0.9951267068379085|
|      46|    2145| 0.999321272887753|
|      46|    2888|0.9982630237251762|
|

27-Aug-24 00:22:24, replay, INFO: Linear UCB (eps = -10.0)(disjoint models) started
27-Aug-24 00:22:24, replay, INFO: Linear UCB (eps = -10.0)(disjoint models) fit_predict started


                            NDCG@10    MRR@10  Coverage@10  fit_time
Popular                    0.243670  0.390414     0.033903  1.740541
Wilson                     0.092121  0.180976     0.017092  0.813008
Random (popularity-based)  0.027783  0.065515     0.737181  1.732029
Random (uniform)           0.022416  0.058170     0.942281  1.568186
TS (context-free)          0.012693  0.027208     0.005884  1.161516
UCB                        0.000000  0.000000     0.000000  0.930765


  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
                                                                                

+--------+--------+--------------------+
|user_idx|item_idx|           relevance|
+--------+--------+--------------------+
|      18|      14|  0.3461363553113034|
|      18|      32| 0.27435792832019157|
|      18|      47| 0.15771052776502625|
|      18|      48| 0.13450903873946263|
|      18|     114| 0.12801077388571092|
|      18|      89| 0.08495628340764438|
|      18|     146| 0.07927673357502185|
|      18|     179|0.005015392648520689|
|      18|     161|4.280268349391436E-4|
|      18|     105|-0.00684628959884...|
|      46|      18|  0.1361872498697485|
|      46|      22| 0.07383363680189003|
|      46|      19|  0.0644250713023885|
|      46|      28|0.013051815491661345|
|      46|      43|0.005695501342363185|
|      46|      52|0.002297268160284638|
|      46|      30|-0.01056825841032...|
|      46|      51|-0.04858197305413492|
|      46|      35|-0.04873115483799839|
|      46|      34|-0.05283451313261711|
|     186|      24|-0.02627412210296...|
|     186|      

27-Aug-24 00:22:47, replay, INFO: Linear UCB (eps = -5.0)(disjoint models) started
27-Aug-24 00:22:47, replay, INFO: Linear UCB (eps = -5.0)(disjoint models) fit_predict started


                                            NDCG@10    MRR@10  Coverage@10  \
Linear UCB (eps = -10.0)(disjoint models)  0.256768  0.415579     0.054637   
Popular                                    0.243670  0.390414     0.033903   
Wilson                                     0.092121  0.180976     0.017092   
Random (popularity-based)                  0.027783  0.065515     0.737181   
Random (uniform)                           0.022416  0.058170     0.942281   
TS (context-free)                          0.012693  0.027208     0.005884   
UCB                                        0.000000  0.000000     0.000000   

                                            fit_time  
Linear UCB (eps = -10.0)(disjoint models)  11.212709  
Popular                                     1.740541  
Wilson                                      0.813008  
Random (popularity-based)                   1.732029  
Random (uniform)                            1.568186  
TS (context-free)                           1

  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
                                                                                

+--------+--------+--------------------+
|user_idx|item_idx|           relevance|
+--------+--------+--------------------+
|      18|      14|   0.649542231904279|
|      18|      32|  0.6281206748210149|
|      18|      47|  0.5525380331294993|
|      18|     114|  0.5445372581847239|
|      18|      48|  0.5351774710095131|
|      18|     146|  0.5305180432459238|
|      18|     215|  0.4902560887751861|
|      18|      89|  0.4896392451855751|
|      18|     105|  0.4781346775650069|
|      18|     272| 0.44883281149300946|
|      46|      18|  0.5421071766363647|
|      46|      43|  0.4938265122073314|
|      46|      22|  0.4794601733026936|
|      46|      52| 0.47476643612590097|
|      46|      28|  0.4696753943237524|
|      46|      19|  0.4628401599238217|
|      46|      30|  0.4538245732203905|
|      46|      51|   0.451656994741109|
|      46|     106|  0.4500501878811629|
|      46|      49|  0.4475531002761074|
|     186|      24| 0.48689472947139034|
|     186|      

27-Aug-24 00:23:11, replay, INFO: Linear UCB (eps = -2.0)(disjoint models) started
27-Aug-24 00:23:11, replay, INFO: Linear UCB (eps = -2.0)(disjoint models) fit_predict started


                                            NDCG@10    MRR@10  Coverage@10  \
Linear UCB (eps = -10.0)(disjoint models)  0.256768  0.415579     0.054637   
Linear UCB (eps = -5.0)(disjoint models)   0.256633  0.411666     0.055758   
Popular                                    0.243670  0.390414     0.033903   
Wilson                                     0.092121  0.180976     0.017092   
Random (popularity-based)                  0.027783  0.065515     0.737181   
Random (uniform)                           0.022416  0.058170     0.942281   
TS (context-free)                          0.012693  0.027208     0.005884   
UCB                                        0.000000  0.000000     0.000000   

                                            fit_time  
Linear UCB (eps = -10.0)(disjoint models)  11.212709  
Linear UCB (eps = -5.0)(disjoint models)   12.185961  
Popular                                     1.740541  
Wilson                                      0.813008  
Random (popularity-bas

  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]


+--------+--------+-------------------+
|user_idx|item_idx|          relevance|
+--------+--------+-------------------+
|      18|      32| 0.8403783227215088|
|      18|      14| 0.8315857578600642|
|      18|     146| 0.8012628290484649|
|      18|     114| 0.7944531487641316|
|      18|     215| 0.7894818531273808|
|      18|      47| 0.7894345363481833|
|      18|     378|  0.778006634051391|
|      18|      48| 0.7755785303715433|
|      18|     105| 0.7691232578633155|
|      18|     239|  0.761065239180516|
|      46|      43| 0.7867051187263123|
|      46|      18| 0.7856591326963344|
|      46|     106| 0.7733858596354098|
|      46|      52| 0.7582479369052708|
|      46|      49| 0.7529335943151868|
|      46|      51| 0.7518003754182554|
|      46|      85|  0.747046939519117|
|      46|     118| 0.7462389880392188|
|      46|      35| 0.7443712422878337|
|      46|      28|  0.743649541623007|
|     186|      24| 0.7947960404160025|
|     186|      17| 0.7551338983711547|


27-Aug-24 00:23:34, replay, INFO: Linear UCB (disjoint models) started
27-Aug-24 00:23:34, replay, INFO: Linear UCB (disjoint models) optimization started
[I 2024-08-27 00:23:34,116] A new study created in memory with name: no-name-e3dbb469-70ca-4694-9d49-473c3357c5fc


                                            NDCG@10    MRR@10  Coverage@10  \
Linear UCB (eps = -10.0)(disjoint models)  0.256768  0.415579     0.054637   
Linear UCB (eps = -5.0)(disjoint models)   0.256633  0.411666     0.055758   
Popular                                    0.243670  0.390414     0.033903   
Linear UCB (eps = -2.0)(disjoint models)   0.243607  0.389884     0.061362   
Wilson                                     0.092121  0.180976     0.017092   
Random (popularity-based)                  0.027783  0.065515     0.737181   
Random (uniform)                           0.022416  0.058170     0.942281   
TS (context-free)                          0.012693  0.027208     0.005884   
UCB                                        0.000000  0.000000     0.000000   

                                            fit_time  
Linear UCB (eps = -10.0)(disjoint models)  11.212709  
Linear UCB (eps = -5.0)(disjoint models)   12.185961  
Popular                                     1.740541  

  res[param] = suggest_fn(param, low=low, high=high)
  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
[I 2024-08-27 00:23:54,387] Trial 0 finished with value: 0.00018929778129825437 and parameters: {'eps': 7.638594005737332, 'alpha': 1.223538244049235}. Best is trial 0 with value: 0.00018929778129825437.
  res[param] = suggest_fn(param, low=low, high=high)
  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
[I 2024-08-27 00:24:13,190] Trial 1 finished with value: 0.20548078172407847 and parameters: {'eps': -13.4732822534441, 'alpha': 9.403637030783353}. Best is trial 1 with value: 0.20548078172407847.
  res[param] = suggest_fn(param, low=low, high=high)
  arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
[I 2024-08-27 00:24:32,496] Trial 2 finished with value: 0.2053791236831272 and parameters: {'eps': -13.741414923932908, 'alpha': 7.943905489421209}. Best is trial 1 with value: 0.2054807817240

+--------+--------+--------------------+
|user_idx|item_idx|           relevance|
+--------+--------+--------------------+
|      18|      14|-0.18708212010482073|
|      18|      32| -0.3476935865107834|
|      18|      47| -0.5370679486599175|
|      18|      48| -0.5707485592606527|
|      18|     114| -0.6052127904146967|
|      18|      89| -0.6273064647581226|
|      18|     146| -0.7157448610921825|
|      18|     179| -0.7524737723498248|
|      18|     161| -0.7740641306316931|
|      18|     101| -0.8060688404408217|
|      46|      18|  -0.577361975265282|
|      46|      19| -0.6363398184917698|
|      46|      22|  -0.639538007429163|
|      46|      28| -0.7901290436212559|
|      46|      30| -0.8273627209168122|
|      46|      52| -0.8291209422804702|
|      46|      43| -0.8534509409473068|
|      46|      31| -0.8855277055753827|
|      46|      34| -0.9078039165172507|
|      46|      35| -0.9211055865280252|
|     186|      12| -0.8847899564046459|
|     186|      

In [47]:
e.results.sort_values('NDCG@10', ascending=False)

Unnamed: 0,MAP@10,NDCG@10,HitRate@1,HitRate@5,HitRate@10,Coverage@10,Surprisal@10,MRR@10,fit_time,predict_time,metric_time,full_time,params
Linear UCB (eps = -10.0)(disjoint models),0.167851,0.256768,0.311677,0.561896,0.65935,0.054637,0.125029,0.415579,11.212709,5.647594,5.934319,22.794623,"{'regression type': 'disjoint', 'seed': None}"
Linear UCB (eps = -5.0)(disjoint models),0.167926,0.256633,0.307287,0.567164,0.656716,0.055758,0.126491,0.411666,12.185961,5.676229,5.996803,23.858992,"{'regression type': 'disjoint', 'seed': None}"
Linear UCB (disjoint models),0.165757,0.25455,0.313433,0.554873,0.662862,0.053797,0.124824,0.415563,11.931105,5.719393,6.197528,23.848026,"{'regression type': 'disjoint', 'seed': None}"
Popular,0.157257,0.24367,0.28446,0.53029,0.645303,0.033903,0.118354,0.390414,1.740541,5.675722,9.396888,16.813151,"{'use_rating': False, 'add_cold_items': True, ..."
Linear UCB (eps = -2.0)(disjoint models),0.155757,0.243607,0.271291,0.558385,0.654083,0.061362,0.134384,0.389884,11.550395,5.525273,5.960969,23.036637,"{'regression type': 'disjoint', 'seed': None}"
Wilson,0.045002,0.092121,0.083406,0.34504,0.414399,0.017092,0.26219,0.180976,0.813008,3.724447,6.287893,10.825347,"{'alpha': 0.05, 'add_cold_items': True, 'cold_..."
Random (popularity-based),0.009907,0.027783,0.023705,0.113257,0.201932,0.737181,0.480604,0.065515,1.732029,4.429888,6.115834,12.277751,"{'distribution': 'popular_based', 'alpha': 25...."
Random (uniform),0.00763,0.022416,0.022827,0.09921,0.171203,0.942281,0.559885,0.05817,1.568186,5.520329,6.752461,13.840976,"{'distribution': 'uniform', 'alpha': 0.0, 'see..."
TS (context-free),0.003724,0.012693,0.011414,0.032485,0.113257,0.005884,0.548572,0.027208,1.161516,3.927576,6.493408,11.5825,"{'sample': False, 'seed': None}"
UCB,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.930765,4.369114,6.231866,11.531745,"{'exploration_coef': 2.0, 'sample': False, 'se..."
