In [1]:
from pyspark.sql import functions as F
from pyspark.sql import Window

def jaccard_topk(interactions, K=200):
    # item degrees |U_i|
    item_deg = interactions.groupBy("itemId").agg(F.countDistinct("userId").alias("deg"))
    # co-occurrence: for each user, all pairs (i,j) with i<j
    a = interactions.alias("a")
    b = interactions.alias("b")
    pairs = (a.join(b, on="userId")
               .where(F.col("a.itemId") < F.col("b.itemId"))
               .groupBy(F.col("a.itemId").alias("i"), F.col("b.itemId").alias("j"))
               .agg(F.countDistinct("userId").alias("cij")))
    # union size = di + dj - cij
    pairs = (pairs
             .join(item_deg.withColumnRenamed("itemId","i").withColumnRenamed("deg","di"), "i")
             .join(item_deg.withColumnRenamed("itemId","j").withColumnRenamed("deg","dj"), "j")
             .withColumn("jac", F.col("cij") / (F.col("di")+F.col("dj")-F.col("cij"))))
    # symmetrize to (item, neighbor, score)
    sim_ij = pairs.select(F.col("i").alias("itemId"), F.col("j").alias("nbr"), F.col("jac"))
    sim_ji = pairs.select(F.col("j").alias("itemId"), F.col("i").alias("nbr"), F.col("jac"))
    sims = sim_ij.unionByName(sim_ji)

    # For each user, score candidate items by sum of Jaccard over their seen items
    # seen items per user
    ui = interactions.select("userId","itemId")
    # join: for each seen item s, get its neighbors n with jac
    cand = (ui.alias("ui")
              .join(sims.alias("s"), F.col("ui.itemId")==F.col("s.itemId"))
              .select(F.col("ui.userId").alias("userId"),
                      F.col("s.nbr").alias("itemId"),
                      F.col("s.jac").alias("score")))
    # aggregate duplicate candidates
    cand = cand.groupBy("userId","itemId").agg(F.sum("score").alias("score"))
    # drop already seen items
    seen = ui.groupBy("userId").agg(F.collect_set("itemId").alias("seen"))

    def topk_per_group(df, group_col, order_col, k, keep_cols):
        w = Window.partitionBy(group_col).orderBy(F.col(order_col).desc())
        return (df
                .withColumn("__rk", F.row_number().over(w))
                .where(F.col("__rk") <= F.lit(k))
                .select(*keep_cols))
    cand = (cand.join(seen, "userId", "left")
                 .where(~F.array_contains(F.col("seen"), F.col("itemId")))
                 .drop("seen"))
    topk = topk_per_group(cand, "userId", "score", K, ["userId","itemId","score"])
    return topk.withColumnRenamed("score","jaccard_score")