In [17]:
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import ClusteringEvaluator
from pyspark.ml.feature import VectorAssembler
import matplotlib.pyplot as plt
import numpy as np

In [2]:
spark_conf = SparkConf().setAppName("Recommendation System").setMaster("local[*]")

spark = SparkSession.builder.config(conf=spark_conf).getOrCreate()

24/12/12 01:48:52 WARN Utils: Your hostname, lanhf-rogstrixg513rc resolves to a loopback address: 127.0.1.1; using 192.168.53.103 instead (on interface enp3s0)
24/12/12 01:48:52 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/12 01:48:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df = spark.read.csv("../../data/ratings.csv", header=True, inferSchema=True)

In [4]:
df.show()

+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
|     1|     31|   2.5|1260759144|
|     1|   1029|   3.0|1260759179|
|     1|   1061|   3.0|1260759182|
|     1|   1129|   2.0|1260759185|
|     1|   1172|   4.0|1260759205|
|     1|   1263|   2.0|1260759151|
|     1|   1287|   2.0|1260759187|
|     1|   1293|   2.0|1260759148|
|     1|   1339|   3.5|1260759125|
|     1|   1343|   2.0|1260759131|
|     1|   1371|   2.5|1260759135|
|     1|   1405|   1.0|1260759203|
|     1|   1953|   4.0|1260759191|
|     1|   2105|   4.0|1260759139|
|     1|   2150|   3.0|1260759194|
|     1|   2193|   2.0|1260759198|
|     1|   2294|   2.0|1260759108|
|     1|   2455|   2.5|1260759113|
|     1|   2968|   1.0|1260759200|
|     1|   3671|   3.0|1260759117|
+------+-------+------+----------+
only showing top 20 rows



In [5]:
df.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: double (nullable = true)
 |-- timestamp: integer (nullable = true)



In [6]:
df.describe().show()

24/12/12 01:48:57 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+-------+-----------------+------------------+------------------+--------------------+
|summary|           userId|           movieId|            rating|           timestamp|
+-------+-----------------+------------------+------------------+--------------------+
|  count|            27678|             27678|             27678|               27678|
|   mean|93.76316930414048|12670.350531107739| 3.528289616301756| 1.132113489702146E9|
| stddev|58.15929455855169|26685.600433440402|1.0653290966311337|1.9136201070717102E8|
|    min|                1|                 1|               0.5|           832228796|
|    max|              200|            162376|               5.0|          1476086345|
+-------+-----------------+------------------+------------------+--------------------+



In [7]:
# repartition the dataframe
df = df.repartition(10)

In [10]:
# Keep only 50 most rated movies
df = df.groupBy("movieId").count().orderBy("count", ascending=False).limit(50).join(df, "movieId", "inner")

In [12]:
df.summary().show()

+-------+------------------+------------------+------------------+------------------+--------------------+
|summary|           movieId|             count|            userId|            rating|           timestamp|
+-------+------------------+------------------+------------------+------------------+--------------------+
|  count|              3249|              3249|              3249|              3249|                3249|
|   mean|1190.5333948907355| 67.80886426592798|101.39827639273622| 3.878885811018775|1.1013390588864267E9|
| stddev|1488.7901494781202|15.067441825804298| 56.72784346006155|0.9625198428374747|2.0065599654775926E8|
|    min|                 1|                50|                 2|               0.5|           832228796|
|    25%|               344|                57|                56|               3.0|           913058278|
|    50%|               590|                63|               102|               4.0|          1111482813|
|    75%|              1265|         

In [23]:
# compute a table with userId as rows and movieId as columns
ratings = df.groupBy("userId").pivot("movieId").sum("rating").na.fill(2.5)


In [24]:
ratings.show()

+------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|userId|  1| 32| 47| 50|110|150|231|260|296|316|318|344|356|364|367|377|380|457|480|527|541|588|589|590|592|593|595|608|648|736|780|1097|1196|1198|1210|1240|1265|1270|1580|1721|2028|2571|2762|2858|2959|3578|4306|4993|5952|7153|
+------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
|   148|2.5|4.0|2.5|2.5|2.5|2.5|2.5|2.5|2.5|2.5|2.5|2.5|2.5|4.0|2.5|2.5|2.5|2.5|4.5|2.5|2.5|4.0|3.5|3.5|2.5|2.5|2.5|2.5|4.0|2.5|3.5| 5.0| 2.5| 2.5| 2.5| 2.5| 2.5| 2.5| 4.0| 2.5| 2.5| 4.5| 2.5| 5.0| 2.5| 2.5| 5.0| 5.0| 4.5| 2.5|
|    31|2.5|4.5|2.5|3.5|2.5|2.5|2.5|4.0|4.5|2.5|4.0|2.5|2.5|2.5|2.5|2.5|2.5|2.5|2.5|5.0|

In [25]:
# Transform the DataFrame to include the "features" column
assembler = VectorAssembler(inputCols=ratings.columns[1:], outputCol="features")

ratings = assembler.transform(ratings)

ratings.show()

+------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+--------------------+
|userId|  1| 32| 47| 50|110|150|231|260|296|316|318|344|356|364|367|377|380|457|480|527|541|588|589|590|592|593|595|608|648|736|780|1097|1196|1198|1210|1240|1265|1270|1580|1721|2028|2571|2762|2858|2959|3578|4306|4993|5952|7153|            features|
+------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+--------------------+
|   148|2.5|4.0|2.5|2.5|2.5|2.5|2.5|2.5|2.5|2.5|2.5|2.5|2.5|4.0|2.5|2.5|2.5|2.5|4.5|2.5|2.5|4.0|3.5|3.5|2.5|2.5|2.5|2.5|4.0|2.5|3.5| 5.0| 2.5| 2.5| 2.5| 2.5| 2.5| 2.5| 4.0| 2.5| 2.5| 4.5| 2.5| 5.0| 2.5| 2.5| 5.0| 5.0| 4.5| 2.5|[2.5,4.0,2.5,2.5,...|
|   

In [27]:
kmeans = KMeans().setK(5).setSeed(1)
model = kmeans.fit(ratings)

# Let's try running the KMeans algorithm
predictions = model.transform(ratings)

# Evaluate clustering by computing Silhouette score
evaluator = ClusteringEvaluator()

silhouette = evaluator.evaluate(predictions)
print("Silhouette with squared euclidean distance = " + str(silhouette))

# Shows the result
centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
    print(center)

Silhouette with squared euclidean distance = 0.1897826068930241
Cluster Centers: 
[2.78191489 2.86702128 2.61702128 2.85638298 2.59574468 2.65957447
 2.50531915 2.77659574 3.01595745 2.54255319 2.95744681 2.43617021
 2.81382979 2.62234043 2.53191489 2.52659574 2.54787234 2.57978723
 2.71276596 2.83510638 2.66489362 2.53723404 2.61702128 2.61170213
 2.57978723 2.7393617  2.55319149 2.95744681 2.67553191 2.58510638
 2.80319149 2.64361702 2.68085106 2.64361702 2.71276596 2.61170213
 2.62765957 2.7287234  2.67021277 2.56914894 2.61702128 2.90957447
 2.80319149 2.92553191 2.77659574 2.66489362 2.68617021 2.74468085
 2.79255319 2.79787234]
[3.25510204 3.3877551  3.64285714 3.67346939 3.29591837 3.02040816
 2.73469388 4.03061224 4.06122449 2.89795918 3.79591837 2.59183673
 3.68367347 3.03061224 2.62244898 2.92857143 3.05102041 3.36734694
 3.32653061 3.67346939 3.58163265 3.02040816 3.76530612 2.91836735
 3.06122449 3.67346939 3.01020408 3.23469388 3.         2.66326531
 3.07142857 3.41836735 