In [1]:
from functools import wraps
import time
from typing import Optional, Tuple, List, Set

import numpy as np
import polars as pl

from numba import njit, prange


In [2]:
interactions = pl.read_csv("../data/ml-1m/interactions_1k.csv", schema={"user_id": pl.Int32, "item_id": pl.Int32})
interactions


user_id,item_id
i32,i32
3391,2987
3391,1248
3391,1249
3391,719
3391,574
3391,2050
3391,2051
3391,3791
3391,2052
3391,1250


In [3]:
@njit(parallel=True, fastmath=True)
def compute_distances(
    user_ids: np.ndarray,
    item_ids: np.ndarray,
    k: int = 10,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    # Compute offsets for each group of item IDs assuming interactions are sorted
    unique_user_ids: List[int] = []
    item_id_offsets: List[int] = []
    prev_user_id: int = -1
    for i, user_id in enumerate(user_ids):
        if user_id != prev_user_id:
            prev_user_id = user_id
            item_id_offsets.append(i)
            unique_user_ids.append(user_id)
    item_id_offsets.append(i + 1)

    # Number of offsets is the number of unique users - 1
    assert item_id_offsets[0] == 0
    assert item_id_offsets[-1] == len(user_ids)
    n_users: int = len(item_id_offsets) - 1

    # Convert sequence of items to set for each user
    item_id_sets: List[Set[int]] = []
    for i in range(0, n_users):
        item_id_seq = item_ids[item_id_offsets[i] : item_id_offsets[i + 1]]
        item_id_sets.append(set(item_id_seq))

    # Compute distances
    distances = np.zeros((n_users, n_users), dtype=np.float32)
    for i in prange(0, n_users):
        for j in range(0, i):
            x = item_id_sets[i]
            y = item_id_sets[j]
            distances[i, j] = len(x & y) / (len(x) * len(y))

    distances = distances + distances.T

    unique_user_ids_np = np.array(unique_user_ids, dtype=np.int32)
    top_distances = np.empty((n_users, k), dtype=np.float32)
    top_user_ids = np.empty((n_users, k), dtype=np.int32)

    for i in range(0, n_users):
        top_user_indices = np.argsort(-distances[i])[:k]
        top_distances[i, :] = distances[i][top_user_indices]
        top_user_ids[i, :] = unique_user_ids_np[top_user_indices]

    return unique_user_ids_np, top_user_ids, top_distances


def timeit(func):
    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        print(f'Function {func.__name__} Took {total_time:.4f} seconds')
        return result
    return timeit_wrapper

@timeit
def user_similarity(interactions: pl.DataFrame) -> pl.DataFrame:
    interactions_sorted = interactions.sort("user_id")

    user_ids, similar_user_ids, similarity = compute_distances(
        interactions_sorted["user_id"].to_numpy(),
        interactions_sorted["item_id"].to_numpy(),
    )

    top_similar_users = pl.DataFrame({
        "user_id": user_ids,
        "similar_user_id": similar_user_ids,
        "similarity": similarity
    })
    top_similar_users = top_similar_users.explode("similar_user_id", "similarity")

    return top_similar_users


In [4]:
interactions_1k = pl.read_csv("../data/ml-1m/interactions_1k.csv", schema={"user_id": pl.Int32, "item_id": pl.Int32})
interactions_2k = pl.read_csv("../data/ml-1m/interactions_2k.csv", schema={"user_id": pl.Int32, "item_id": pl.Int32})
interactions_5k = pl.read_csv("../data/ml-1m/interactions_5k.csv", schema={"user_id": pl.Int32, "item_id": pl.Int32})


In [5]:
# Warmup to compile Numba kernels
_ = user_similarity(interactions_1k)


Function user_similarity Took 6.8698 seconds


In [6]:
user_similarity(interactions_1k)


Function user_similarity Took 0.5343 seconds


user_id,similar_user_id,similarity
i32,i32,f32
13,1849,0.006173
13,3730,0.006066
13,4320,0.005258
13,3404,0.004986
13,1558,0.004986
13,5931,0.00481
13,3555,0.004522
13,2291,0.004517
13,4668,0.00434
13,1289,0.004274


In [8]:
user_similarity(interactions_2k)


Function user_similarity Took 1.6901 seconds


user_id,similar_user_id,similarity
i32,i32,f32
4,2347,0.018634
4,3535,0.014778
4,3461,0.014157
4,1771,0.013889
4,5388,0.013889
4,3616,0.013605
4,1349,0.013289
4,4966,0.012897
4,1558,0.012821
4,892,0.012605


In [9]:
user_similarity(interactions_5k)


Function user_similarity Took 9.1017 seconds


user_id,similar_user_id,similarity
i32,i32,f32
1,5343,0.00891
1,5190,0.007898
1,1283,0.007383
1,681,0.006951
1,5525,0.006604
1,2799,0.006563
1,5320,0.006563
1,317,0.006289
1,417,0.006003
1,80,0.005896
