In [1]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
! pip install rs-datasets

# The notebook contains an example of features preprocessing with PySpark for RePlay LightFM model wrapper and includes:
1. Data loading and reindexing
2. Features preprocessing with pyspark
3. Building LightFM model based on interaction matrix and features
4. Model evaluation

*Note: to run this notebook, you will need an __experimental__ version of the RePlay*

In [3]:
import warnings
from optuna.exceptions import ExperimentalWarning
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=ExperimentalWarning)

In [2]:
from pyspark.sql.types import IntegerType
from pyspark.sql.functions import array_contains, col, explode, split, substring

from replay.experimental.preprocessing.data_preparator import Indexer, DataPreparator
from replay.metrics import HitRate, NDCG, MAP, Coverage, Experiment
from replay.experimental.models import LightFMWrap
from replay.utils.session_handler import State
from replay.splitters import TimeSplitter
from replay.utils.spark_utils import get_log_info
from rs_datasets import MovieLens

In [None]:
spark = State().session
spark

In [16]:
spark.sparkContext.setLogLevel('ERROR')

In [17]:
K=10
SEED=1234

# 1. Data loading

We will use MovieLens 10m dataset from rs_datasets package, which contains a list of recommendations datasets.

In [18]:
data = MovieLens("10m")
data.info()

65.6MB [00:01, 56.5MB/s]                            


ratings


Unnamed: 0,user_id,item_id,rating,timestamp
0,1,122,5.0,838985046
1,1,185,5.0,838983525
2,1,231,5.0,838983392



items


Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
1,2,Jumanji (1995),Adventure|Children|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance



tags


Unnamed: 0,user_id,item_id,tag,timestamp
0,15,4973,excellent!,1215184630
1,20,1747,politics,1188263867
2,20,1747,satire,1188263867





### 1.1 Convert interaction log to RePlay format

In [19]:
preparator = DataPreparator()

In [None]:
%%time
log = preparator.transform(columns_mapping={'user_id': 'user_id',
                                      'item_id': 'item_id',
                                      'relevance': 'rating',
                                      'timestamp': 'timestamp'
                                     }, 
                           data=data.ratings)
item_features = preparator.transform(columns_mapping={'item_id': 'item_id'}, 
                           data=data.items)

In [21]:
log.show(2)

+-------+-------+---------+-------------------+
|user_id|item_id|relevance|          timestamp|
+-------+-------+---------+-------------------+
|      1|    122|      5.0|1996-08-02 11:24:06|
|      1|    185|      5.0|1996-08-02 10:58:45|
+-------+-------+---------+-------------------+
only showing top 2 rows



In [22]:
item_features.show(2)

+-------+----------------+--------------------+
|item_id|           title|              genres|
+-------+----------------+--------------------+
|      1|Toy Story (1995)|Adventure|Animati...|
|      2|  Jumanji (1995)|Adventure|Childre...|
+-------+----------------+--------------------+
only showing top 2 rows



<a id='indexing'></a>
### 1.2. Indexing

Convert given users' and items' identifiers (\_id) to integers starting at zero without gaps (\_idx) with Indexer class.

In [23]:
indexer = Indexer(user_col='user_id', item_col='item_id')

In [24]:
%%time
indexer.fit(users=log.select('user_id'),
           items=log.select('item_id').union(item_features.select('item_id')))

                                                                                

CPU times: user 258 ms, sys: 18.9 ms, total: 277 ms
Wall time: 7.61 s


In [25]:
%%time
log_replay = indexer.transform(df=log)
log_replay.show(2)

                                                                                

+--------+--------+---------+-------------------+
|user_idx|item_idx|relevance|          timestamp|
+--------+--------+---------+-------------------+
|   65232|    1057|      5.0|1996-08-02 11:24:06|
|   65232|      76|      5.0|1996-08-02 10:58:45|
+--------+--------+---------+-------------------+
only showing top 2 rows

CPU times: user 283 ms, sys: 21.4 ms, total: 304 ms
Wall time: 3.98 s


In [26]:
%%time
item_features_replay = indexer.transform(df=item_features)
item_features_replay.show(2)

+--------+----------------+--------------------+
|item_idx|           title|              genres|
+--------+----------------+--------------------+
|      11|Toy Story (1995)|Adventure|Animati...|
|     117|  Jumanji (1995)|Adventure|Childre...|
+--------+----------------+--------------------+
only showing top 2 rows

CPU times: user 51.9 ms, sys: 4.19 ms, total: 56.1 ms
Wall time: 472 ms


### 1.3. Data split

In [28]:
# train/test split 
train_spl = TimeSplitter(
    time_threshold=0.2,
    drop_cold_items=True,
    drop_cold_users=True,
    item_column = 'item_idx',
    query_column = 'user_idx'
)


train, test = train_spl.split(log_replay)
print('train info:\n', get_log_info(train, user_col='user_idx', item_col='item_idx'))
print('test info:\n', get_log_info(test, user_col='user_idx', item_col='item_idx'))

                                                                                

train info:
 total lines: 8000043, total users: 59522, total items: 8989


                                                                                

test info:
 total lines: 249418, total users: 3196, total items: 8180


In [29]:
train.is_cached

False

# 2. Features preprocessing with pyspark

#### Year

In [30]:
year = item_features_replay.withColumn('year', substring(col('title'), -5, 4).astype(IntegerType())).select('item_idx', 'year')
year.show(2)

+--------+----+
|item_idx|year|
+--------+----+
|      11|1995|
|     117|1995|
+--------+----+
only showing top 2 rows



#### Genres

In [31]:
genres = (
    item_features_replay
    .select(
        "item_idx",
        split("genres", "\|").alias("genres")
    )
)

In [32]:
genres.show()

+--------+--------------------+
|item_idx|              genres|
+--------+--------------------+
|      11|[Adventure, Anima...|
|     117|[Adventure, Child...|
|     274|   [Comedy, Romance]|
|    1382|[Comedy, Drama, R...|
|     320|            [Comedy]|
|      89|[Action, Crime, T...|
|     252|   [Comedy, Romance]|
|    2179|[Adventure, Child...|
|    1018|            [Action]|
|      51|[Action, Adventur...|
|     139|[Comedy, Drama, R...|
|    1112|    [Comedy, Horror]|
|    2403|[Animation, Child...|
|     682|             [Drama]|
|    1348|[Action, Adventur...|
|     189|      [Crime, Drama]|
|     111|[Comedy, Drama, R...|
|     880|[Comedy, Drama, T...|
|     129|            [Comedy]|
|    1039|[Action, Comedy, ...|
+--------+--------------------+
only showing top 20 rows



In [33]:
genres_list = (
    genres.select(explode("genres").alias("genre"))
    .distinct().filter('genre <> "(no genres listed)"')
    .toPandas()["genre"].tolist()
)

In [34]:
genres_list

['Children',
 'Crime',
 'Sci-Fi',
 'Musical',
 'Animation',
 'Mystery',
 'Action',
 'Documentary',
 'Fantasy',
 'Drama',
 'IMAX',
 'Comedy',
 'Horror',
 'Adventure',
 'Western',
 'Romance',
 'Thriller',
 'War',
 'Film-Noir']

In [35]:
item_features = genres
for genre in genres_list:
    item_features = item_features.withColumn(
        genre,
        array_contains(col("genres"), genre).astype(IntegerType())
    )
item_features = item_features.drop("genres").cache()
item_features.count()

10681

In [36]:
item_features = item_features.join(year, on='item_idx', how='inner')
item_features.cache()
item_features.count()

10681

In [37]:
item_features.show(2)

+--------+--------+-----+------+-------+---------+-------+------+-----------+-------+-----+----+------+------+---------+-------+-------+--------+---+---------+----+
|item_idx|Children|Crime|Sci-Fi|Musical|Animation|Mystery|Action|Documentary|Fantasy|Drama|IMAX|Comedy|Horror|Adventure|Western|Romance|Thriller|War|Film-Noir|year|
+--------+--------+-----+------+-------+---------+-------+------+-----------+-------+-----+----+------+------+---------+-------+-------+--------+---+---------+----+
|      11|       1|    0|     0|      0|        1|      0|     0|          0|      1|    0|   0|     1|     0|        1|      0|      0|       0|  0|        0|1995|
|     117|       1|    0|     0|      0|        0|      0|     0|          0|      1|    0|   0|     0|     0|        1|      0|      0|       0|  0|        0|1995|
+--------+--------+-----+------+-------+---------+-------+------+-----------+-------+-----+----+------+------+---------+-------+-------+--------+---+---------+----+
only showi

# 3. Building LightFM model based on interaction matrix and item features

In [38]:
model_feat = LightFMWrap(random_state=SEED, loss='warp', no_components=16)

In [39]:
%%time
model_feat.fit(train, item_features=item_features)

                                                                                

CPU times: user 11min 16s, sys: 1.02 s, total: 11min 17s
Wall time: 3min 41s


In [40]:
%%time
recs = model_feat.predict(
    log=train,
    k=K,
    users=test.select('user_idx').distinct(),
    item_features=item_features,
    filter_seen_items=True,
)
recs.cache()
recs.count()

                                                                                

CPU times: user 349 ms, sys: 59 ms, total: 408 ms
Wall time: 3min 27s


31960

# 4. Model evaluation

In [41]:
metrics = Experiment(
    [NDCG(K), MAP(K), HitRate([1, K]), Coverage(K)],
    test,
    train,
    query_column = "user_idx",
    item_column = "item_idx",
    rating_column = "relevance"
)

In [42]:
%%time
metrics.add_result("LightFM_item_features", recs)
metrics.results

                                                                                

CPU times: user 133 ms, sys: 32.3 ms, total: 165 ms
Wall time: 30.8 s


Unnamed: 0,NDCG@10,MAP@10,HitRate@1,HitRate@10,Coverage@10
LightFM_item_features,0.27179,0.185185,0.336671,0.659262,0.108688
