In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode
from pyspark.sql.types import FloatType, IntegerType
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

In [2]:
spark = SparkSession.builder.appName('Collaborative Filtering').getOrCreate()

22/12/08 19:50:03 WARN Utils: Your hostname, hasirama resolves to a loopback address: 127.0.1.1; using 192.168.0.219 instead (on interface enp7s0)
22/12/08 19:50:03 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/12/08 19:50:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
sc = spark.sparkContext

In [4]:
rating_rdf = spark.read.options(inferScehma=True, header=True).csv('ratings.csv')
rating_rdf.show()

+------+-------+------+---------+
|userId|movieId|rating|timestamp|
+------+-------+------+---------+
|     1|      1|   4.0|964982703|
|     1|      3|   4.0|964981247|
|     1|      6|   4.0|964982224|
|     1|     47|   5.0|964983815|
|     1|     50|   5.0|964982931|
|     1|     70|   3.0|964982400|
|     1|    101|   5.0|964980868|
|     1|    110|   4.0|964982176|
|     1|    151|   5.0|964984041|
|     1|    157|   5.0|964984100|
|     1|    163|   5.0|964983650|
|     1|    216|   5.0|964981208|
|     1|    223|   3.0|964980985|
|     1|    231|   5.0|964981179|
|     1|    235|   4.0|964980908|
|     1|    260|   5.0|964981680|
|     1|    296|   3.0|964982967|
|     1|    316|   3.0|964982310|
|     1|    333|   5.0|964981179|
|     1|    349|   4.0|964982563|
+------+-------+------+---------+
only showing top 20 rows



In [5]:
rating_rdf.printSchema()

root
 |-- userId: string (nullable = true)
 |-- movieId: string (nullable = true)
 |-- rating: string (nullable = true)
 |-- timestamp: string (nullable = true)



In [6]:
movie_rdf = spark.read.options(inferSchema=True, header=True).csv('movies.csv')
movie_rdf.show()

+-------+--------------------+--------------------+
|movieId|               title|              genres|
+-------+--------------------+--------------------+
|      1|    Toy Story (1995)|Adventure|Animati...|
|      2|      Jumanji (1995)|Adventure|Childre...|
|      3|Grumpier Old Men ...|      Comedy|Romance|
|      4|Waiting to Exhale...|Comedy|Drama|Romance|
|      5|Father of the Bri...|              Comedy|
|      6|         Heat (1995)|Action|Crime|Thri...|
|      7|      Sabrina (1995)|      Comedy|Romance|
|      8| Tom and Huck (1995)|  Adventure|Children|
|      9| Sudden Death (1995)|              Action|
|     10|    GoldenEye (1995)|Action|Adventure|...|
|     11|American Presiden...|Comedy|Drama|Romance|
|     12|Dracula: Dead and...|       Comedy|Horror|
|     13|        Balto (1995)|Adventure|Animati...|
|     14|        Nixon (1995)|               Drama|
|     15|Cutthroat Island ...|Action|Adventure|...|
|     16|       Casino (1995)|         Crime|Drama|
|     17|Sen

In [7]:
rdf = rating_rdf.join(movie_rdf, 'movieId', 'left')
rdf = rdf.withColumn('userId', rdf['userId'].cast(IntegerType()))
rdf = rdf.withColumn('movieId', rdf['movieId'].cast(IntegerType()))
rdf = rdf.withColumn('rating', rdf['rating'].cast(FloatType()))
rdf = rdf.drop('timestamp', 'title', 'genres')
rdf.show()

+-------+------+------+
|movieId|userId|rating|
+-------+------+------+
|      1|     1|   4.0|
|      3|     1|   4.0|
|      6|     1|   4.0|
|     47|     1|   5.0|
|     50|     1|   5.0|
|     70|     1|   3.0|
|    101|     1|   5.0|
|    110|     1|   4.0|
|    151|     1|   5.0|
|    157|     1|   5.0|
|    163|     1|   5.0|
|    216|     1|   5.0|
|    223|     1|   3.0|
|    231|     1|   5.0|
|    235|     1|   4.0|
|    260|     1|   5.0|
|    296|     1|   3.0|
|    316|     1|   3.0|
|    333|     1|   5.0|
|    349|     1|   4.0|
+-------+------+------+
only showing top 20 rows



In [8]:
rdf.printSchema()

root
 |-- movieId: integer (nullable = true)
 |-- userId: integer (nullable = true)
 |-- rating: float (nullable = true)



In [9]:
rdf.agg({'rating': 'max'}).collect()[0][0]

5.0

In [10]:
counts = rdf.groupBy('userId').count().orderBy('count', ascending=False)
counts.show()

+------+-----+
|userId|count|
+------+-----+
|   414| 2698|
|   599| 2478|
|   474| 2108|
|   448| 1864|
|   274| 1346|
|   610| 1302|
|    68| 1260|
|   380| 1218|
|   606| 1115|
|   288| 1055|
|   249| 1046|
|   387| 1027|
|   182|  977|
|   307|  975|
|   603|  943|
|   298|  939|
|   177|  904|
|   318|  879|
|   232|  862|
|   480|  836|
+------+-----+
only showing top 20 rows



In [11]:
pdf = counts.pandas_api()
pdf = pdf[pdf['count'] >= 50].copy()
counts = pdf.to_spark()
counts.show()



+------+-----+
|userId|count|
+------+-----+
|   414| 2698|
|   599| 2478|
|   474| 2108|
|   448| 1864|
|   274| 1346|
|   610| 1302|
|    68| 1260|
|   380| 1218|
|   606| 1115|
|   288| 1055|
|   249| 1046|
|   387| 1027|
|   182|  977|
|   307|  975|
|   603|  943|
|   298|  939|
|   177|  904|
|   318|  879|
|   232|  862|
|   480|  836|
+------+-----+
only showing top 20 rows



In [12]:
counts.count()

385

In [13]:
train, test = rdf.randomSplit([.8, .2], seed=4563)
train.show()

+-------+------+------+
|movieId|userId|rating|
+-------+------+------+
|      1|     5|   4.0|
|      1|     7|   4.5|
|      1|    15|   2.5|
|      1|    17|   4.5|
|      1|    18|   3.5|
|      1|    19|   4.0|
|      1|    21|   3.5|
|      1|    27|   3.0|
|      1|    31|   5.0|
|      1|    32|   3.0|
|      1|    40|   5.0|
|      1|    43|   5.0|
|      1|    44|   3.0|
|      1|    45|   4.0|
|      1|    46|   5.0|
|      1|    50|   3.0|
|      1|    63|   5.0|
|      1|    68|   2.5|
|      1|    71|   5.0|
|      1|    73|   4.5|
+-------+------+------+
only showing top 20 rows



In [14]:
test.show()

+-------+------+------+
|movieId|userId|rating|
+-------+------+------+
|      1|     1|   4.0|
|      1|    33|   3.0|
|      1|    54|   3.0|
|      1|    57|   5.0|
|      1|    64|   4.0|
|      1|    66|   4.0|
|      1|    98|   4.5|
|      1|   155|   3.0|
|      1|   161|   4.0|
|      1|   200|   3.5|
|      1|   213|   3.5|
|      1|   226|   3.5|
|      1|   232|   3.5|
|      1|   263|   4.0|
|      1|   270|   5.0|
|      1|   273|   5.0|
|      1|   276|   4.0|
|      1|   322|   3.5|
|      1|   323|   3.5|
|      1|   332|   4.0|
+-------+------+------+
only showing top 20 rows



In [15]:
train.count()

80658

In [16]:
test.count()

20178

In [17]:
als = ALS(
    userCol='userId', 
    itemCol='movieId',
    ratingCol='rating',
    nonnegative=True,
    implicitPrefs=False,
    coldStartStrategy='drop',
)
als

ALS_aa55f4c94221

In [18]:
pg = ParamGridBuilder().addGrid(als.rank, [10, 50, 100])
pg = pg.addGrid(als.regParam, [.01,.05, .1,])
pg = pg.build()
pg

[{Param(parent='ALS_aa55f4c94221', name='rank', doc='rank of the factorization'): 10,
  Param(parent='ALS_aa55f4c94221', name='regParam', doc='regularization parameter (>= 0).'): 0.01},
 {Param(parent='ALS_aa55f4c94221', name='rank', doc='rank of the factorization'): 10,
  Param(parent='ALS_aa55f4c94221', name='regParam', doc='regularization parameter (>= 0).'): 0.05},
 {Param(parent='ALS_aa55f4c94221', name='rank', doc='rank of the factorization'): 10,
  Param(parent='ALS_aa55f4c94221', name='regParam', doc='regularization parameter (>= 0).'): 0.1},
 {Param(parent='ALS_aa55f4c94221', name='rank', doc='rank of the factorization'): 50,
  Param(parent='ALS_aa55f4c94221', name='regParam', doc='regularization parameter (>= 0).'): 0.01},
 {Param(parent='ALS_aa55f4c94221', name='rank', doc='rank of the factorization'): 50,
  Param(parent='ALS_aa55f4c94221', name='regParam', doc='regularization parameter (>= 0).'): 0.05},
 {Param(parent='ALS_aa55f4c94221', name='rank', doc='rank of the factor

In [19]:
evaluator = RegressionEvaluator(
    metricName='rmse',
    labelCol='rating',
    predictionCol='prediction',
)
evaluator

RegressionEvaluator_d78fd0018868

In [20]:
cv = CrossValidator(evaluator=evaluator,estimator=als,estimatorParamMaps=pg,numFolds=3)
cv

CrossValidator_ac0fdacb4b3a

In [21]:
%%time
model = cv.fit(train)

22/12/08 19:50:10 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
22/12/08 19:50:10 WARN InstanceBuilder$NativeBLAS: Failed to load implementation from:dev.ludovic.netlib.blas.ForeignLinkerBLAS
CPU times: user 302 ms, sys: 81.7 ms, total: 384 ms
Wall time: 56.5 s


In [22]:
best_model = model.bestModel
predictions = best_model.transform(test)
error = evaluator.evaluate(predictions)
error

0.8690993834520041

In [23]:
recommendations = best_model.recommendForAllUsers(numItems=5)
recommendations.show()



+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|     1|[{1262, 5.450519}...|
|     2|[{80906, 4.585936...|
|     3|[{5746, 4.9272103...|
|     4|[{3851, 4.9228463...|
|     5|[{8477, 4.7514753...|
|     6|[{2137, 4.615519}...|
|     7|[{260, 4.5581484}...|
|     8|[{1223, 4.6190696...|
|     9|[{89904, 4.861443...|
|    10|[{71579, 4.692719...|
|    11|[{27611, 5.144198...|
|    12|[{92259, 5.387304...|
|    13|[{7842, 4.8420177...|
|    14|[{80906, 4.430692...|
|    15|[{1204, 4.5635853...|
|    16|[{158966, 4.19272...|
|    17|[{170355, 4.69426...|
|    18|[{170355, 4.57990...|
|    19|[{1658, 4.3065524...|
|    20|[{720, 4.9089785}...|
+------+--------------------+
only showing top 20 rows



                                                                                

In [24]:
recommendations = recommendations.withColumn('recommendations', explode(recommendations.recommendations))
recommendations.show()

+------+------------------+
|userId|   recommendations|
+------+------------------+
|     1|  {1262, 5.450519}|
|     1| {7842, 5.4223576}|
|     1|  {3347, 5.362997}|
|     1|   {720, 5.344876}|
|     1|{132333, 5.316282}|
|     2|{80906, 4.5859365}|
|     2|  {89774, 4.58154}|
|     2|{106100, 4.545988}|
|     2|{171495, 4.541283}|
|     2| {61024, 4.483285}|
|     3| {5746, 4.9272103}|
|     3| {6835, 4.9272103}|
|     3| {5919, 4.8668237}|
|     3| {5181, 4.8611994}|
|     3| {7991, 4.7796674}|
|     4| {3851, 4.9228463}|
|     4| {3365, 4.7667346}|
|     4| {2390, 4.7218146}|
|     4| {4967, 4.6734924}|
|     4| {2583, 4.6633267}|
+------+------------------+
only showing top 20 rows



In [25]:
recommendations.printSchema()

root
 |-- userId: integer (nullable = false)
 |-- recommendations: struct (nullable = true)
 |    |-- movieId: integer (nullable = true)
 |    |-- rating: float (nullable = true)



In [26]:
recommendations = recommendations.select('userId', 'recommendations.movieId', 'recommendations.rating')
recommendations.show()

+------+-------+---------+
|userId|movieId|   rating|
+------+-------+---------+
|     1|   1262| 5.450519|
|     1|   7842|5.4223576|
|     1|   3347| 5.362997|
|     1|    720| 5.344876|
|     1| 132333| 5.316282|
|     2|  80906|4.5859365|
|     2|  89774|  4.58154|
|     2| 106100| 4.545988|
|     2| 171495| 4.541283|
|     2|  61024| 4.483285|
|     3|   5746|4.9272103|
|     3|   6835|4.9272103|
|     3|   5919|4.8668237|
|     3|   5181|4.8611994|
|     3|   7991|4.7796674|
|     4|   3851|4.9228463|
|     4|   3365|4.7667346|
|     4|   2390|4.7218146|
|     4|   4967|4.6734924|
|     4|   2583|4.6633267|
+------+-------+---------+
only showing top 20 rows



In [27]:
recommendations = recommendations.join(movie_rdf, 'movieId', 'left')
recommendations.show()

+-------+------+---------+--------------------+--------------------+
|movieId|userId|   rating|               title|              genres|
+-------+------+---------+--------------------+--------------------+
|   1262|     1| 5.450519|Great Escape, The...|Action|Adventure|...|
|   7842|     1|5.4223576|         Dune (2000)|Drama|Fantasy|Sci-Fi|
|   3347|     1| 5.362997|Never Cry Wolf (1...|     Adventure|Drama|
|    720|     1| 5.344876|Wallace & Gromit:...|Adventure|Animati...|
| 132333|     1| 5.316282|         Seve (2014)|   Documentary|Drama|
|  80906|     2|4.5859365|   Inside Job (2010)|         Documentary|
|  89774|     2|  4.58154|      Warrior (2011)|               Drama|
| 106100|     2| 4.545988|Dallas Buyers Clu...|               Drama|
| 171495|     2| 4.541283|              Cosmos|  (no genres listed)|
|  61024|     2| 4.483285|Pineapple Express...| Action|Comedy|Crime|
|   5746|     3|4.9272103|Galaxy of Terror ...|Action|Horror|Mys...|
|   6835|     3|4.9272103|Alien Co