-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch '32-ndcg' into 'master'
Resolve "Proper performance metrics" Closes #32 See merge request recommend.games/board-game-recommender!25
- Loading branch information
Showing
12 changed files
with
1,960 additions
and
1,184 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,5 +106,6 @@ recommender/ | |
.tc*/ | ||
.bga* | ||
.bgg* | ||
*.csv | ||
*.ipynb | ||
*.npz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
"""Evaluate recommender models.""" | ||
|
||
import logging | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Dict, Iterable, Tuple, Union | ||
|
||
import numpy as np | ||
import polars as pl | ||
from sklearn.metrics import ndcg_score | ||
|
||
from board_game_recommender.base import BaseGamesRecommender | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class RecommenderTestData: | ||
"""Test data for recommender model evaluation.""" | ||
|
||
user_ids: Tuple[str, ...] | ||
game_ids: np.ndarray | ||
ratings: np.ndarray | ||
|
||
|
||
def load_test_data( | ||
path: Union[str, Path], | ||
ratings_per_user: int, | ||
user_id_key: str = "bgg_user_name", | ||
game_id_key: str = "bgg_id", | ||
ratings_key: str = "bgg_user_rating", | ||
) -> RecommenderTestData: | ||
"""Load RecommenderTestData from CSV.""" | ||
|
||
path = Path(path).resolve() | ||
LOGGER.info("Loading test data from <%s>…", path) | ||
|
||
data = pl.read_csv(path) | ||
LOGGER.info("Read %d rows", len(data)) | ||
|
||
if len(data) % ratings_per_user != 0: | ||
raise ValueError( | ||
f"The number of rows ({len(data)}) is not divisible by " | ||
+ f"the number of ratings per user ({ratings_per_user})" | ||
) | ||
|
||
user_ids = tuple(data[user_id_key][::ratings_per_user]) | ||
game_ids = data[game_id_key].view().reshape((-1, ratings_per_user)) | ||
ratings = data[ratings_key].view().reshape((-1, ratings_per_user)) | ||
|
||
return RecommenderTestData(user_ids=user_ids, game_ids=game_ids, ratings=ratings) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class RecommenderMetrics: | ||
"""Recommender model evaluation metrics.""" | ||
|
||
ndcg: Dict[int, float] | ||
ndcg_exp: Dict[int, float] | ||
|
||
|
||
def calculate_metrics( | ||
recommender: BaseGamesRecommender, | ||
test_data: RecommenderTestData, | ||
*, | ||
k_values: Union[None, int, Iterable[int]], | ||
) -> RecommenderMetrics: | ||
"""Calculate RecommenderMetrics for given recommender model and RecommenderTestData.""" | ||
|
||
y_true = test_data.ratings | ||
y_pred = np.array( | ||
[ | ||
recommender.recommend_as_numpy(users=(user,), games=games)[0, :] | ||
for user, games in zip(test_data.user_ids, test_data.game_ids) | ||
] | ||
) | ||
|
||
if y_true.shape != y_pred.shape: | ||
raise ValueError( | ||
f"Shape of ratings ({y_true.shape}) does not match " | ||
+ f"shape of predictions ({y_pred.shape})" | ||
) | ||
|
||
if k_values is None: | ||
k_values = frozenset() | ||
elif isinstance(k_values, int): | ||
k_values = frozenset({k_values}) | ||
else: | ||
k_values = frozenset(k_values) | ||
|
||
k_values = sorted(k_values | {y_true.shape[-1]}) | ||
ndcg = {} | ||
|
||
for k in k_values: | ||
ndcg[k] = ndcg_score( | ||
y_true=y_true, | ||
y_score=y_pred, | ||
k=k, | ||
) | ||
|
||
y_true = np.exp2(y_true) - 1 | ||
ndcg_exp = {} | ||
|
||
for k in k_values: | ||
ndcg_exp[k] = ndcg_score( | ||
y_true=y_true, | ||
y_score=y_pred, | ||
k=k, | ||
) | ||
|
||
return RecommenderMetrics(ndcg=ndcg, ndcg_exp=ndcg_exp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# --- | ||
# jupyter: | ||
# jupytext: | ||
# formats: ipynb,py:percent | ||
# text_representation: | ||
# extension: .py | ||
# format_name: percent | ||
# format_version: '1.3' | ||
# jupytext_version: 1.14.5 | ||
# kernelspec: | ||
# display_name: Python 3 (ipykernel) | ||
# language: python | ||
# name: python3 | ||
# --- | ||
|
||
# %% | ||
import polars as pl | ||
|
||
# %load_ext nb_black | ||
# %load_ext lab_black | ||
|
||
# %% | ||
THRESHOLD_POWER_USERS = 200 | ||
NUM_LABELS = 100 | ||
|
||
# %% | ||
ratings = ( | ||
pl.scan_ndjson("../../board-game-data/scraped/bgg_RatingItem.jl") | ||
.filter(pl.col("bgg_user_rating").is_not_null()) | ||
.select( | ||
"bgg_id", | ||
"bgg_user_name", | ||
"bgg_user_rating", | ||
( | ||
(pl.col("bgg_id").count().over("bgg_user_name") >= THRESHOLD_POWER_USERS) | ||
& (pl.arange(0, pl.count()).shuffle().over("bgg_user_name") < NUM_LABELS) | ||
).alias("is_test_row"), | ||
) | ||
.collect() | ||
) | ||
|
||
# %% | ||
train_test = ratings.partition_by( | ||
"is_test_row", | ||
as_dict=True, | ||
) | ||
data_train = train_test[False] | ||
data_train.drop_in_place("is_test_row") | ||
data_train = data_train.sort("bgg_user_name", "bgg_id") | ||
data_test = train_test[True] | ||
data_test.drop_in_place("is_test_row") | ||
data_test = data_test.sort("bgg_user_name", "bgg_id") | ||
data_train.shape, data_test.shape | ||
|
||
# %% | ||
data_train.write_csv("ratings_train.csv") | ||
data_test.write_csv("ratings_test.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.