In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pyspark.sql.functions as sql_func

DATA_DIR = "/data/ml-latest"
!ls {DATA_DIR}/tf_idf.parquet

In [4]:
# spark session
from pyspark.sql import SparkSession

spark = (
    SparkSession
    .builder
    .master("local[*]")
    .config("spark.driver.memory", "4g")
    .getOrCreate()
)

In [5]:
tf_idf = spark.read.parquet(os.path.join(DATA_DIR, "tf_idf.parquet")).cache()
tf_idf.show()

+-------+--------------------+--------------------+
|movieId|               title|              tf_idf|
+-------+--------------------+--------------------+
|     35|   Carrington (1995)|(1024,[8,74,189,2...|
|    503| New Age, The (1994)|(1024,[434,769,82...|
|    583|Dear Diary (Caro ...|(1024,[434,741,84...|
|    594|Snow White and th...|(1024,[29,52,60,8...|
|    610|  Heavy Metal (1981)|(1024,[32,93,112,...|
|    614|       Loaded (1994)|(1024,[263,434],[...|
|    761| Phantom, The (1996)|(1024,[43,169,196...|
|    880|Island of Dr. Mor...|(1024,[44,81,219,...|
|   1369|I Can't Sleep (J'...|(1024,[263,434],[...|
|   1519|Broken English (1...|(1024,[434,829],[...|
|   1589|     Cop Land (1997)|(1024,[37,61,164,...|
|   1815|         Eden (1997)|(1024,[434,829],[...|
|   1881|Quest for Camelot...|(1024,[57,165,337...|
|   2080|Lady and the Tram...|(1024,[29,37,83,1...|
|   2324|Life Is Beautiful...|(1024,[3,31,32,45...|
|   2444|24 7: Twenty Four...|(1024,[122,221,43...|
|   2445|At 

In [6]:
#ranking counting
ratings = (
    spark
    .read
    .csv(
        os.path.join(DATA_DIR, "ratings.csv"),
        header=True,
        inferSchema=True
    )
    # только о миллионе. потому что даже на таком объёме обсчёт модели на четырёх ядрах
    # занимает пару часов
    .limit(100000)
    .select("movieId", "userId", "rating")
)

In [13]:
# averege rank by user
user_avg = ratings.groupBy('userId').agg(sql_func.avg("rating").alias("avg_rating_user"))
user_avg.show()

+------+------------------+
|userId|   avg_rating_user|
+------+------------------+
|     1| 4.277777777777778|
|     2|3.3181818181818183|
|     3|               3.1|
|     4|               3.5|
|     5| 4.269230769230769|
|     6|              3.75|
|     7|3.3679245283018866|
|     8|2.9911504424778763|
|     9|3.8511904761904763|
|    10| 4.230769230769231|
|    11| 3.211453744493392|
|    12|3.8548387096774195|
|    13|2.8181818181818183|
|    14|               3.4|
|    15|3.7367256637168142|
|    16| 4.417582417582418|
|    17|3.5833333333333335|
|    18| 4.166666666666667|
|    19|2.2666666666666666|
|    20| 3.846296296296296|
+------+------------------+
only showing top 20 rows



In [58]:
# averege rank by movie
movie_avg = ratings.groupBy('movieID').agg(sql_func.avg("rating").alias("avg_rating_movie"))
movie_avg.show()

+-------+------------------+
|movieID|  avg_rating_movie|
+-------+------------------+
|    110| 3.911290322580645|
|    147| 3.480769230769231|
|    858|4.4511111111111115|
|   1221| 4.248466257668712|
|   1246|3.8317757009345796|
|   1968|3.8440366972477062|
|   2762| 4.015151515151516|
|   2918| 4.008064516129032|
|   2959| 4.192139737991266|
|   4226| 4.043103448275862|
|   4878|3.7548076923076925|
|   5577|3.0416666666666665|
|  33794|3.7288135593220337|
|  54503|3.7804878048780486|
|  58559| 4.082781456953643|
|  59315|3.9402173913043477|
|  68358| 4.172131147540983|
|  69844| 3.769230769230769|
|  73017|3.6444444444444444|
|  81834|               3.9|
+-------+------------------+
only showing top 20 rows



In [None]:
#Joining data
joined_data = (ratings.alias("r")\
        .join(tf_idf.alias("tf-idf"), sql_func.col("tf-idf.movieId") ==  sql_func.col("r.movieId"))\
        .join(user_avg.alias("user"), sql_func.col("user.userId") ==  sql_func.col("r.userId"))\
        .join(movie_avg.alias("movie"), sql_func.col("movie.movieId") ==  sql_func.col("r.movieId"))\
        .select(sql_func.col("r.userId"),sql_func.col("r.movieId"),sql_func.col("r.rating"),\
        sql_func.col("user.avg_rating_user"),sql_func.col("movie.avg_rating_movie"),sql_func.col("tf-idf.tf_idf")))

In [7]:
# building vector
from sklearn.linear_model import ElasticNet
import numpy as np
from pyspark.sql.types import FloatType, ArrayType

def sklearn_lr(spark_x: list, spark_y: list) -> list:

    numpy_x = np.array([vector.toArray() for vector in spark_x])
    numpy_y = np.array(spark_y).reshape(-1, 1)
    lr = ElasticNet().fit(numpy_x, numpy_y)

    return [lr.sparse_coef_.todense().tolist()[0], lr.intercept_.tolist()]

reg_udf = sql_func.udf(sklearn_lr, returnType=ArrayType(ArrayType(FloatType())))

In [77]:
# train _ test _ split

train_data, test_data  = joined_data\
        .select(sql_func.col("userId"),sql_func.col("movieId"),sql_func.col("rating"),\
        list_concat("tf_idf","avg_rating_user","avg_rating_movie").alias("tf_idf")).randomSplit([0.8, 0.2], seed=50)
                         
train_data.cache()
test_data.cache()

Py4JJavaError: An error occurred while calling o1218.cache.
: java.util.concurrent.TimeoutException: Futures timed out after [300 seconds]
	at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:219)
	at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:201)
	at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecuteBroadcast(BroadcastExchangeExec.scala:136)
	at org.apache.spark.sql.execution.InputAdapter.doExecuteBroadcast(WholeStageCodegenExec.scala:367)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeBroadcast$1.apply(SparkPlan.scala:144)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeBroadcast$1.apply(SparkPlan.scala:140)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.SparkPlan.executeBroadcast(SparkPlan.scala:140)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.prepareBroadcast(BroadcastHashJoinExec.scala:135)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.codegenInner(BroadcastHashJoinExec.scala:232)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.doConsume(BroadcastHashJoinExec.scala:102)
	at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:181)
	at org.apache.spark.sql.execution.FilterExec.consume(basicPhysicalOperators.scala:85)
	at org.apache.spark.sql.execution.FilterExec.doConsume(basicPhysicalOperators.scala:206)
	at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:181)
	at org.apache.spark.sql.execution.InputAdapter.consume(WholeStageCodegenExec.scala:354)
	at org.apache.spark.sql.execution.InputAdapter.doProduce(WholeStageCodegenExec.scala:383)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:88)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.InputAdapter.produce(WholeStageCodegenExec.scala:354)
	at org.apache.spark.sql.execution.FilterExec.doProduce(basicPhysicalOperators.scala:125)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:88)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.FilterExec.produce(basicPhysicalOperators.scala:85)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.doProduce(BroadcastHashJoinExec.scala:97)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:88)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.produce(BroadcastHashJoinExec.scala:39)
	at org.apache.spark.sql.execution.ProjectExec.doProduce(basicPhysicalOperators.scala:45)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:88)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.ProjectExec.produce(basicPhysicalOperators.scala:35)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.doProduce(BroadcastHashJoinExec.scala:97)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:88)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.produce(BroadcastHashJoinExec.scala:39)
	at org.apache.spark.sql.execution.ProjectExec.doProduce(basicPhysicalOperators.scala:45)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:88)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.ProjectExec.produce(basicPhysicalOperators.scala:35)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.doProduce(BroadcastHashJoinExec.scala:97)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:88)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.joins.BroadcastHashJoinExec.produce(BroadcastHashJoinExec.scala:39)
	at org.apache.spark.sql.execution.ProjectExec.doProduce(basicPhysicalOperators.scala:45)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:88)
	at org.apache.spark.sql.execution.CodegenSupport$$anonfun$produce$1.apply(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.CodegenSupport$class.produce(WholeStageCodegenExec.scala:83)
	at org.apache.spark.sql.execution.ProjectExec.produce(basicPhysicalOperators.scala:35)
	at org.apache.spark.sql.execution.WholeStageCodegenExec.doCodeGen(WholeStageCodegenExec.scala:524)
	at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:576)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.python.EvalPythonExec.doExecute(EvalPythonExec.scala:89)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.InputAdapter.inputRDDs(WholeStageCodegenExec.scala:371)
	at org.apache.spark.sql.execution.ProjectExec.inputRDDs(basicPhysicalOperators.scala:41)
	at org.apache.spark.sql.execution.SortExec.inputRDDs(SortExec.scala:121)
	at org.apache.spark.sql.execution.SampleExec.inputRDDs(basicPhysicalOperators.scala:271)
	at org.apache.spark.sql.execution.WholeStageCodegenExec.doExecute(WholeStageCodegenExec.scala:605)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:131)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:155)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:152)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:127)
	at org.apache.spark.sql.execution.columnar.InMemoryRelation.buildBuffers(InMemoryRelation.scala:97)
	at org.apache.spark.sql.execution.columnar.InMemoryRelation.<init>(InMemoryRelation.scala:92)
	at org.apache.spark.sql.execution.columnar.InMemoryRelation$.apply(InMemoryRelation.scala:42)
	at org.apache.spark.sql.execution.CacheManager$$anonfun$cacheQuery$1.apply(CacheManager.scala:97)
	at org.apache.spark.sql.execution.CacheManager.writeLock(CacheManager.scala:67)
	at org.apache.spark.sql.execution.CacheManager.cacheQuery(CacheManager.scala:91)
	at org.apache.spark.sql.Dataset.persist(Dataset.scala:2902)
	at org.apache.spark.sql.Dataset.cache(Dataset.scala:2912)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:214)
	at java.lang.Thread.run(Thread.java:748)


In [65]:
model_coef = (
    train_data
    .groupBy("userId")
    .agg(
        sql_func.collect_list("tf_idf").alias("x"),
        sql_func.collect_list("rating").alias("y")
    )
    .withColumn("model_coeff", reg_udf("x", "y"))
    .cache()
)

In [67]:
model_coef.count()

1010

In [68]:
model_coef.show()

+------+--------------------+--------------------+--------------------+
|userId|                   x|                   y|         model_coeff|
+------+--------------------+--------------------+--------------------+
|   148|[(1024,[1,2,5,7,8...|[4.0, 4.0, 4.0, 3...|[[0.0, 0.0, 0.0, ...|
|   463|[(1024,[4,45,46,4...|[1.0, 1.0, 4.0, 2...|[[0.0, 0.0, 0.0, ...|
|   471|[(1024,[3,20,27,2...|[3.5, 5.0, 5.0, 4...|[[0.0, 0.0, 0.0, ...|
|   496|[(1024,[1,2,5,7,8...|[5.0, 4.0, 3.0, 5...|[[0.0, 0.0, 0.0, ...|
|   833|[(1024,[23,24,35,...|     [5.0, 4.0, 4.0]|[[0.0, 0.0, 0.0, ...|
|   243|[(1024,[0,5,65,68...|[4.0, 3.0, 3.5, 3...|[[0.008234739, 0....|
|   392|[(1024,[23,24,27,...|[4.0, 4.0, 2.0, 5...|[[0.0, 0.0, 0.0, ...|
|   540|[(1024,[11,12,29,...|[4.0, 5.0, 5.0, 5...|[[0.0, 0.0, 0.0, ...|
|   623|[(1024,[4,45,46,4...|[4.0, 5.0, 5.0, 4...|[[0.0, 0.0, 0.0, ...|
|   737|[(1024,[3,4,13,29...|[1.5, 1.0, 3.5, 4.0]|[[0.0, 0.0, 0.0, ...|
|   858|[(1024,[1,2,5,7,8...|[5.0, 5.0, 5.0, 1...|[[0.0, 0.0, 0.

In [69]:
from pyspark.ml.linalg import SparseVector

def lr_apply(x: SparseVector, lr_coef: list) -> float:
    """
        param x: вектор фич для регрессии
        param lr_coef: 
        return: предсказанное моделью регрессии значение
    """
    return float(np.array(x).dot(np.array(lr_coef[0])) + lr_coef[1][0])

lr_apply_udf = sql_func.udf(lr_apply, returnType=FloatType())

In [70]:
#make prediction func
from pyspark.sql import DataFrame

def get_prediction(data: DataFrame) -> DataFrame:
    return (
        data
        .join(model_coef, "userId")
        .select(
            "userId",
            "rating",
            "movieId",
            "tf_idf", 
            lr_apply_udf("tf_idf", "model_coeff").alias("prediction"))
        .cache()
    )

In [64]:
train_prediction = get_prediction(train_data)
(
    train_prediction.write.mode("overwrite")
    .parquet(os.path.join(DATA_DIR, "train_prediction.parquet"))
)

NameError: name 'get_prediction' is not defined

In [14]:
train_prediction.show()

+------+------+-------+--------------------+----------+
|userId|rating|movieId|              tf_idf|prediction|
+------+------+-------+--------------------+----------+
|   148|   4.0|    260|(1024,[1,2,5,7,8,...|  3.995837|
|   148|   4.0|    589|(1024,[3,9,29,36,...| 4.0452337|
|   148|   3.5|   1196|(1024,[9,10,23,29...|  3.517842|
|   148|   3.5|   1210|(1024,[4,9,23,29,...| 3.5223696|
|   148|   2.5|   1291|(1024,[23,29,51,5...| 2.5620809|
|   148|   4.0|   2028|(1024,[3,4,29,52,...|  3.994406|
|   148|   5.0|   3300|(1024,[27,29,37,4...|  4.906541|
|   148|   1.5|   4246|(1024,[13,44,50,5...| 1.6575054|
|   148|   5.0|   4993|(1024,[4,9,10,13,...|  5.053576|
|   148|   4.5|   5445|(1024,[8,15,44,46...|  4.464959|
|   148|   4.5|   5574|(1024,[17,49,54,5...| 4.4506645|
|   148|   5.0|   5952|(1024,[9,10,13,17...| 4.9269485|
|   148|   3.5|   6373|(1024,[5,29,40,51...|  3.593023|
|   148|   4.5|   7143|(1024,[1,10,13,32...|  4.452602|
|   148|   5.0|   7153|(1024,[3,6,9,17,2...|  4.

In [15]:
# get train predictions
def evaluate_prediction(prediction: DataFrame) -> float:
    return np.sqrt(
        prediction
        .selectExpr("""
            CASE
                WHEN prediction > 5 THEN 5
                WHEN prediction < 0.5 THEN 0.5
                ELSE prediction
            END AS prediction
        """, "rating")
        .select(
            sql_func.pow(sql_func.col("rating") - sql_func.col("prediction"), 2)
            .alias("squared_error")
        )
        .agg(sql_func.avg("squared_error"))
        .first()[0]
    )

In [16]:
evaluate_prediction(train_prediction)

0.6416910375297744

In [17]:
test_prediction = get_prediction(test_data)
(
    test_prediction.write.mode("overwrite")
    .parquet(os.path.join(DATA_DIR, "test_prediction.parquet"))
)

In [18]:
evaluate_prediction(test_prediction)

0.9859142865193938