In [5]:
from functools import wraps
import time

import numpy as np
import polars as pl


In [6]:
interactions = pl.read_csv("../data/ml-1m/interactions_1k.csv", schema={"user_id": pl.Int32, "item_id": pl.Int32})
interactions


user_id,item_id
i32,i32
3391,2987
3391,1248
3391,1249
3391,719
3391,574
3391,2050
3391,2051
3391,3791
3391,2052
3391,1250


In [7]:
def timeit(func):
    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        print(f'Function {func.__name__} Took {total_time:.4f} seconds')
        return result
    return timeit_wrapper

@timeit
def user_similarity(interactions: pl.DataFrame) -> pl.DataFrame:
    interactions_lazy = interactions.lazy()

    grouped_by_lazy = (
        interactions_lazy
        .group_by("user_id")
        .agg(pl.col("item_id"))
    )

    cross_rows = grouped_by_lazy.join(
        grouped_by_lazy,
        on="user_id",
        how="cross",
        suffix="_candidate",
    )

    cross_rows = (
        cross_rows
        .filter(pl.col("user_id") != pl.col("user_id_candidate"))
        .select(
            pl.col("user_id"),
            pl.col("user_id_candidate").alias("similar_user_id"),
            (
                pl.col("item_id").list.set_intersection("item_id_candidate").list.lengths() /
                (pl.col("item_id").list.lengths() * pl.col("item_id_candidate").list.lengths())
            )
            .cast(pl.Float32)
            .alias("similarity"),
        )
    )

    top_similar_users = (
        cross_rows
        .group_by("user_id")
        .agg(pl.col("similar_user_id", "similarity").sort_by("similarity", descending=True).head(10))
        .explode("similar_user_id", "similarity")
    )

    return top_similar_users.collect()


In [9]:
user_similarity(interactions).sort("user_id", "similarity", descending=(False, True))


Function user_similarity Took 7.3902 seconds


user_id,similar_user_id,similarity
i32,i32,f32
13,1849,0.006173
13,3730,0.006066
13,4320,0.005258
13,3404,0.004986
13,1558,0.004986
13,5931,0.00481
13,3555,0.004522
13,2291,0.004517
13,4668,0.00434
13,1289,0.004274


In [10]:
interactions_1k = pl.read_csv("../data/ml-1m/interactions_1k.csv", schema={"user_id": pl.Int32, "item_id": pl.Int32})
interactions_2k = pl.read_csv("../data/ml-1m/interactions_2k.csv", schema={"user_id": pl.Int32, "item_id": pl.Int32})
interactions_5k = pl.read_csv("../data/ml-1m/interactions_5k.csv", schema={"user_id": pl.Int32, "item_id": pl.Int32})


In [11]:
user_similarity(interactions_1k)


Function user_similarity Took 6.7308 seconds


user_id,similar_user_id,similarity
i32,i32,f32
2856,2322,0.191919
2856,4172,0.173077
2856,5435,0.152941
2856,96,0.146552
2856,5169,0.14
2856,2937,0.138686
2856,2570,0.137931
2856,5442,0.137255
2856,5820,0.136364
2856,5057,0.135135


In [12]:
user_similarity(interactions_2k)


Function user_similarity Took 51.9429 seconds


user_id,similar_user_id,similarity
i32,i32,f32
640,574,0.2
640,578,0.195122
640,1829,0.186047
640,1909,0.164179
640,3100,0.151515
640,5586,0.139535
640,1773,0.138889
640,5160,0.136364
640,5266,0.136364
640,5814,0.136364


In [35]:
# This crushes!!!
user_similarity(interactions_5k)


: 