In [14]:
import findspark
findspark.init()
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from pyspark.sql import functions as F
SparkSession.builder.config(conf=SparkConf())
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder, CrossValidator



######################
# init spark session #
######################
spark = SparkSession.builder \
    .appName("train-model") \
    .getOrCreate()

In [15]:
movie_ratings = spark.read.json('../data/movies.json')
movie_ratings.show(5)

+---+--------+------+--------------------+-------+--------------------+------+------+------------+---------+--------------------+------+
|age|function|gender|               genre|movieId|                name|number|rating|release_date|timestamp|                 url|userId|
+---+--------+------+--------------------+-------+--------------------+------+------+------------+---------+--------------------+------+
| 60| retired|     M|[Animation, Child...|      1|    Toy Story (1995)| 95076|     4| 01-Jan-1995|887736532|http://us.imdb.co...|   308|
| 60| retired|     M|[Action, Comedy, ...|      4|   Get Shorty (1995)| 95076|     5| 01-Jan-1995|887737890|http://us.imdb.co...|   308|
| 60| retired|     M|[Crime, Drama, Th...|      5|      Copycat (1995)| 95076|     4| 01-Jan-1995|887739608|http://us.imdb.co...|   308|
| 60| retired|     M|     [Drama, Sci-Fi]|      7|Twelve Monkeys (1...| 95076|     4| 01-Jan-1995|887738847|http://us.imdb.co...|   308|
| 60| retired|     M|[Children's, Come...

In [16]:
movie_ratings.printSchema()

root
 |-- age: string (nullable = true)
 |-- function: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- genre: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- movieId: string (nullable = true)
 |-- name: string (nullable = true)
 |-- number: string (nullable = true)
 |-- rating: string (nullable = true)
 |-- release_date: string (nullable = true)
 |-- timestamp: string (nullable = true)
 |-- url: string (nullable = true)
 |-- userId: string (nullable = true)



In [17]:
movie_df = movie_ratings.select(
    F.col("age").cast("int").alias("age"),
    F.col("function").alias("function"),
    F.col("rating").cast("float").alias("rating"),
    F.col("gender").alias("gender"),
    F.col("genre").alias("genre"),
    F.col("movieId").cast("int").alias("movieId"),
    F.col("release_date").alias("release_date"),
    F.col("timestamp").alias("timestamp"),
    F.col("url").alias("url"),
    F.col("number").alias("codeZip"),
    F.col("userId").cast("int").alias("userId"),
)

In [18]:
movie_df.printSchema()

root
 |-- age: integer (nullable = true)
 |-- function: string (nullable = true)
 |-- rating: float (nullable = true)
 |-- gender: string (nullable = true)
 |-- genre: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- movieId: integer (nullable = true)
 |-- release_date: string (nullable = true)
 |-- timestamp: string (nullable = true)
 |-- url: string (nullable = true)
 |-- codeZip: string (nullable = true)
 |-- userId: integer (nullable = true)



<code>
split the data:
</code>

In [19]:
(training, testing) = movie_df.randomSplit([0.8, 0.2])

In [21]:
# als = ALS(maxIter=5, 
#           rank=4,
#           regParam=0.01, 
#           userCol="userId", 
#           itemCol="movieId", 
#           ratingCol="rating",
#           coldStartStrategy = "drop")
# model = als.fit(training)

# Configure the ALS model
als = ALS(userCol='userId', itemCol='movieId', ratingCol='rating',
          coldStartStrategy='drop', nonnegative=True)


param_grid = ParamGridBuilder()\
             .addGrid(als.rank, [1, 20, 30])\
             .addGrid(als.maxIter, [20])\
             .addGrid(als.regParam, [.05, .15])\
             .build()
evaluator = RegressionEvaluator(metricName='rmse', labelCol='rating', predictionCol='prediction')

cv = CrossValidator(
        estimator=als,
        estimatorParamMaps=param_grid,
        evaluator=evaluator,
        numFolds=3)

model = cv.fit(training)

best_model = model.bestModel
print('rank: ', best_model.rank)
print('MaxIter: ', best_model._java_obj.parent().getMaxIter())
print('RegParam: ', best_model._java_obj.parent().getRegParam())

rank:  30
MaxIter:  20
RegParam:  0.15


In [8]:
predictions = model.transform(testing)
evaluator = RegressionEvaluator(metricName="rmse", 
                                labelCol="rating",
                                predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
# predictions.show()
print("Root-mean-square error = " + str(rmse))

Root-mean-square error = 0.9924835470959829


In [14]:
model.save("als-model")

In [9]:
userRecs = model.recommendForAllUsers(100)



In [11]:
userRecs.show(5)

+------+--------------------+
|userId|     recommendations|
+------+--------------------+
|   540|[{1554, 9.537097}...|
|   580|[{1104, 9.678745}...|
|   530|[{1643, 7.4660716...|
|   830|[{1242, 11.867263...|
|   210|[{1554, 7.995552}...|
+------+--------------------+
only showing top 5 rows

