In [1]:
import os
import pandas as pd
import numpy as np

# 导入PySpark
from pyspark import SparkContext
from pyspark.sql import SparkSession

# 导入pyspark 的部分函数
from pyspark.sql.functions import col, min, max, avg, lit

# 导入pyspark 的机器学习相关的一些包
from pyspark.ml.recommendation import ALS 
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator # Cross-Validation
from pyspark.ml.evaluation import RegressionEvaluator # Performance metric

import seaborn as sns
import matplotlib.pyplot as plt

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

pd.set_option('display.max_columns', 200)
pd.set_option('display.max_colwidth', 400)

In [2]:
import os
os.environ ['JAVA_HOME'] = '/usr/lib/jvm/openlogic-openjdk-11-hotspot-amd64/'
#'C:\Program Files\Java\jre1.8.0_271'

In [3]:
sc = SparkContext(appName = "Movie-Recommendation")
print(sc)

<SparkContext master=local[*] appName=Movie-Recommendation>


In [4]:
spark = SparkSession.Builder().getOrCreate()

### 数据载入

In [8]:
ratings = spark.read.csv('./input/ratings.csv', header = True, inferSchema=True)
ratings.show(10)

+------+-------+------+----------+
|userId|movieId|rating| timestamp|
+------+-------+------+----------+
|     1|      2|   3.5|1112486027|
|     1|     29|   3.5|1112484676|
|     1|     32|   3.5|1112484819|
|     1|     47|   3.5|1112484727|
|     1|     50|   3.5|1112484580|
|     1|    112|   3.5|1094785740|
|     1|    151|   4.0|1094785734|
|     1|    223|   4.0|1112485573|
|     1|    253|   4.0|1112484940|
|     1|    260|   4.0|1112484826|
+------+-------+------+----------+
only showing top 10 rows



In [14]:
ratings = ratings.drop('timestamp')

In [15]:
ratings.head()

Row(userId=1, movieId=2, rating=3.5)

In [16]:
ratings.printSchema()

root
 |-- userId: integer (nullable = true)
 |-- movieId: integer (nullable = true)
 |-- rating: double (nullable = true)



In [18]:
numerator = ratings.select("rating").count()


num_users = ratings.select("userId").distinct().count()
num_items = ratings.select("movieId").distinct().count()

denominator = num_users * num_items

sparsity = (1.0 - (numerator * 1.0)/ denominator) * 100
print("The ratings dataframe is ", "%.2f" % sparsity + "% empty.")

The ratings dataframe is  98.95% empty.


In [19]:
num_users, num_items

(7120, 14026)

In [23]:
(ratings.groupBy("userId").count().filter("`count`  > 1").sort(col("count").desc()).show(n = 20))

+------+-----+
|userId|count|
+------+-----+
|  3907| 2711|
|  2261| 2644|
|   903| 2608|
|  4358| 2575|
|  4222| 2553|
|  3318| 2382|
|   741| 2212|
|  6719| 2206|
|   982| 2183|
|   156| 2179|
|  6636| 2052|
|  4507| 1969|
|   775| 1957|
|  3797| 1938|
|  6373| 1929|
|  3858| 1858|
|   768| 1769|
|  5843| 1765|
|  3284| 1742|
|  4967| 1719|
+------+-----+
only showing top 20 rows



In [25]:
(ratings.groupBy("movieId")
    .count()
    .filter("`count` > 1")
    .sort(col("count").desc())
    .show(n = 20))

+-------+-----+
|movieId|count|
+-------+-----+
|    296| 3498|
|    356| 3476|
|    593| 3247|
|    318| 3216|
|    480| 3129|
|    260| 2874|
|    110| 2799|
|    589| 2711|
|   2571| 2705|
|    527| 2598|
|      1| 2569|
|    457| 2568|
|    780| 2546|
|    150| 2512|
|     50| 2490|
|   1210| 2480|
|   1196| 2418|
|    592| 2406|
|   2858| 2355|
|     32| 2312|
+-------+-----+
only showing top 20 rows



In [26]:
# Avg num ratings per users
print("Avg num ratings per user: ")
ratings.groupBy("userId").count().select(avg("count")).show()

Avg num ratings per user: 
+------------------+
|        avg(count)|
+------------------+
|147.27176966292134|
+------------------+



In [34]:
# Create test and train set
(train, test) = ratings.randomSplit([0.8, 0.2], seed = 100)
print(train.count())
print(test.count())

838907
209668


### ALS

In [30]:
als = ALS(userCol="userId", itemCol="movieId", ratingCol="rating", 
          nonnegative = True, # Non negative matrix factorization
          coldStartStrategy = "drop", # What to do if user do not appear in train and test set
          implicitPrefs = False) # Explicit preference

param_grid = ParamGridBuilder() \
            .addGrid(als.rank, [10, 50, 100]) \
            .addGrid(als.maxIter, [5, 50, 100]) \
            .addGrid(als.regParam, [.01, .05, .1]) \
            .build()

# Define evaluator as RMSE
evaluator = RegressionEvaluator(metricName = "rmse", 
                                labelCol = "rating", 
                                predictionCol = "prediction")

cv = CrossValidator(estimator = als, 
                    estimatorParamMaps = param_grid, 
                    evaluator = evaluator, 
                    numFolds = 5)

In [31]:
# Print length of evaluator
print ("Num models to be tested: ", len(param_grid))

Num models to be tested:  27


In [35]:
als_model = als.fit(train)
test_pred = als_model.transform(test)
test_pred.show(n=10)

+------+-------+------+----------+
|userId|movieId|rating|prediction|
+------+-------+------+----------+
|  5585|    148|   3.0| 2.5450664|
|  3990|    148|   4.0| 2.3091614|
|  5186|    148|   2.0| 2.9203124|
|  5938|    148|   4.0|  3.121954|
|  1716|    148|   2.0| 2.9002583|
|  2671|    148|   3.0| 2.3622625|
|  3576|    148|   2.0| 2.2861638|
|  3335|    148|   5.0|  2.115857|
|  4923|    463|   3.0|  2.838025|
|   156|    463|   4.0| 3.3625484|
+------+-------+------+----------+
only showing top 10 rows



In [37]:
print("RMSE: ",  evaluator.evaluate(test_pred))

RMSE:  0.8250989919432624


In [38]:
# Generate n recommendations for all users
ALS_recommendations = als_model.recommendForAllUsers(numItems = 10) 
ALS_recommendations.show(n = 10)

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|  1580|[[110603, 5.27052...|
|  4900|[[83359, 5.970939...|
|  5300|[[40697, 5.931905...|
|  6620|[[40697, 5.615771...|
|   471|[[727, 4.839523],...|
|  1591|[[727, 5.8318753]...|
|  4101|[[83359, 5.125425...|
|  1342|[[40697, 5.78063]...|
|  2122|[[3817, 4.7367983...|
|  2142|[[6823, 5.82299],...|
+------+--------------------+
only showing top 10 rows



In [39]:
# Temporary table
ALS_recommendations.registerTempTable("ALS_recs_temp")

clean_recs = spark.sql("""SELECT userId,
                            movieIds_and_ratings.movieId AS movieId,
                            movieIds_and_ratings.rating AS prediction
                        FROM ALS_recs_temp
                        LATERAL VIEW explode(recommendations) exploded_table
                            AS movieIds_and_ratings""")
clean_recs.show()

+------+-------+----------+
|userId|movieId|prediction|
+------+-------+----------+
|  1580| 110603| 5.2705255|
|  1580|  25905|  5.057629|
|  1580|  73413| 4.8326144|
|  1580|    404| 4.6601586|
|  1580|  59295| 4.6551867|
|  1580|   4176| 4.6130996|
|  1580|   7077|  4.549643|
|  1580|  59376| 4.5258656|
|  1580|  72714| 4.5137124|
|  1580|   2538|  4.462394|
|  4900|  83359| 5.9709396|
|  4900|   3817|  5.428588|
|  4900|   6375|  5.392499|
|  4900|  95776| 5.3505826|
|  4900|  94410|    5.3489|
|  4900|  90378|  5.337767|
|  4900|  79987|  5.286213|
|  4900|  51455| 5.2859397|
|  4900|  82931|  5.273943|
|  4900|   7879| 5.2542114|
+------+-------+----------+
only showing top 20 rows



In [46]:
# Recommendations for unread books
new_movies = clean_recs.join(ratings, ["userId", "movieId"], "left").filter(ratings.rating.isNull())

In [47]:
new_movies.show()
print(new_movies.count())

+------+-------+----------+------+
|userId|movieId|prediction|rating|
+------+-------+----------+------+
|    52|   4454|  4.607009|  null|
|    86|  73469| 4.6333914|  null|
|    94|  88570|  4.390179|  null|
|   111|   4454|  4.965559|  null|
|   150|   2675| 4.6041727|  null|
|   156|  83359| 5.4530454|  null|
|   227|  54986|  4.214472|  null|
|   232|  53883|  5.464476|  null|
|   283|   4536|   5.10898|  null|
|   314|  95776|  5.238259|  null|
|   316|  79987|  4.425828|  null|
|   380|   3817| 5.1334486|  null|
|   392|   3817|  5.642511|  null|
|   424|  31337|  4.165175|  null|
|   457|    727| 4.5919275|  null|
|   459| 105943| 4.6700077|  null|
|   462|  79987|  4.887365|  null|
|   483|  95776| 5.9356556|  null|
|   514|  44168| 4.4875383|  null|
|   540|   2675| 4.6302557|  null|
+------+-------+----------+------+
only showing top 20 rows

71017
