In [1]:
from pathlib import Path

import polars as pl
import numpy as np

from my_recsys_metrics import compute_metrics
from my_utils import make_submission


In [2]:
data_path = Path("../data/music_recsys")
train_events = pl.read_parquet(data_path / "train_events.parquet")
users_for_submission = pl.read_parquet(data_path / "users_for_submission.parquet")


### Use TopPopular tracks for recommendations

In [3]:
top10_tracks = (
    train_events
    .group_by("track_id")
    .agg(pl.col("user_id").count().alias("score"))
    .top_k(10, by="score")
)
top10_tracks


track_id,score
i32,u32
634651,18499
811300,16335
44204,14561
265134,13180
1355970,12791
1133665,12120
412548,11838
647096,11436
278845,11305
322362,10682


In [4]:
def populate_tracks_to_users(tracks: pl.DataFrame, users: pl.DataFrame) -> pl.DataFrame:
    user_ids = users["user_id"].unique().to_numpy()
    track_ids = tracks["track_id"].to_numpy()
    scores = tracks["score"].to_numpy()

    result = pl.DataFrame({
        "user_id": np.repeat(user_ids, len(tracks)),
        "track_id": np.tile(track_ids, len(user_ids)),
        "score": np.tile(scores, len(user_ids)),
    })

    return result

toppop_recommendations = populate_tracks_to_users(top10_tracks, users_for_submission)
toppop_recommendations


user_id,track_id,score
i32,i32,u32
1000736,634651,18499
1000736,811300,16335
1000736,44204,14561
1000736,265134,13180
1000736,1355970,12791
1000736,1133665,12120
1000736,412548,11838
1000736,647096,11436
1000736,278845,11305
1000736,322362,10682


In [5]:
toppop_submission = make_submission(toppop_recommendations)
toppop_submission


user_id,track_id
i32,str
2160468,"""634651 811300 …"
3248455,"""634651 811300 …"
9282825,"""634651 811300 …"
5137308,"""634651 811300 …"
3631721,"""634651 811300 …"
2554503,"""634651 811300 …"
3509411,"""634651 811300 …"
4793464,"""634651 811300 …"
8338680,"""634651 811300 …"
1814567,"""634651 811300 …"


In [6]:
compute_metrics(toppop_submission, pl.read_parquet(data_path / "ground_truth.parquet"))


{'ndcg@10': 0.005483177335739817, 'recall@10': 0.008520426941479572}