# Spark Practical Work

Authors:
 - Ahajjan Ziggaf Kanjaa, Mohammed
 - Labchiri Boukhalef, Younes
 - Ramírez Castaño, Víctor

### Data loading

In [64]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import (col, sum)
from pyspark.ml import Pipeline
from pyspark.ml.feature import (
    SQLTransformer,
    StringIndexer, 
    OneHotEncoder, 
    VectorAssembler, 
    StandardScaler
)
from pyspark.ml.regression import (
    DecisionTreeRegressor,
    RandomForestRegressor,
    GBTRegressor
)
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator
import gc

spark = SparkSession.builder.config("spark.driver.memory", "8g").appName("FlightModelPrediction").getOrCreate()

data_path = "../training_data/flight_data/1988.csv"

df = spark.read.csv(
    data_path,
    header=True,
    inferSchema=True,
    nullValue="NA"
)

                                                                                

### Explaratory data analysis (EDA)

In [65]:
df.printSchema()

cat_cols = [
    "UniqueCarrier",
    "Origin",
    "Dest",
    "TailNum",
    "CancellationCode"
]

# Variable info for numerical variables
df.describe().show()

# Variable info for categorical variables
for c in cat_cols:
    print(f"\nColumn: {c}")
    df.groupBy(c).count() \
        .orderBy("count", ascending=False) \
        .show(5)

# Show null values
df.select([
    sum(col(c).isNull().cast("int")).alias(c)
    for c in df.columns
]).show(truncate=False)

root
 |-- Year: integer (nullable = true)
 |-- Month: integer (nullable = true)
 |-- DayofMonth: integer (nullable = true)
 |-- DayOfWeek: integer (nullable = true)
 |-- DepTime: integer (nullable = true)
 |-- CRSDepTime: integer (nullable = true)
 |-- ArrTime: integer (nullable = true)
 |-- CRSArrTime: integer (nullable = true)
 |-- UniqueCarrier: string (nullable = true)
 |-- FlightNum: integer (nullable = true)
 |-- TailNum: string (nullable = true)
 |-- ActualElapsedTime: integer (nullable = true)
 |-- CRSElapsedTime: integer (nullable = true)
 |-- AirTime: string (nullable = true)
 |-- ArrDelay: integer (nullable = true)
 |-- DepDelay: integer (nullable = true)
 |-- Origin: string (nullable = true)
 |-- Dest: string (nullable = true)
 |-- Distance: integer (nullable = true)
 |-- TaxiIn: string (nullable = true)
 |-- TaxiOut: string (nullable = true)
 |-- Cancelled: integer (nullable = true)
 |-- CancellationCode: string (nullable = true)
 |-- Diverted: integer (nullable = true)
 |

                                                                                

+-------+-------+------------------+-----------------+------------------+------------------+------------------+-----------------+------------------+-------------+-----------------+-------+------------------+------------------+-------+------------------+------------------+-------+-------+------------------+------+-------+--------------------+----------------+--------------------+------------+------------+--------+-------------+-----------------+
|summary|   Year|             Month|       DayofMonth|         DayOfWeek|           DepTime|        CRSDepTime|          ArrTime|        CRSArrTime|UniqueCarrier|        FlightNum|TailNum| ActualElapsedTime|    CRSElapsedTime|AirTime|          ArrDelay|          DepDelay| Origin|   Dest|          Distance|TaxiIn|TaxiOut|           Cancelled|CancellationCode|            Diverted|CarrierDelay|WeatherDelay|NASDelay|SecurityDelay|LateAircraftDelay|
+-------+-------+------------------+-----------------+------------------+------------------+----------

                                                                                

+-------------+------+
|UniqueCarrier| count|
+-------------+------+
|           DL|753983|
|           AA|694757|
|           UA|587144|
|           US|494383|
|           PI|470957|
+-------------+------+
only showing top 5 rows

Column: Origin


                                                                                

+------+------+
|Origin| count|
+------+------+
|   ORD|271494|
|   ATL|259731|
|   DFW|216849|
|   LAX|169696|
|   DEN|161146|
+------+------+
only showing top 5 rows

Column: Dest


                                                                                

+----+------+
|Dest| count|
+----+------+
| ORD|274766|
| ATL|260573|
| DFW|220266|
| LAX|169699|
| DEN|163598|
+----+------+
only showing top 5 rows

Column: TailNum


                                                                                

+-------+-------+
|TailNum|  count|
+-------+-------+
|   NULL|5202096|
+-------+-------+


Column: CancellationCode


                                                                                

+----------------+-------+
|CancellationCode|  count|
+----------------+-------+
|            NULL|5202096|
+----------------+-------+





+----+-----+----------+---------+-------+----------+-------+----------+-------------+---------+-------+-----------------+--------------+-------+--------+--------+------+----+--------+-------+-------+---------+----------------+--------+------------+------------+--------+-------------+-----------------+
|Year|Month|DayofMonth|DayOfWeek|DepTime|CRSDepTime|ArrTime|CRSArrTime|UniqueCarrier|FlightNum|TailNum|ActualElapsedTime|CRSElapsedTime|AirTime|ArrDelay|DepDelay|Origin|Dest|Distance|TaxiIn |TaxiOut|Cancelled|CancellationCode|Diverted|CarrierDelay|WeatherDelay|NASDelay|SecurityDelay|LateAircraftDelay|
+----+-----+----------+---------+-------+----------+-------+----------+-------------+---------+-------+-----------------+--------------+-------+--------+--------+------+----+--------+-------+-------+---------+----------------+--------+------------+------------+--------+-------------+-----------------+
|0   |0    |0         |0        |50163  |0         |64599  |0         |0            |0     

                                                                                

In [66]:
# Train test split 80/20
train_data, test_data = df.randomSplit([0.8, 0.2], seed=89)

In [67]:
# Creating the pipeline
stages = []

### Data filtering

In [None]:
# We do not control that any of this variables has a null value, so if an instance has a null value, it is deleted
sql_logic = """
SELECT Year,
       Month,
       DayofMonth,
       DayOfWeek,
       DepTime,
       CRSDepTime,
       ArrTime,
       CRSArrTime,
       UniqueCarrier,
       FlightNum,
       TailNum,
       ActualElapsedTime,
       CRSElapsedTime,
       AirTime,
       ArrDelay,
       DepDelay,
       Origin,
       Dest,
       Distance,
       TaxiIn,
       Cancelled,
       CancellationCode,
       Diverted,
       CarrierDelay,
       WeatherDelay,
       NASDelay,
       SecurityDelay,
       LateAircraftDelay,
       COALESCE(CAST(TaxiOut AS INT), 0) AS TaxiOut
FROM __THIS__
WHERE CRSDepTime IS NOT NULL
  AND CRSArrTime IS NOT NULL
  AND ArrDelay IS NOT NULL
  AND Year IS NOT NULL
  AND Month IS NOT NULL
  AND DayofMonth IS NOT NULL
  AND DayOfWeek IS NOT NULL
  AND UniqueCarrier IS NOT NULL
  AND CRSElapsedTime IS NOT NULL
  AND DepDelay IS NOT NULL
  AND Origin IS NOT NULL
  AND Dest IS NOT NULL
  AND Distance IS NOT NULL
  AND Cancelled != 1 -- A cancelled flight is not considered a delay, so that it does not give us useful information.
"""
sql_clean = SQLTransformer(statement=sql_logic)

stages.append(sql_clean)

### Feature engineering and selection
The following SQL query fills all the missing values of ``TaxiOut`` with its median, also creates 2 new variables: ``TakeOffTime``, which signals the exact moment the plane takes off from the ground, and ``LandingEstimate``, that show a new approximate hour of landing taking into account when the plan took off.

In [69]:
#Drop the variables that contain information that is unknown at the time the plane takes off
sql_logic = """
SELECT 
    Year, Month, DayofMonth, DayOfWeek, DepTime, CRSDepTime, CRSArrTime, UniqueCarrier, CRSElapsedTime, ArrDelay, DepDelay, Origin, Dest, Distance,
    -- Mediana de TaxiOut
    COALESCE(
        (SELECT percentile_approx(TaxiOut, 0.5) FROM __THIS__), 
        0
    ) AS TaxiOut,

    -- Timestamp base
    CASE
        WHEN CRSDepTime = 2400 THEN
            to_timestamp(
                concat(
                    Year,
                    lpad(Month, 2, '0'),
                    lpad(DayofMonth + 1, 2, '0'),
                    '0000'
                ),
                'yyyyMMddHHmm'
            )
        ELSE
            to_timestamp(
                concat(
                    Year,
                    lpad(Month, 2, '0'),
                    lpad(DayofMonth, 2, '0'),
                    lpad(CRSDepTime, 4, '0')
                ),
                'yyyyMMddHHmm'
            )
    END AS CRSDepTimestamp,

    -- TakeOffTime (HHmm)
    CAST(
        date_format(
            CRSDepTimestamp
            + (CAST(DepDelay AS INT)
            + CAST(TaxiOut AS INT)) * INTERVAL 1 MINUTE,
            'HHmm'
        ) AS INT
    ) AS TakeOffTime,

    -- LandingEst (HHmm)
    CAST(
        date_format(
            CRSDepTimestamp
            + (CAST(DepDelay AS INT)
            + CAST(TaxiOut AS INT)
            + CAST(CRSElapsedTime AS INT)) * INTERVAL 1 MINUTE,
            'HHmm'
        ) AS INT
    ) AS LandingEst

FROM __THIS__
"""

sql_trans = SQLTransformer(statement=sql_logic)

stages.append(sql_trans)

Due to the ``HHmm`` format being not being very useful for machine learning, we are splitting the variables that follow this format into 2 separate variables, one for the hour and another one for the minutes.

In [70]:
time_split_sql = """
SELECT
    *,
    
    -- CRSDepTime
    CAST(FLOOR(CRSDepTime / 100) AS INT)      AS CRSDepTimeHour,
    CAST(CRSDepTime % 100 AS INT)             AS CRSDepTimeMinute,

    -- CRSArrTime
    CAST(FLOOR(CRSArrTime / 100) AS INT)      AS CRSArrTimeHour,
    CAST(CRSArrTime % 100 AS INT)             AS CRSArrTimeMinute,

    -- TakeOffTime
    CAST(FLOOR(TakeOffTime / 100) AS INT)     AS TakeOffTimeHour,
    CAST(TakeOffTime % 100 AS INT)            AS TakeOffTimeMinute,

    -- LandingEst
    CAST(FLOOR(LandingEst / 100) AS INT)      AS LandingEstHour,
    CAST(LandingEst % 100 AS INT)             AS LandingEstMinute,

    -- DepTime
    CAST(FLOOR(COALESCE(DepTime, CRSDepTime) / 100) AS INT) AS DepTimeHour,
    CAST((COALESCE(DepTime, CRSDepTime) % 100) AS INT) AS DepTimeMinute


FROM __THIS__
"""

time_splitter = SQLTransformer(statement=time_split_sql)

stages.append(time_splitter)

We drop all the ``HHmm`` variables, as well as the ``CRSDepTimestamp`` since it is just a variable used for creating others.

In [71]:
drop_hhmm_sql = """
SELECT
    * EXCEPT (
        CRSDepTime,
        CRSArrTime,
        TakeOffTime,
        LandingEst,
        DepTime,
        CRSDepTimestamp
    )
FROM __THIS__
"""

drop_hhmm = SQLTransformer(statement=drop_hhmm_sql)
stages.append(drop_hhmm)

In [72]:
# Categorical variables selected
cat_cols = ["UniqueCarrier", "Origin", "Dest", "Month", "DayOfWeek", "DayofMonth", "Year"]

# Numerical variables selected
num_cols = [
    "DepDelay", "TaxiOut", "Distance", "CRSElapsedTime", 
    "LandingEstMinute", "TakeOffTimeMinute", "CRSDepTimeMinute", "CRSArrTimeMinute", "DepTimeMinute",
    "LandingEstHour", "TakeOffTimeHour", "CRSDepTimeHour", "CRSArrTimeHour", "DepTimeHour"
]


# Categorical variables encoding
input_cols_ohe = []
categorical_stages = [] 

for c in cat_cols:
    indexer = StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep")
    
    encoder = OneHotEncoder(inputCol=f"{c}_idx", outputCol=f"{c}_ohe")
    
    categorical_stages += [indexer, encoder]
    input_cols_ohe.append(f"{c}_ohe")

stages += categorical_stages

# Numerical variables processing

# Grouping the numerical variables to a vector
num_assembler = VectorAssembler(inputCols=num_cols, outputCol="num_features_raw", handleInvalid="skip")
stages.append(num_assembler)

# Standartizing phase
scaler = StandardScaler(
    inputCol="num_features_raw", 
    outputCol="num_features_scaled", 
    withStd=True, 
    withMean=True
)
stages.append(scaler)

# Group new columns
assembler_all_inputs = input_cols_ohe + ["num_features_scaled"]

assembler_all = VectorAssembler(
    inputCols=assembler_all_inputs, 
    outputCol="features"
)
stages.append(assembler_all)


### Models to train

In [73]:

# Models
dt = DecisionTreeRegressor(labelCol="ArrDelay", featuresCol="features")
rf = RandomForestRegressor(labelCol="ArrDelay", featuresCol="features")
gbt = GBTRegressor(labelCol="ArrDelay", featuresCol="features")

# Use a separate pipeline for each model
pipeline_dt = Pipeline(stages=stages + [dt])
pipeline_rf = Pipeline(stages=stages + [rf])
pipeline_gbt = Pipeline(stages=stages + [gbt])





### Hyperparameter tunning

In [74]:
# Hyperparameter tuning
paramGrid_dt = (ParamGridBuilder()
    .addGrid(dt.maxDepth, [5])    # You can add more values if you have the technical muscle to do it. We don’t have it. RAM is very expensive.
    .addGrid(dt.maxBins, [32])    # You can add more values if you have the technical muscle to do it. We don’t have it. RAM is very expensive.
    .build())

paramGrid_rf = (ParamGridBuilder()
    .addGrid(rf.numTrees, [20])   # You can add more values if you have the technical muscle to do it. We don’t have it. RAM is very expensive.
    .addGrid(rf.maxDepth, [5])    # You can add more values if you have the technical muscle to do it. We don’t have it. RAM is very expensive.
    .build())

paramGrid_gbt = (ParamGridBuilder()
    .addGrid(gbt.maxIter, [10])   # You can add more values if you have the technical muscle to do it. We don’t have it. RAM is very expensive.
    .addGrid(gbt.maxDepth, [3])   # You can add more values if you have the technical muscle to do it. We don’t have it. RAM is very expensive.
    .build())

### Cross validation

In [75]:

evaluador = RegressionEvaluator(labelCol="ArrDelay", predictionCol="prediction", metricName="rmse")
cv_dt = CrossValidator(
    estimator=pipeline_dt, 
    estimatorParamMaps=paramGrid_dt,
    evaluator=evaluador,
    numFolds=3,
    seed=89,
    parallelism=1
)

cv_rf = CrossValidator(
    estimator=pipeline_rf,
    estimatorParamMaps=paramGrid_rf,
    evaluator=evaluador,
    numFolds=3,
    seed=89,
    parallelism=1
)

cv_gbt = CrossValidator(
    estimator=pipeline_gbt,
    estimatorParamMaps=paramGrid_gbt,
    evaluator=evaluador,
    numFolds=3,
    seed=89,
    parallelism=1
)


### Model evaluation

In [76]:

## Assign cross validation to model
modelos_cv = [("Decision Tree", cv_dt), ("Random Forest", cv_rf), ("GBT", cv_gbt)]
resultados = []

# Evaluator
eval_mae = RegressionEvaluator(labelCol="ArrDelay", metricName="mae")

for nombre, cv in modelos_cv:
    print(f"Initiating training for {nombre}...")
    
    # Clear cache and start garbage collector to gain memory on the JVM
    spark.catalog.clearCache()
    gc.collect()

    #Training
    cv_model = cv.fit(train_data) 
    
    #Testing
    predicciones = cv_model.transform(test_data)
    
    #Evaluate result
    rmse = evaluador.evaluate(predicciones) 
    mae = eval_mae.evaluate(predicciones)   
    
    # Select best evaluator metric
    score = mae if rmse > (2 * mae) else rmse
    metrica_usada = "MAE" if rmse > (2 * mae) else "RMSE"
    
    print(f"Results {nombre} -> RMSE: {rmse:.2f}, MAE: {mae:.2f} (Score: {score:.2f} using {metrica_usada})")
    
    resultados.append({
        "nombre": nombre,
        "modelo_fit": cv_model, 
        "rmse": rmse,
        "mae": mae,
        "score_final": score,
        "metrica": metrica_usada
    })

Initiating training for Decision Tree...


26/01/06 16:27:00 ERROR Executor: Exception in task 0.0 in stage 336.0 (TID 1931)
java.lang.IllegalArgumentException: requirement failed: Nothing has been added to this summarizer.
	at scala.Predef$.require(Predef.scala:337)
	at org.apache.spark.ml.stat.SummarizerBuffer.mean(Summarizer.scala:650)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.$anonfun$eval$1(Summarizer.scala:389)
	at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75)
	at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:388)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:355)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.eval(interfaces.scala:594)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateResultProjection$5(AggregationIterator.scala:260)
	at org.apache.spark.sql.execution.aggregate.Obje

Py4JJavaError: An error occurred while calling o10611.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 336.0 failed 1 times, most recent failure: Lost task 0.0 in stage 336.0 (TID 1931) (192.168.1.108 executor driver): java.lang.IllegalArgumentException: requirement failed: Nothing has been added to this summarizer.
	at scala.Predef$.require(Predef.scala:337)
	at org.apache.spark.ml.stat.SummarizerBuffer.mean(Summarizer.scala:650)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.$anonfun$eval$1(Summarizer.scala:389)
	at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75)
	at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:388)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:355)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.eval(interfaces.scala:594)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateResultProjection$5(AggregationIterator.scala:260)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:100)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:35)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:402)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:901)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:901)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:77)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:650)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1583)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$3(DAGScheduler.scala:2935)
	at scala.Option.getOrElse(Option.scala:201)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2935)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2927)
	at scala.collection.immutable.List.foreach(List.scala:334)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2927)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1295)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1295)
	at scala.Option.foreach(Option.scala:437)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1295)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3207)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3141)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3130)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:50)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1009)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2484)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2505)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2524)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2549)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1057)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:417)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1056)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:462)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.$anonfun$executeCollect$1(AdaptiveSparkPlanExec.scala:402)
	at org.apache.spark.sql.execution.adaptive.ResultQueryStageExec.$anonfun$doMaterialize$1(QueryStageExec.scala:325)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$4(SQLExecution.scala:322)
	at org.apache.spark.sql.execution.SQLExecution$.withSessionTagsApplied(SQLExecution.scala:272)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$3(SQLExecution.scala:320)
	at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$2(SQLExecution.scala:316)
	at java.base/java.util.concurrent.CompletableFuture$AsyncSupply.run(CompletableFuture.java:1768)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
	at java.base/java.lang.Thread.run(Thread.java:1583)
Caused by: java.lang.IllegalArgumentException: requirement failed: Nothing has been added to this summarizer.
	at scala.Predef$.require(Predef.scala:337)
	at org.apache.spark.ml.stat.SummarizerBuffer.mean(Summarizer.scala:650)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.$anonfun$eval$1(Summarizer.scala:389)
	at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75)
	at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:388)
	at org.apache.spark.ml.stat.SummaryBuilderImpl$MetricsAggregate.eval(Summarizer.scala:355)
	at org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate.eval(interfaces.scala:594)
	at org.apache.spark.sql.execution.aggregate.AggregationIterator.$anonfun$generateResultProjection$5(AggregationIterator.scala:260)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:100)
	at org.apache.spark.sql.execution.aggregate.ObjectAggregationIterator.next(ObjectAggregationIterator.scala:35)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:50)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:402)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:901)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:901)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:374)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:338)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:171)
	at org.apache.spark.scheduler.Task.run(Task.scala:147)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$5(Executor.scala:647)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:80)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:77)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:650)
	... 3 more


### Find and save best model

In [None]:

mejor_resultado = min(resultados, key=lambda x: x["score_final"])

print("-" * 30)
print(f"Best model: {mejor_resultado['nombre']}")
print(f"Criteria: {mejor_resultado['metrica']} of {mejor_resultado['score_final']:.4f}")
print("-" * 30)

# Retrive better functioning model
mejor_modelo_final = mejor_resultado['modelo_fit'].bestModel

# Save the model for app
path_guardado = "best_model"
mejor_modelo_final.write().overwrite().save(path_guardado)

print(f"The model '{mejor_resultado['nombre']}' has been saved at folder '{path_guardado}'")