In [20]:
from functools import wraps
import time

import numpy as np
import pandas as pd


In [21]:
interactions = pd.read_csv("../data/ml-1m/interactions_1k.csv", dtype=np.int32)
interactions


Unnamed: 0,user_id,item_id
0,3391,2987
1,3391,1248
2,3391,1249
3,3391,719
4,3391,574
...,...,...
172564,3025,1238
172565,3025,3926
172566,3025,3928
172567,3025,1244


In [22]:
interactions_grouped_by_user = interactions.groupby("user_id", as_index=False).agg(list)
interactions_grouped_by_user


Unnamed: 0,user_id,item_id
0,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15..."
1,17,"[1179, 2553, 2554, 3932, 3863, 3793, 1253, 720..."
2,24,"[2987, 3424, 648, 3354, 2628, 1259, 3361, 585,..."
3,32,"[1249, 3429, 3286, 1683, 1250, 2707, 1834, 589..."
4,34,"[3424, 1680, 3358, 2558, 2485, 2054, 3868, 125..."
...,...,...
995,6013,"[587, 1, 3005, 3008, 593, 595, 1416, 2083, 208..."
996,6018,"[2054, 589, 1408, 2067, 592, 3020, 3032, 3033,..."
997,6021,"[589, 1, 1408, 590, 592, 594, 2080, 2081, 2085..."
998,6027,"[2082, 2094, 3418, 1876, 292, 2858, 1188, 2133..."


In [24]:
interactions_grouped_by_user.dtypes


user_id     int64
item_id    object
dtype: object

In [23]:
cross_rows = interactions_grouped_by_user.join(
    interactions_grouped_by_user,
    how="cross",
    lsuffix="_lhs",
    rsuffix="_rhs",
)
cross_rows


Unnamed: 0,user_id_lhs,item_id_lhs,user_id_rhs,item_id_rhs
0,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15..."
1,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",17,"[1179, 2553, 2554, 3932, 3863, 3793, 1253, 720..."
2,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",24,"[2987, 3424, 648, 3354, 2628, 1259, 3361, 585,..."
3,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",32,"[1249, 3429, 3286, 1683, 1250, 2707, 1834, 589..."
4,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",34,"[3424, 1680, 3358, 2558, 2485, 2054, 3868, 125..."
...,...,...,...,...
999995,6028,"[3000, 6, 592, 1429, 1438, 2402, 2427, 227, 26...",6013,"[587, 1, 3005, 3008, 593, 595, 1416, 2083, 208..."
999996,6028,"[3000, 6, 592, 1429, 1438, 2402, 2427, 227, 26...",6018,"[2054, 589, 1408, 2067, 592, 3020, 3032, 3033,..."
999997,6028,"[3000, 6, 592, 1429, 1438, 2402, 2427, 227, 26...",6021,"[589, 1, 1408, 590, 592, 594, 2080, 2081, 2085..."
999998,6028,"[3000, 6, 592, 1429, 1438, 2402, 2427, 227, 26...",6027,"[2082, 2094, 3418, 1876, 292, 2858, 1188, 2133..."


In [17]:
def compute_similarity(row):
    if row["user_id_lhs"] == row["user_id_rhs"]:
        return -1

    x = set(row["item_id_lhs"])
    y = set(row["item_id_rhs"])

    return len(x & y) / (len(x) * len(y))

cross_rows["similarity"] = cross_rows.apply(compute_similarity, axis=1)
cross_rows


Unnamed: 0,user_id_lhs,item_id_lhs,user_id_rhs,item_id_rhs,similarity
0,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",-1.000000
1,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",17,"[1179, 2553, 2554, 3932, 3863, 3793, 1253, 720...",0.001799
2,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",24,"[2987, 3424, 648, 3354, 2628, 1259, 3361, 585,...",0.001838
3,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",32,"[1249, 3429, 3286, 1683, 1250, 2707, 1834, 589...",0.001543
4,13,"[2987, 648, 2628, 2054, 1259, 589, 1690, 2, 15...",34,"[3424, 1680, 3358, 2558, 2485, 2054, 3868, 125...",0.000678
...,...,...,...,...,...
999995,6028,"[3000, 6, 592, 1429, 1438, 2402, 2427, 227, 26...",6013,"[587, 1, 3005, 3008, 593, 595, 1416, 2083, 208...",0.000858
999996,6028,"[3000, 6, 592, 1429, 1438, 2402, 2427, 227, 26...",6018,"[2054, 589, 1408, 2067, 592, 3020, 3032, 3033,...",0.002944
999997,6028,"[3000, 6, 592, 1429, 1438, 2402, 2427, 227, 26...",6021,"[589, 1, 1408, 590, 592, 594, 2080, 2081, 2085...",0.003224
999998,6028,"[3000, 6, 592, 1429, 1438, 2402, 2427, 227, 26...",6027,"[2082, 2094, 3418, 1876, 292, 2858, 1188, 2133...",0.004137


In [18]:
cross_rows.sort_values(
    ["user_id_lhs", "similarity"],
    ascending=(True, False),
    inplace=True,
    ignore_index=True,
)

top_similar_users = (
    cross_rows[["user_id_lhs", "user_id_rhs", "similarity"]]
    .groupby("user_id_lhs")
    .head(10)
    .rename(columns={
        "user_id_lhs": "user_id",
        "user_id_rhs": "similar_user_id",
    })
)

top_similar_users


Unnamed: 0,user_id,similar_user_id,similarity
0,13,1849,0.006173
1,13,3730,0.006066
2,13,4320,0.005258
3,13,1558,0.004986
4,13,3404,0.004986
...,...,...,...
999005,6028,3809,0.006790
999006,6028,562,0.006325
999007,6028,394,0.006159
999008,6028,288,0.006079


In [19]:
def timeit(func):
    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        print(f'Function {func.__name__} Took {total_time:.4f} seconds')
        return result
    return timeit_wrapper

@timeit
def user_similarity(interactions: pd.DataFrame) -> pd.DataFrame:
    interactions_grouped_by_user = interactions.groupby("user_id", as_index=False).agg(list)

    cross_rows = interactions_grouped_by_user.join(
        interactions_grouped_by_user,
        how="cross",
        lsuffix="_lhs",
        rsuffix="_rhs",
    )

    def compute_similarity_fn(row):
        if row["user_id_lhs"] == row["user_id_rhs"]:
            return -1

        x = set(row["item_id_lhs"])
        y = set(row["item_id_rhs"])

        return len(x & y) / len(x | y)

    cross_rows["similarity"] = cross_rows.apply(compute_similarity_fn, axis=1)

    cross_rows.sort_values(
        ["user_id_lhs", "similarity"],
        ascending=(True, False),
        inplace=True,
        ignore_index=True,
    )

    top_similar_users = (
        cross_rows[["user_id_lhs", "user_id_rhs", "similarity"]]
        .groupby("user_id_lhs")
        .head(10)
        .rename(columns={
            "user_id_lhs": "user_id",
            "user_id_rhs": "similar_user_id",
        })
    )

    return top_similar_users


In [9]:
interactions_1k = pd.read_csv("../data/ml-1m/interactions_1k.csv", dtype=np.int32)
interactions_2k = pd.read_csv("../data/ml-1m/interactions_2k.csv", dtype=np.int32)
interactions_5k = pd.read_csv("../data/ml-1m/interactions_5k.csv", dtype=np.int32)


In [10]:
user_similarity(interactions_1k)


Function user_similarity Took 38.8633 seconds


Unnamed: 0,user_id,similar_user_id,similarity
0,13,3555,0.362069
1,13,4320,0.321678
2,13,5931,0.275862
3,13,3000,0.266272
4,13,4706,0.248555
...,...,...,...
999005,6028,5731,0.188889
999006,6028,6008,0.188406
999007,6028,288,0.183908
999008,6028,4862,0.175000


In [11]:
user_similarity(interactions_2k)


Function user_similarity Took 142.7556 seconds


Unnamed: 0,user_id,similar_user_id,similarity
0,4,2347,0.257143
1,4,526,0.234375
2,4,3461,0.234043
3,4,4966,0.232143
4,4,1349,0.230769
...,...,...,...
3998005,6039,5887,0.191860
3998006,6039,1325,0.187683
3998007,6039,556,0.183333
3998008,6039,1699,0.180451


In [12]:
user_similarity(interactions_5k)


Function user_similarity Took 864.6690 seconds


Unnamed: 0,user_id,similar_user_id,similarity
0,1,5343,0.236111
1,1,5190,0.230769
2,1,1283,0.222222
3,1,1353,0.183673
4,1,681,0.181818
...,...,...,...
24995005,6040,2608,0.278652
24995006,6040,2627,0.276662
24995007,6040,5090,0.272575
24995008,6040,1077,0.272016
