In [41]:
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

In [42]:
spark = SparkSession.builder.appName("Collaborative Filtering").getOrCreate()
spark

In [43]:
moviesDF = spark.read.options(header = True, inferSchema=True).csv("movies.csv")
ratingsDF = spark.read.options(header = True, inferSchema = True).csv("ratings.csv")

moviesDF.show()
ratingsDF.show()

+-------+--------------------+--------------------+
|movieId|               title|              genres|
+-------+--------------------+--------------------+
|      1|    Toy Story (1995)|Adventure|Animati...|
|      2|      Jumanji (1995)|Adventure|Childre...|
|      3|Grumpier Old Men ...|      Comedy|Romance|
|      4|Waiting to Exhale...|Comedy|Drama|Romance|
|      5|Father of the Bri...|              Comedy|
|      6|         Heat (1995)|Action|Crime|Thri...|
|      7|      Sabrina (1995)|      Comedy|Romance|
|      8| Tom and Huck (1995)|  Adventure|Children|
|      9| Sudden Death (1995)|              Action|
|     10|    GoldenEye (1995)|Action|Adventure|...|
|     11|American Presiden...|Comedy|Drama|Romance|
|     12|Dracula: Dead and...|       Comedy|Horror|
|     13|        Balto (1995)|Adventure|Animati...|
|     14|        Nixon (1995)|               Drama|
|     15|Cutthroat Island ...|Action|Adventure|...|
|     16|       Casino (1995)|         Crime|Drama|
|     17|Sen

In [44]:
moviesDF.printSchema()
ratingsDF.printSchema()

root
 |-- movieId: integer (nullable = true)
 |-- title: string (nullable = true)
 |-- genres: string (nullable = true)

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: double (nullable = true)
 |-- timestamp: integer (nullable = true)



In [45]:
ratings = ratingsDF.join(moviesDF, 'movieId','left')

In [46]:
ratings.count()

100836

In [47]:
# Split the dataframe into train and test data
(train, test) = ratings.randomSplit([0.8,0.2])

In [48]:
train.count() + test.count() ==ratings.count()

True

In [49]:
train.show()

+-------+------+------+----------+----------------+--------------------+
|movieId|userId|rating| timestamp|           title|              genres|
+-------+------+------+----------+----------------+--------------------+
|      1|     5|   4.0| 847434962|Toy Story (1995)|Adventure|Animati...|
|      1|     7|   4.5|1106635946|Toy Story (1995)|Adventure|Animati...|
|      1|    17|   4.5|1305696483|Toy Story (1995)|Adventure|Animati...|
|      1|    19|   4.0| 965705637|Toy Story (1995)|Adventure|Animati...|
|      1|    21|   3.5|1407618878|Toy Story (1995)|Adventure|Animati...|
|      1|    27|   3.0| 962685262|Toy Story (1995)|Adventure|Animati...|
|      1|    31|   5.0| 850466616|Toy Story (1995)|Adventure|Animati...|
|      1|    33|   3.0| 939647444|Toy Story (1995)|Adventure|Animati...|
|      1|    54|   3.0| 830247330|Toy Story (1995)|Adventure|Animati...|
|      1|    63|   5.0|1443199669|Toy Story (1995)|Adventure|Animati...|
|      1|    64|   4.0|1161520134|Toy Story (1995)|

In [50]:
## ALS Model

In [51]:
als = ALS(userCol = "userId", itemCol = "movieId", ratingCol = "rating", nonnegative = True, implicitPrefs=False, coldStartStrategy ="drop")

In [52]:
param_grid = ParamGridBuilder()\
            .addGrid(als.rank,[10,50,100,150])\
            .addGrid(als.regParam, [.01, .05, .1, .15])\
            .build()

In [53]:
evaluator = RegressionEvaluator(
            metricName="rmse",
            labelCol="rating",
            predictionCol="prediction")

In [54]:
cv = CrossValidator(estimator=als, estimatorParamMaps=param_grid, evaluator=evaluator, numFolds=5)

In [59]:
model = cv.fit(train)
bestModel = model.bestModel
test_predictions = bestModel.transform(test)

In [60]:
RMSE = evaluator.evaluate(test_predictions)
print(RMSE)

0.8680240575181213


In [72]:
recommendations = bestModel.recommendForAllUsers(5)

In [73]:
df = recommendations

In [74]:
df.show(truncate=False)

+------+---------------------------------------------------------------------------------------------------+
|userId|recommendations                                                                                    |
+------+---------------------------------------------------------------------------------------------------+
|1     |[{96004, 5.806157}, {170355, 5.806157}, {3379, 5.806157}, {33649, 5.629791}, {5490, 5.577285}]     |
|2     |[{131724, 4.805224}, {69524, 4.535378}, {96004, 4.5284743}, {170355, 4.5284743}, {3379, 4.5284743}]|
|3     |[{5746, 4.8621607}, {6835, 4.8621607}, {5919, 4.766895}, {5181, 4.759805}, {2851, 4.734143}]       |
|4     |[{1733, 4.760128}, {25825, 4.7033024}, {1046, 4.688176}, {4765, 4.6822395}, {2204, 4.5960803}]     |
|5     |[{170355, 4.6484933}, {96004, 4.6484933}, {3379, 4.6484933}, {5490, 4.4872494}, {7767, 4.4733486}] |
|6     |[{33649, 4.8036513}, {3200, 4.7189884}, {3086, 4.6860747}, {5490, 4.6601877}, {5867, 4.657937}]    |
|7     |[{3030, 4.6

In [77]:
from pyspark.sql.functions import col, explode
df2 = df.withColumn("movieid_rating", explode("recommendations"))
df2.show()

+------+--------------------+-------------------+
|userId|     recommendations|     movieid_rating|
+------+--------------------+-------------------+
|     1|[{96004, 5.806157...|  {96004, 5.806157}|
|     1|[{96004, 5.806157...| {170355, 5.806157}|
|     1|[{96004, 5.806157...|   {3379, 5.806157}|
|     1|[{96004, 5.806157...|  {33649, 5.629791}|
|     1|[{96004, 5.806157...|   {5490, 5.577285}|
|     2|[{131724, 4.80522...| {131724, 4.805224}|
|     2|[{131724, 4.80522...|  {69524, 4.535378}|
|     2|[{131724, 4.80522...| {96004, 4.5284743}|
|     2|[{131724, 4.80522...|{170355, 4.5284743}|
|     2|[{131724, 4.80522...|  {3379, 4.5284743}|
|     3|[{5746, 4.8621607...|  {5746, 4.8621607}|
|     3|[{5746, 4.8621607...|  {6835, 4.8621607}|
|     3|[{5746, 4.8621607...|   {5919, 4.766895}|
|     3|[{5746, 4.8621607...|   {5181, 4.759805}|
|     3|[{5746, 4.8621607...|   {2851, 4.734143}|
|     4|[{1733, 4.760128}...|   {1733, 4.760128}|
|     4|[{1733, 4.760128}...| {25825, 4.7033024}|


In [79]:
df2.select("userId", col("movieid_rating.movieId"), col("movieid_rating.rating")).show()

+------+-------+---------+
|userId|movieId|   rating|
+------+-------+---------+
|     1|  96004| 5.806157|
|     1| 170355| 5.806157|
|     1|   3379| 5.806157|
|     1|  33649| 5.629791|
|     1|   5490| 5.577285|
|     2| 131724| 4.805224|
|     2|  69524| 4.535378|
|     2|  96004|4.5284743|
|     2| 170355|4.5284743|
|     2|   3379|4.5284743|
|     3|   5746|4.8621607|
|     3|   6835|4.8621607|
|     3|   5919| 4.766895|
|     3|   5181| 4.759805|
|     3|   2851| 4.734143|
|     4|   1733| 4.760128|
|     4|  25825|4.7033024|
|     4|   1046| 4.688176|
|     4|   4765|4.6822395|
|     4|   2204|4.5960803|
+------+-------+---------+
only showing top 20 rows

