In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import col, count

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("ALS Recommendation System") \
    .getOrCreate()

24/04/21 17:56:47 WARN Utils: Your hostname, MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 10.0.0.195 instead (on interface en0)
24/04/21 17:56:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/21 17:56:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
df = spark.read.csv('data/interactions_top_books.csv', header=True)

                                                                                

In [3]:
df = df.withColumn("user_id", df["user_id"].cast('int'))
df = df.withColumn("book_id", df["book_id"].cast('int'))
df = df.withColumn("rating", df["rating"].cast('int'))

In [4]:
popular_books = df.where(col("is_read") == 1).groupBy("book_id").count().filter(col("count") > 2500).select("book_id")
active_users = df.where(col("is_read") == 1).groupBy("user_id").count().filter(col("count") > 50).select("user_id") 
df = df.join(popular_books, "book_id", "inner").join(active_users, "user_id", "inner")

In [5]:
# Split the data into training and test sets
(training, test) = df.randomSplit([0.8, 0.2])

# Build the recommendation model using ALS on the training data
als = ALS(maxIter=5, regParam=0.01, userCol="user_id", itemCol="book_id", ratingCol="rating")
model = als.fit(training)

# Evaluate the model by computing the RMSE on the test data
predictions = model.transform(test)
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",
                                predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))

# Generate top 10 movie recommendations for each user
userRecs = model.recommendForAllUsers(10)
userRecs.show()

# Stop Spark Session
#spark.stop()


24/04/21 17:57:02 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors
24/04/21 18:01:54 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
24/04/21 18:01:54 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
24/04/21 18:01:56 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK
                                                                                

Root-mean-square error = 1.6786663670977267




+-------+--------------------+
|user_id|     recommendations|
+-------+--------------------+
|     13|[{6405, 6.3128276...|
|     16|[{637, 5.115455},...|
|     22|[{1313, 4.699243}...|
|     26|[{1200, 4.1176834...|
|     31|[{7106, 5.652311}...|
|     34|[{16144, 4.700937...|
|     44|[{12807, 6.12889}...|
|     52|[{15355, 4.610573...|
|     65|[{182, 5.648111},...|
|     78|[{6400, 2.4138167...|
|     85|[{19075, 5.44109}...|
|     91|[{7106, 4.4261475...|
|    132|[{790, 4.0650268}...|
|    157|[{13978, 6.530684...|
|    177|[{536, 3.8087547}...|
|    182|[{15356, 5.997926...|
|    190|[{182, 4.4612613}...|
|    192|[{16997, 4.95944}...|
|    193|[{182, 6.874717},...|
|    211|[{8425, 3.6904285...|
+-------+--------------------+
only showing top 20 rows



                                                                                

In [17]:
from pyspark.sql.functions import col, expr, array_contains, collect_list, when

# Create a DataFrame of known interactions from the training set
known_interactions = training.select("user_id", "book_id")

# Explode the recommendations to filter out known likes
exploded_recs = userRecs.withColumn("recommendation", expr("explode(recommendations)")) \
                        .select("user_id", col("recommendation.book_id").alias("book_id"))

# Join exploded recommendations with known interactions to filter them out
filtered_recs = exploded_recs.join(known_interactions, ["user_id", "book_id"], "left_anti")

# Group back the filtered recommendations
filtered_user_recs = filtered_recs.groupBy("user_id").agg(
    collect_list("book_id").alias("filtered_recommendations")
)

# Join the test data to the filtered recommendations for evaluation
test_with_recs = test.join(filtered_user_recs, "user_id")

# Calculate precision and recall
precision_recall = test_with_recs.withColumn(
    "is_relevant", array_contains(col("filtered_recommendations"), col("book_id"))
).withColumn(
    "is_relevant_int", when(col("is_relevant"), 1).otherwise(0)  # Convert boolean to integer
).groupBy("user_id").agg(
    expr("avg(is_relevant_int)").alias("precision"),  # Use integer for average
    expr("sum(is_relevant_int) / count(is_relevant)").alias("recall")  # Use integer for sum
)

# Show the calculated metrics
precision_recall.show()


24/04/21 19:17:51 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:17:51 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:17:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


+-------+--------------------+--------------------+
|user_id|           precision|              recall|
+-------+--------------------+--------------------+
|     26| 0.06666666666666667| 0.06666666666666667|
|     31|                 0.0|                 0.0|
|     34|                 0.0|                 0.0|
|     44|                 0.0|                 0.0|
|     65|                 0.0|                 0.0|
|     78|0.015873015873015872|0.015873015873015872|
|     85|                 0.0|                 0.0|
|    192|                 0.0|                 0.0|
|    193|                 0.0|                 0.0|
|    211|0.017241379310344827|0.017241379310344827|
|    255|0.058823529411764705|0.058823529411764705|
|    296|                 0.0|                 0.0|
|    321| 0.06666666666666667| 0.06666666666666667|
|    362|0.022222222222222223|0.022222222222222223|
|    384|                 0.0|                 0.0|
|    436|                 0.0|                 0.0|
|    451|   

                                                                                

In [19]:
from pyspark.sql.functions import avg

# Assuming 'precision_recall' is your DataFrame containing precision and recall for each user
overall_metrics = precision_recall.agg(
    avg("precision").alias("mean_precision"),
    avg("recall").alias("mean_recall")
)

# Display the aggregated metrics
overall_metrics.show()


24/04/21 19:24:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:24:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:24:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:24:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:24:29 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:24:30 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:24:30 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:24:30 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/04/21 19:24:30 WARN RowBasedKeyValueBatch: Calling spill() on

+--------------------+--------------------+
|      mean_precision|         mean_recall|
+--------------------+--------------------+
|0.006609833780335955|0.006609833780335955|
+--------------------+--------------------+



                                                                                