<a href="https://colab.research.google.com/github/pcamarillor/O2024_ESI3914O/blob/Lab12_Team08/Lab12.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import StructType, StructField, IntegerType
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator

spark = SparkSession.builder.appName("MovieRecommendationALS").getOrCreate()

schema = StructType([
    StructField("userId", IntegerType(), True),
    StructField("movieId", IntegerType(), True),
    StructField("rating", IntegerType(), True),
    StructField("timestamp", IntegerType(), True)
])

file_path = "/content/sample_data/sample_movielens_ratings.txt"
data = spark.read.csv(file_path, sep="::", schema=schema)

data = data.drop("timestamp")

users = data.select("userId").distinct()
movies = data.select("movieId").distinct()

In [2]:
rating_count = data.filter(data.rating.isNotNull()).count()
print("rating count:", rating_count)

rating count: 1501


In [3]:
missing_elements = (users.count()*movies.count()) - rating_count

In [4]:
# Build ALS model
als = ALS(
  userCol="userId",
  itemCol="movieId",
  ratingCol="rating",
  maxIter=10,
  regParam=0.1,
  rank=5, # Controls the dimensionality of the latent vector space for # users and items.
  coldStartStrategy="drop" # Avoids NaN predictions
)
model = als.fit(data)
user_recommendations = model.recommendForAllUsers(numItems=missing_elements)

In [5]:
# Generate predictions on the test set
predictions = model.transform(data)
predictions.show(truncate=False)
# Set up evaluator to compute RMSE
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
# Calculate RMSE
rmse = evaluator.evaluate(predictions)

+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
|28    |0      |3     |2.32634   |
|28    |1      |1     |1.0017017 |
|28    |2      |4     |3.6121683 |
|28    |3      |1     |0.9337988 |
|28    |6      |1     |1.05962   |
|28    |7      |1     |1.7147777 |
|28    |12     |5     |2.9297867 |
|28    |13     |2     |1.7533139 |
|28    |14     |1     |1.3806543 |
|28    |15     |1     |0.97168183|
|28    |17     |1     |1.2037045 |
|28    |19     |3     |2.4436297 |
|28    |20     |1     |1.4833806 |
|28    |23     |3     |2.6713645 |
|28    |24     |3     |2.3909693 |
|28    |27     |1     |0.90194   |
|28    |29     |1     |1.1212653 |
|28    |33     |1     |1.1912075 |
|28    |34     |1     |1.4464692 |
|28    |36     |1     |1.5517595 |
+------+-------+------+----------+
only showing top 20 rows



In [6]:
print(f"Root-mean-square error = {rmse}")

Root-mean-square error = 0.6011524942030569
