In [1]:
import logging
import time
import warnings

import pandas as pd
import torch
import tqdm
from optuna.exceptions import ExperimentalWarning
from pyspark.sql import functions as sf

from replay.data_preparator import DataPreparator, Indexer
from replay.experiment import Experiment
from replay.metrics import HitRate, NDCG, MAP, MRR, Coverage, Surprisal
from replay.models import ALSWrap, KNN, LightFMWrap, SLIM, UCB, CQL, Wilson
from replay.session_handler import State
from replay.splitters import DateSplitter
from replay.utils import get_log_info

warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=ExperimentalWarning)
use_gpu = torch.cuda.is_available()



## Set Spark log level

In [2]:
spark = State().session
spark.sparkContext.setLogLevel('ERROR')

22/08/01 12:57:39 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/08/01 12:57:40 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
22/08/01 12:57:41 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/08/01 12:57:41 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


## Read ML1M dataset

In [3]:
prefix = "./data/"
df_log = pd.read_csv(f'{prefix}/ml1m_ratings.dat', sep='\t', names=['user_id', 'item_id', 'relevance', 'timestamp'])
# df_items = pd.read_csv(f'{prefix}/ml1m_items.dat', sep='\t', names=['item_id', 'name', 'genre'])
# df_users = pd.read_csv(f'{prefix}/ml1m_users.dat', sep='\t', names=['user_id', 'gender', 'age', 'occupation', 'zip_code'])

In [4]:
col_mapping = {key: key for key in ['user_id', 'item_id', 'relevance', 'timestamp']}

data_preparator = DataPreparator()
log = data_preparator.transform(columns_mapping=col_mapping, data=df_log)

01-Aug-22 12:57:42, replay, INFO: Columns with ids of users or items are present in mapping. The dataframe will be treated as an interactions log.
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
                                                                                

In [5]:
log.show(2)

+-------+-------+---------+-------------------+
|user_id|item_id|relevance|          timestamp|
+-------+-------+---------+-------------------+
|      1|   1193|      5.0|2001-01-01 01:12:40|
|      1|    661|      3.0|2001-01-01 01:35:09|
+-------+-------+---------+-------------------+
only showing top 2 rows



In [6]:
log.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- item_id: long (nullable = true)
 |-- relevance: double (nullable = true)
 |-- timestamp: timestamp (nullable = true)



In [7]:
get_log_info(log, user_col='user_id', item_col='item_id')

                                                                                

'total lines: 1000209, total users: 6040, total items: 3706'

Experiment constants/hyperparams

In [8]:
K = 10
K_list_metrics = [1, 5, 10]
SEED = 12345

## Prepare data for training/testing

Apply PySpark indexing, then split randomly into train/test datasets

In [9]:
indexer = Indexer()
indexer.fit(users=log.select('user_id'), items=log.select('item_id'))

                                                                                

In [10]:
# will consider ratings >= 3 as positive feedback. A positive feedback is treated with relevance = 1
only_positives_log = log.filter(sf.col('relevance') >= 3).withColumn('relevance', sf.lit(1.))
# negative feedback will be used for Wilson and UCB models
only_negatives_log = log.filter(sf.col('relevance') < 3).withColumn('relevance', sf.lit(0.))

only_positives_log.count(), only_negatives_log.count()

                                                                                

(836478, 163731)

In [11]:
pos_log = indexer.transform(df=only_positives_log)
pos_log.show(2)

  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
                                                                                

+--------+--------+---------+-------------------+
|user_idx|item_idx|relevance|          timestamp|
+--------+--------+---------+-------------------+
|    4131|      43|      1.0|2001-01-01 01:12:40|
|    4131|     585|      1.0|2001-01-01 01:35:09|
+--------+--------+---------+-------------------+
only showing top 2 rows



In [12]:
# train/test split 
date_splitter = DateSplitter(
    test_start=0.2,
    drop_cold_items=True,
    drop_cold_users=True,

)
train, test = date_splitter.split(pos_log)
train.cache(), test.cache()
print('train info:\n', get_log_info(train))
print('test info:\n', get_log_info(test))

                                                                                

train info:
 total lines: 669181, total users: 5397, total items: 3569


[Stage 56:>                                                         (0 + 8) / 8]

test info:
 total lines: 86542, total users: 1139, total items: 3279


                                                                                

In [14]:
test_start = test.agg(sf.min('timestamp')).collect()[0][0]

# train with both positive and negative feedback
pos_neg_train=(
    train
    .withColumn('relevance', sf.lit(1.))
    .union(
        indexer.transform(
            only_negatives_log.filter(sf.col('timestamp') < test_start)
        )
    )
)
cql_train = indexer.transform(log.filter(sf.col('timestamp') < test_start))
cql_train.cache(), cql_train.count()
pos_neg_train.cache()
pos_neg_train.count()

  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
                                                                                

798993

## Run training

In [17]:
def fit_predict_add_res(name, model, experiment, train, top_k, test_users):
    """
    Run fit_predict for the `model`, measure time on fit_predict and evaluate metrics
    """
    start_time=time.time()
    
    model.fit(log=train)
    fit_time = time.time() - start_time

    pred=model.predict(log=train, k=top_k, users=test_users)
    pred.cache()
    pred.count()
    predict_time = time.time() - start_time - fit_time

    experiment.add_result(name, pred)
    metric_time = time.time() - start_time - fit_time - predict_time
    
    experiment.results.loc[name, 'fit_time'] = fit_time
    experiment.results.loc[name, 'predict_time'] = predict_time
    experiment.results.loc[name, 'metric_time'] = metric_time
    experiment.results.loc[name, 'full_time'] = (fit_time + predict_time + metric_time)
    pred.unpersist()

In [16]:
experiment = Experiment(test, {MAP(): K, NDCG(): K, HitRate(): K_list_metrics, Coverage(train): K, Surprisal(train): K, MRR(): K})

algorithms = {
    'CQL': CQL(use_gpu=use_gpu, k=K, n_epochs=3),
    'ALS': ALSWrap(seed=SEED),
    'KNN': KNN(num_neighbours=K),
    'LightFM': LightFMWrap(random_state=SEED), 
    'SLIM': SLIM(seed=SEED),
    'UCB': UCB(exploration_coef=0.5)
}

01-Aug-22 12:59:31, replay, INFO: The model is neural network with non-distributed training


In [18]:
%%time
logger = logging.getLogger("replay")
test_users = test.select('user_idx').distinct()

for name in tqdm.auto.tqdm(algorithms.keys(), desc='Model'):
    model = algorithms[name]
    
    logger.info(msg='{} started'.format(name))
    
    train_ = train
    if isinstance(model, (Wilson, UCB)):
        train_ = pos_neg_train
    if isinstance(model, CQL):
        train_ = cql_train

    fit_predict_add_res(name, model, experiment, train=train_, top_k=K, test_users=test_users)
    print(experiment.results[['NDCG@{}'.format(K), 'MRR@{}'.format(K), 'Coverage@{}'.format(K), 'fit_time']].sort_values('NDCG@{}'.format(K), ascending=False))

Model:   0%|          | 0/6 [00:00<?, ?it/s]

01-Aug-22 12:59:44, replay, INFO: CQL started
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):


2022-08-01 12:59.51 [debug    ] RoundIterator is selected.
2022-08-01 12:59.51 [info     ] Directory is created at d3rlpy_logs/CQL_20220801125951
2022-08-01 12:59.51 [debug    ] Building models...
2022-08-01 12:59.51 [debug    ] Models have been built.
2022-08-01 12:59.51 [info     ] Parameters are saved to d3rlpy_logs/CQL_20220801125951/params.json params={'action_scaler': None, 'actor_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rate': None}}, 'actor_learning_rate': 0.0001, 'actor_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_learning_rate': 0.0001, 'alpha_optim_factory': {'optim_cls': 'Adam', 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}, 'alpha_threshold': 10.0, 'batch_size': 256, 'conservative_weight': 5.0, 'critic_encoder_factory': {'type': 'default', 'params': {'activation': 'relu', 'use_batch_norm': False, 'dropout_rat

Epoch 1/3:   0%|          | 0/3120 [00:00<?, ?it/s]

2022-08-01 13:02.23 [info     ] CQL_20220801125951: epoch=1 step=3120 epoch=1 metrics={'time_sample_batch': 0.00042672944374573535, 'time_algorithm_update': 0.048037135142546435, 'temp_loss': -15.400636895574056, 'temp': 1.0406986279365344, 'alpha_loss': 6.2436805610282295, 'alpha': 0.9005338710852159, 'critic_loss': 565.8181018609649, 'actor_loss': 76.53403642513813, 'time_step': 0.04861592528147575} step=3120
2022-08-01 13:02.23 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220801125951/model_3120.pt


Epoch 2/3:   0%|          | 0/3120 [00:00<?, ?it/s]

2022-08-01 13:05.12 [info     ] CQL_20220801125951: epoch=2 step=6240 epoch=2 metrics={'time_sample_batch': 0.0006374119183956048, 'time_algorithm_update': 0.05316777886488499, 'temp_loss': -10.706878644380813, 'temp': 1.2415765372988505, 'alpha_loss': 10.586027692306118, 'alpha': 0.6472853791828339, 'critic_loss': 93.39644227325917, 'actor_loss': 44.98644273770161, 'time_step': 0.0539720892906189} step=6240
2022-08-01 13:05.12 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220801125951/model_6240.pt


Epoch 3/3:   0%|          | 0/3120 [00:00<?, ?it/s]

2022-08-01 13:08.13 [info     ] CQL_20220801125951: epoch=3 step=9360 epoch=3 metrics={'time_sample_batch': 0.0008298011162342169, 'time_algorithm_update': 0.05665710201630225, 'temp_loss': -14.97983264854321, 'temp': 1.5780648015630552, 'alpha_loss': 5.965013566990652, 'alpha': 0.4977218455993212, 'critic_loss': 228.11275364947625, 'actor_loss': 89.44396681174254, 'time_step': 0.05765042251501328} step=9360
2022-08-01 13:08.13 [info     ] Model parameters are saved to d3rlpy_logs/CQL_20220801125951/model_9360.pt


  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
01-Aug-22 13:09:22, replay, INFO: ALS started                                   


      NDCG@10    MRR@10  Coverage@10    fit_time
CQL  0.009132  0.028198     0.013729  509.200611


01-Aug-22 13:11:09, replay, INFO: KNN started                                   


      NDCG@10    MRR@10  Coverage@10    fit_time
ALS  0.255471  0.413317     0.147380   25.643692
CQL  0.009132  0.028198     0.013729  509.200611


01-Aug-22 13:12:57, replay, INFO: LightFM started                               


      NDCG@10    MRR@10  Coverage@10    fit_time
ALS  0.255471  0.413317     0.147380   25.643692
KNN  0.246561  0.404074     0.102550   68.804712
CQL  0.009132  0.028198     0.013729  509.200611


  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
01-Aug-22 13:14:24, replay, INFO: SLIM started                                  


          NDCG@10    MRR@10  Coverage@10    fit_time
ALS      0.255471  0.413317     0.147380   25.643692
KNN      0.246561  0.404074     0.102550   68.804712
LightFM  0.238080  0.395384     0.298123   41.198115
CQL      0.009132  0.028198     0.013729  509.200611


  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
  if LooseVersion(pandas.__version__) < LooseVersion(minimum_pandas_version):
01-Aug-22 13:15:35, replay, INFO: UCB started                                   


          NDCG@10    MRR@10  Coverage@10    fit_time
ALS      0.255471  0.413317     0.147380   25.643692
KNN      0.246561  0.404074     0.102550   68.804712
SLIM     0.239382  0.414097     0.149902   32.560438
LightFM  0.238080  0.395384     0.298123   41.198115
CQL      0.009132  0.028198     0.013729  509.200611




          NDCG@10    MRR@10  Coverage@10    fit_time
ALS      0.255471  0.413317     0.147380   25.643692
KNN      0.246561  0.404074     0.102550   68.804712
SLIM     0.239382  0.414097     0.149902   32.560438
LightFM  0.238080  0.395384     0.298123   41.198115
CQL      0.009132  0.028198     0.013729  509.200611
UCB      0.000193  0.000878     0.019894    8.336780
CPU times: user 8min 27s, sys: 3min 54s, total: 12min 22s
Wall time: 16min 44s


                                                                                

In [19]:
experiment.results.sort_values('NDCG@10', ascending=False)

Unnamed: 0,Coverage@10,HitRate@1,HitRate@5,HitRate@10,MAP@10,MRR@10,NDCG@10,Surprisal@10,fit_time,predict_time,metric_time,full_time
ALS,0.14738,0.303775,0.569798,0.673398,0.163541,0.413317,0.255471,0.170315,25.643692,46.263832,35.321051,107.228575
KNN,0.10255,0.290606,0.56014,0.650571,0.156526,0.404074,0.246561,0.164692,68.804712,15.744625,23.813397,108.362734
SLIM,0.149902,0.298507,0.569798,0.66813,0.144292,0.414097,0.239382,0.17625,32.560438,20.411271,18.270732,71.242441
LightFM,0.298123,0.268657,0.553995,0.677788,0.144336,0.395384,0.23808,0.208283,41.198115,26.684952,18.362788,86.245855
CQL,0.013729,0.015803,0.043898,0.068481,0.003206,0.028198,0.009132,0.630448,509.200611,32.272322,36.361653,577.834586
UCB,0.019894,0.000878,0.000878,0.000878,8.8e-05,0.000878,0.000193,1.0,8.33678,24.129453,20.843065,53.309298
