In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName("rec").getOrCreate()

In [4]:
from pyspark.ml.recommendation import ALS

In [5]:
from pyspark.ml.evaluation import RegressionEvaluator

In [6]:
data = spark.read.csv('movielens_ratings.csv',inferSchema=True, header=True)

In [7]:
data.show()

+------+-------+------+
|userId|movieId|rating|
+------+-------+------+
|     1|      2|   3.5|
|     1|     29|   3.5|
|     1|     32|   3.5|
|     1|     47|   3.5|
|     1|     50|   3.5|
|     1|    112|   3.5|
|     1|    151|   4.0|
|     1|    223|   4.0|
|     1|    253|   4.0|
|     1|    260|   4.0|
|     1|    293|   4.0|
|     1|    296|   4.0|
|     1|    318|   4.0|
|     1|    337|   3.5|
|     1|    367|   3.5|
|     1|    541|   4.0|
|     1|    589|   3.5|
|     1|    593|   3.5|
|     1|    653|   3.0|
|     1|    919|   3.5|
+------+-------+------+
only showing top 20 rows



In [8]:
data.describe().show()

+-------+-----------------+------------------+------------------+
|summary|           userId|           movieId|            rating|
+-------+-----------------+------------------+------------------+
|  count|           100000|            100000|            100000|
|   mean|         362.8304|         8572.4658|          3.507605|
| stddev|196.8029033568026|19056.086005583176|1.0629280136183334|
|    min|                1|                 1|               0.5|
|    max|              702|            128594|               5.0|
+-------+-----------------+------------------+------------------+



In [9]:
training, test = data.randomSplit([0.8,0.2])

In [10]:
als = ALS(maxIter=5, regParam=0.01,userCol='userId',itemCol='movieId',ratingCol='rating')

In [11]:
model = als.fit(training)

In [12]:
predictions = model.transform(test)

In [13]:
predictions.show()

+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
|   101|    471|   3.0| 3.9133773|
|   377|    471|   5.0| 3.0475051|
|   650|    471|   4.5|  4.039208|
|   579|    471|   5.0|  4.315957|
|   504|    471|   3.0|  4.682068|
|   489|    471|   5.0|  2.480452|
|   461|    471|   2.0| 2.9963946|
|   381|    471|   4.0|  3.633883|
|   648|    471|   5.0| 3.2886457|
|   245|    471|   5.0|  4.129113|
|   109|    471|   5.0| 3.6925142|
|    14|    471|   5.0| 4.0189414|
|    46|    833|   3.0|  4.018633|
|   133|   1088|   2.5| 2.3960078|
|    91|   1088|   2.5| 2.9443245|
|   654|   1088|   5.0|  4.002043|
|   206|   1088|   4.5| 3.6625388|
|    54|   1088|   3.0| 3.0165172|
|   279|   1088|   4.0| 3.4119813|
|   586|   1088|   1.0| 2.3550816|
+------+-------+------+----------+
only showing top 20 rows



In [14]:
evaluator = RegressionEvaluator(metricName='rmse',labelCol='rating',predictionCol='prediction')

In [15]:
rmse = evaluator.evaluate(predictions)

In [16]:
print("RMSE")
print(rmse)

RMSE
nan


In [17]:
single_user = test.filter(test['userId']==11).select(['movieId','userId'])

In [18]:
single_user.show()

+-------+------+
|movieId|userId|
+-------+------+
|    158|    11|
|    165|    11|
|    208|    11|
|    377|    11|
|    384|    11|
|    442|    11|
|    541|    11|
|    589|    11|
|    597|    11|
|    610|    11|
|   1028|    11|
|   1097|    11|
|   1196|    11|
|   1214|    11|
|   1291|    11|
|   1320|    11|
|   1339|    11|
|   1391|    11|
|   1584|    11|
|   1591|    11|
+-------+------+
only showing top 20 rows



In [19]:
recommendations = model.transform(single_user)

In [20]:
recommendations.orderBy('prediction',ascending= False).show()

+-------+------+----------+
|movieId|userId|prediction|
+-------+------+----------+
|  70227|    11|       NaN|
|   6795|    11|       NaN|
|   7345|    11|       NaN|
|   3745|    11| 5.8218637|
|  53125|    11| 5.6413517|
|   8371|    11| 5.5462937|
|  69526|    11|  5.424989|
|  52281|    11|  5.390704|
|   5283|    11| 5.3097944|
|  58559|    11| 5.1177745|
|   7153|    11| 5.1150546|
|   3000|    11|  5.075387|
|  48304|    11| 5.0643663|
|   2762|    11| 4.9932423|
|   2420|    11| 4.9323344|
|  50442|    11|  4.856162|
|   1876|    11| 4.8546805|
|  66297|    11| 4.8180857|
|  34319|    11| 4.8027062|
|   1196|    11|  4.749913|
+-------+------+----------+
only showing top 20 rows

