数据读取

In [34]:
from pyspark.sql import SparkSession
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS

spark = SparkSession.builder.appName("ALS_demo").getOrCreate()

# 读取数据，并进行类型转换
ratings = spark.read.csv('hdfs://localhost:9000/input/ml_data/ratings.csv', header=True)
ratings = ratings.withColumn("userId", ratings.userId.astype("int"))
ratings = ratings.withColumn("movieId", ratings.userId.astype("int"))
ratings = ratings.withColumn("rating", ratings.userId.astype("float"))
ratings = ratings.withColumn("timestamp", ratings.userId.astype("int"))

(training, test) = ratings.randomSplit([0.8, 0.2])
print("Read successfully!")

Read successfully!


ALS模型的构建与学习

In [35]:
als = ALS(
    rank=20,
    maxIter=10,
    regParam=0.1,
    userCol='userId',
    itemCol='movieId',
    ratingCol='rating',
    coldStartStrategy="drop",
    seed=0
)
als_model = als.fit(training)

print("Train successfully!")

Train successfully!


模型的预测与评估

In [36]:
predictions = als_model.transform(test)
evaluator = RegressionEvaluator(
    metricName="rmse",
    labelCol="rating",
    predictionCol="prediction"
)
rmse = evaluator.evaluate(predictions)

print("Predict successfully!")
print("RMS error = " + str(rmse))

Predict successfully!
RMS error = 0.01129221602343152


模型输出

In [39]:
import pandas as pd
users = ratings.select(als.getUserCol()).distinct().limit(3)
movies = ratings.select(als.getItemCol()).distinct().limit(3)
user_output = als_model.recommendForUserSubset(users, 2)
movie_output = als_model.recommendForItemSubset(movies, 2)

# print(user_output.show(3, False))
# print(movie_output.show(3, False))

# user_recs = user_output.toPandas().to_json(orient = 'records')
# movie_recs = movie_output.toPandas().to_json(orient = 'records')

# movies = spark.read.csv('hdfs://localhost:9000/input/ml_data/movies.csv', header=True).toPandas()

user_recs = user_output.toPandas()
print(user_recs)

# other process



   userId                                    recommendations
0     471    [(257725, 203451.578125), (278380, 203028.875)]
1     463  [(281564, 192630.09375), (256519, 189882.484375)]
2     148  [(267833, 203874.484375), (275326, 195999.15625)]
Empty DataFrame
Columns: [movieId, title, genres]
Index: []


KeyError: 1