In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator

In [2]:
spark = SparkSession \
    .builder \
    .appName("collabrative filtering") \
    .getOrCreate()

In [3]:
spark.version

'3.3.1'

In [4]:
game_df = spark.read.csv('games.csv', header=True, inferSchema=True)
rate_df = spark.read.csv('ratings.csv', header=True, inferSchema=True)

In [5]:
game_df.show()

+-------+--------------------+------------+--------------------+--------------------+
|game_id|                name|release_date|             summary|          meta_score|
+-------+--------------------+------------+--------------------+--------------------+
|      1|The Legend of Zel...|   23-Nov-98|As a young boy, L...|                  99|
|      2|Tony Hawk's Pro S...|   20-Sep-00|As most major pub...|                  98|
|      3| Grand Theft Auto IV|   29-Apr-08|"[Metacritic's 20...| fresh off the bo...|
|      4|         SoulCalibur|    8-Sep-99|This is a tale of...|                  98|
|      5|  Super Mario Galaxy|   12-Nov-07|[Metacritic's 200...|                  97|
|      6|Super Mario Galaxy 2|   23-May-10|Super Mario Galax...|                  97|
|      7|Red Dead Redempti...|   26-Oct-18|Developed by the ...|                  97|
|      8|  Grand Theft Auto V|   18-Nov-14|Grand Theft Auto ...|                  97|
|      9|Disco Elysium: Th...|   30-Mar-21|Disco Elysi

In [6]:
rate_df.show()

+-------+-------+------+
|game_id|user_id|rating|
+-------+-------+------+
|      1|    314|     5|
|      1|    439|     3|
|      1|    588|     5|
|      1|   1169|     4|
|      1|   1185|     4|
|      1|   2077|     4|
|      1|   2487|     4|
|      1|   2900|     5|
|      1|   3662|     4|
|      1|   3922|     5|
|      1|   5379|     5|
|      1|   5461|     3|
|      1|   5885|     5|
|      1|   6630|     5|
|      1|   7563|     3|
|      1|   9246|     1|
|      1|  10140|     4|
|      1|  10146|     5|
|      1|  10246|     4|
|      1|  10335|     4|
+-------+-------+------+
only showing top 20 rows



In [7]:
rate_df.registerTempTable("ratings")
game_df.registerTempTable("games")



In [8]:
rate_df.describe()

DataFrame[summary: string, game_id: string, user_id: string, rating: string]

In [9]:
(train, test) = rate_df.randomSplit([0.7, 0.3], seed=42)

In [34]:
als = ALS(maxIter=5, regParam=0.01, userCol='user_id',itemCol='game_id', ratingCol='rating')
als.setColdStartStrategy('drop')

ALS_80d2d0327da1

In [35]:
model = als.fit(train)

In [36]:
pred = model.transform(test)

In [37]:
pred.show()

+-------+-------+------+----------+
|game_id|user_id|rating|prediction|
+-------+-------+------+----------+
|    463|  32592|     5| 4.4580426|
|   1580|   3918|     4|  2.595739|
|   1591|  35982|     5|  3.674078|
|   1829|  39285|     4| 4.8084574|
|   1829|  44822|     5| 4.4008255|
|   2122|  25591|     3| 3.7442565|
|   2866|  49331|     5|  4.914663|
|    463|  46147|     5|  5.724244|
|   2866|  17172|     3| 3.2400088|
|   2142|   1339|     4| 2.9931293|
|   2142|  19526|     5| 4.3795586|
|   1342|  33337|     3| 3.6251206|
|   1342|  33337|     4| 3.6251206|
|   2366|  32832|     5|  4.872513|
|    833|  28343|     3|  3.285513|
|   1238|  10527|     3| 3.5316963|
|   1829|  27361|     2|  2.705662|
|   1580|    588|     3|  3.827312|
|   1591|  43689|     4|  5.320927|
|   1088|  49202|     2|  2.582635|
+-------+-------+------+----------+
only showing top 20 rows



In [38]:
eval = RegressionEvaluator(metricName="rmse", labelCol='rating', predictionCol='prediction')

In [39]:
rmse = eval.evaluate(pred)
print(f'RMSE: {rmse}')

RMSE: 2.1799436088153237


In [40]:
user_id = 564
myuser = test.filter(test['user_id'] == user_id).select(['game_id', 'user_id'])

In [41]:
myuser.show()

+-------+-------+
|game_id|user_id|
+-------+-------+
|   3187|    564|
|   3911|    564|
|   5237|    564|
|   5456|    564|
|   7266|    564|
|   8782|    564|
|   9645|    564|
|   9689|    564|
+-------+-------+



In [42]:
rec = model.transform(myuser)

In [33]:
rec.orderBy('prediction', ascending=False).show()

+-------+-------+------------+
|game_id|user_id|  prediction|
+-------+-------+------------+
|   3911|    564|   4.6399097|
|   9689|    564|   3.7790594|
|   9645|    564|    3.499237|
|   7266|    564|   3.1608658|
|   8782|    564|   2.7031126|
|   5456|    564|   2.2260807|
|   3187|    564|   1.9552112|
|   5237|    564|-0.074302256|
+-------+-------+------------+

