In [13]:
import findspark
findspark.init()
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from pyspark.sql import functions as F
SparkSession.builder.config(conf=SparkConf())
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator



######################
# init spark session #
######################
spark = SparkSession.builder \
    .appName("training-model") \
    .getOrCreate()

In [14]:
movie_ratings = spark.read.json('../data/movies.json')
movie_ratings.show(5)

+---+--------+------+--------------------+-------+--------------------+------+------+------------+---------+--------------------+------+
|age|function|gender|               genre|movieId|                name|number|rating|release_date|timestamp|                 url|userId|
+---+--------+------+--------------------+-------+--------------------+------+------+------------+---------+--------------------+------+
| 60| retired|     M|[Animation, Child...|      1|    Toy Story (1995)| 95076|     4| 01-Jan-1995|887736532|http://us.imdb.co...|   308|
| 60| retired|     M|[Action, Comedy, ...|      4|   Get Shorty (1995)| 95076|     5| 01-Jan-1995|887737890|http://us.imdb.co...|   308|
| 60| retired|     M|[Crime, Drama, Th...|      5|      Copycat (1995)| 95076|     4| 01-Jan-1995|887739608|http://us.imdb.co...|   308|
| 60| retired|     M|     [Drama, Sci-Fi]|      7|Twelve Monkeys (1...| 95076|     4| 01-Jan-1995|887738847|http://us.imdb.co...|   308|
| 60| retired|     M|[Children's, Come...

In [15]:
movie_ratings.printSchema()

root
 |-- age: string (nullable = true)
 |-- function: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- genre: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- movieId: string (nullable = true)
 |-- name: string (nullable = true)
 |-- number: string (nullable = true)
 |-- rating: string (nullable = true)
 |-- release_date: string (nullable = true)
 |-- timestamp: string (nullable = true)
 |-- url: string (nullable = true)
 |-- userId: string (nullable = true)



In [16]:
movie_df = movie_ratings.select(
    F.col("age").cast("int").alias("age"),
    F.col("function").alias("function"),
    F.col("rating").cast("float").alias("rating"),
    F.col("gender").alias("gender"),
    F.col("genre").alias("genre"),
    F.col("movieId").cast("int").alias("movieId"),
    F.col("release_date").alias("release_date"),
    F.col("timestamp").alias("timestamp"),
    F.col("url").alias("url"),
    F.col("number").alias("codeZip"),
    F.col("userId").cast("int").alias("userId"),
)

In [17]:
movie_df.printSchema()

root
 |-- age: integer (nullable = true)
 |-- function: string (nullable = true)
 |-- rating: float (nullable = true)
 |-- gender: string (nullable = true)
 |-- genre: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- movieId: integer (nullable = true)
 |-- release_date: string (nullable = true)
 |-- timestamp: string (nullable = true)
 |-- url: string (nullable = true)
 |-- codeZip: string (nullable = true)
 |-- userId: integer (nullable = true)



<code>
split the data:
</code>

In [18]:
(training, testing) = movie_df.randomSplit([0.8, 0.2])

In [28]:
als = ALS(maxIter=5,
          rank=20,
          regParam=0.1, 
          userCol="userId", 
          itemCol="movieId", 
          ratingCol="rating",
          coldStartStrategy = "drop")

model = als.fit(training)

In [29]:
predictions = model.transform(testing)
evaluator = RegressionEvaluator(metricName="rmse", 
                                labelCol="rating",
                                predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
# predictions.show()
print("Root-mean-square error = " + str(rmse))

Root-mean-square error = 0.921739590707558


In [21]:
model.save("../model/als-model")

In [30]:
def getUsers(movieId,limit):
    df = testing.where(f"movieId = {movieId}")
    return df.select("userId").distinct().limit(limit)

In [31]:
users = getUsers(252,5)
userSubsetRecs = model.recommendForUserSubset(users, 2)

In [32]:
userSubsetRecs.collect()

[Row(userId=580, recommendations=[Row(movieId=613, rating=4.714283466339111), Row(movieId=1664, rating=4.530240535736084)]),
 Row(userId=481, recommendations=[Row(movieId=1463, rating=5.287625312805176), Row(movieId=1643, rating=5.100905895233154)]),
 Row(userId=472, recommendations=[Row(movieId=1463, rating=5.345054626464844), Row(movieId=1169, rating=5.303918361663818)]),
 Row(userId=804, recommendations=[Row(movieId=1463, rating=4.943421363830566), Row(movieId=1449, rating=4.697652339935303)]),
 Row(userId=496, recommendations=[Row(movieId=919, rating=4.379052639007568), Row(movieId=921, rating=4.319170951843262)])]