In [None]:
!pip install pyspark

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, when, lit
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline

In [2]:
spark = SparkSession.builder.getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/12 15:59:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df = spark.read.parquet(
    "/kaggle/input/fda-cleaned-data/part-00000-f7897fa7-5c35-48b0-808e-96b66af99d3b-c000.snappy.parquet",
    "/kaggle/input/fda-cleaned-data/part-00001-f7897fa7-5c35-48b0-808e-96b66af99d3b-c000.snappy.parquet"
)

                                                                                

In [4]:
# --- Set top-K values ---
top_k_reactions = 100
top_k_drugs = 50

# --- Get top K reactions ---
top_reactions = df.groupBy("reaction").count().orderBy("count", ascending=False).limit(top_k_reactions)
top_reaction_values = [row["reaction"] for row in top_reactions.collect()]

# Replace rare reactions with "other"
df = df.withColumn(
    "reaction_filtered",
    when(col("reaction").isin(top_reaction_values), col("reaction")).otherwise(lit("other"))
)

# --- Get top K drugs ---
top_drugs = df.groupBy("drug").count().orderBy("count", ascending=False).limit(top_k_drugs)
top_drug_values = [row["drug"] for row in top_drugs.collect()]

# Replace rare drugs with "other"
df = df.withColumn(
    "drug_filtered",
    when(col("drug").isin(top_drug_values), col("drug")).otherwise(lit("other"))
)

                                                                                

In [16]:
df.toParquet()

+----+---+-------+--------+----------------+--------------------+---------+-----------------+-------------+
| age|sex|country|reaction|reaction_outcome|                drug|age_group|reaction_filtered|drug_filtered|
+----+---+-------+--------+----------------+--------------------+---------+-----------------+-------------+
|77.0|2.0|     US|Delirium|             1.0|            NUPLAZID|    71-80|            other|        other|
|77.0|2.0|     US|Delirium|             1.0|             ASPIRIN|    71-80|            other|      ASPIRIN|
|77.0|2.0|     US|Delirium|             1.0|        ATORVASTATIN|    71-80|            other|        other|
|77.0|2.0|     US|Delirium|             1.0|            BACLOFEN|    71-80|            other|        other|
|77.0|2.0|     US|Delirium|             1.0|          BUDESONIDE|    71-80|            other|   BUDESONIDE|
|77.0|2.0|     US|Delirium|             1.0|  CARBIDOPA\LEVODOPA|    71-80|            other|        other|
|77.0|2.0|     US|Delirium| 

In [15]:
categorical_cols = ["country", "reaction_filtered", "drug_filtered", "age_group"]


# Create StringIndexers for each
indexers = [
    StringIndexer(inputCol=col, outputCol=f"{col}_indexed", handleInvalid="keep")
    for col in categorical_cols
]

# Fit and transform using a pipeline
pipeline = Pipeline(stages=indexers)
df_indexed = pipeline.fit(df).transform(df_indexed if 'df_indexed' in locals() else df)

                                                                                

In [7]:
df_indexed = df_indexed.withColumnRenamed("reaction_outcome", "label")

In [7]:
from pyspark.ml.feature import VectorAssembler

feature_cols = ["age", "sex"] + [f"{col}_indexed" for col in categorical_cols]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

ml_df = assembler.transform(df_indexed).select("features", "label")

In [8]:
ml_df.show(10)

+--------------------+-----+
|            features|label|
+--------------------+-----+
|[77.0,2.0,1.0,0.0...|  1.0|
|[77.0,2.0,1.0,0.0...|  1.0|
|[77.0,2.0,1.0,0.0...|  1.0|
|[77.0,2.0,1.0,0.0...|  1.0|
|[77.0,2.0,1.0,0.0...|  1.0|
|[77.0,2.0,1.0,0.0...|  1.0|
|[77.0,2.0,1.0,0.0...|  1.0|
|[77.0,2.0,1.0,0.0...|  1.0|
|[77.0,2.0,1.0,0.0...|  1.0|
|[77.0,2.0,1.0,0.0...|  1.0|
+--------------------+-----+
only showing top 10 rows



In [9]:
ml_df.groupBy("label").count().orderBy("count", ascending=False).show()



+-----+--------+
|label|   count|
+-----+--------+
|  5.0|12891444|
|  6.0| 7636419|
|  3.0| 5799012|
|  1.0| 2013425|
|  2.0|  758002|
|  4.0|   52808|
+-----+--------+



                                                                                

In [10]:
train_df, test_df = ml_df.randomSplit([0.8, 0.2], seed=42)
print(f"Train count: {train_df.count()}, Test count: {test_df.count()}")



Train count: 23324234, Test count: 5826876


                                                                                

In [11]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=50, maxBins=200)
rf_model = rf.fit(train_df)

25/07/08 00:06:20 WARN MemoryStore: Not enough space to cache rdd_104_1 in memory! (computed 101.2 MiB so far)
25/07/08 00:06:20 WARN BlockManager: Persisting block rdd_104_1 to disk instead.
25/07/08 00:06:34 WARN MemoryStore: Not enough space to cache rdd_104_0 in memory! (computed 151.9 MiB so far)
25/07/08 00:06:34 WARN BlockManager: Persisting block rdd_104_0 to disk instead.
25/07/08 00:07:54 WARN MemoryStore: Not enough space to cache rdd_104_1 in memory! (computed 360.0 MiB so far)
25/07/08 00:08:33 WARN MemoryStore: Not enough space to cache rdd_104_0 in memory! (computed 360.0 MiB so far)
25/07/08 00:09:55 WARN MemoryStore: Not enough space to cache rdd_104_1 in memory! (computed 151.9 MiB so far)
25/07/08 00:09:55 WARN MemoryStore: Not enough space to cache rdd_104_0 in memory! (computed 228.0 MiB so far)
25/07/08 00:12:15 WARN MemoryStore: Not enough space to cache rdd_104_1 in memory! (computed 151.9 MiB so far)
25/07/08 00:12:15 WARN MemoryStore: Not enough space to cache

In [12]:
predictions = rf_model.transform(test_df)
predictions.select("prediction", "label", "features").show(5, truncate=False)

[Stage 46:>                                                         (0 + 1) / 1]

+----------+-----+--------------+
|prediction|label|features      |
+----------+-----+--------------+
|5.0       |6.0  |(6,[0],[41.0])|
|5.0       |6.0  |(6,[0],[41.0])|
|5.0       |6.0  |(6,[0],[41.0])|
|5.0       |6.0  |(6,[0],[41.0])|
|5.0       |6.0  |(6,[0],[41.0])|
+----------+-----+--------------+
only showing top 5 rows



                                                                                

In [13]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy"
)
accuracy = evaluator.evaluate(predictions)
print(f"Test Accuracy: {accuracy:.4f}")



Test Accuracy: 0.6347


                                                                                

In [14]:
f1_evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="f1"
)
f1 = f1_evaluator.evaluate(predictions)
print(f"F1 Score: {f1:.4f}")



F1 Score: 0.5914


                                                                                

In [15]:
predictions.groupBy("prediction").count().orderBy("count", ascending=False).show()

25/07/08 00:30:14 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/07/08 00:30:14 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/07/08 00:30:17 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/07/08 00:30:18 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/07/08 00:30:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/07/08 00:30:22 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/07/08 00:30:25 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/07/08 00:30:25 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/07/08 00:30:28 WARN RowBasedKeyValueBatch: Calling spill() on

+----------+-------+
|prediction|  count|
+----------+-------+
|       5.0|3406463|
|       6.0|1257206|
|       3.0|1160637|
|       1.0|   2570|
+----------+-------+



                                                                                

In [22]:
from pyspark.ml.classification import (
    LogisticRegression,
    DecisionTreeClassifier,
    RandomForestClassifier,
    GBTClassifier,
    NaiveBayes
)
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# 🚀 Define your models
models = {
    "LogisticRegression": LogisticRegression(labelCol="label", featuresCol="features"),
    # "DecisionTree": DecisionTreeClassifier(labelCol="label", featuresCol="features"),
    # "RandomForest": RandomForestClassifier(labelCol="label", featuresCol="features"),
    "GBT": GBTClassifier(labelCol="label", featuresCol="features"),
    "NaiveBayes": NaiveBayes(labelCol="label", featuresCol="features")
}

# 🎯 Evaluator
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")

# 🔄 Train + evaluate each model
for name, model in models.items():
    fitted_model = model.fit(train_df)
    predictions = fitted_model.transform(test_df)
    acc = evaluator.evaluate(predictions)
    print(f"{name} Accuracy: {acc:.4f}")

25/07/08 01:02:40 WARN MemoryStore: Not enough space to cache rdd_288_1 in memory! (computed 177.0 MiB so far)
25/07/08 01:02:40 WARN BlockManager: Persisting block rdd_288_1 to disk instead.
25/07/08 01:02:53 WARN MemoryStore: Not enough space to cache rdd_288_0 in memory! (computed 113.0 MiB so far)
25/07/08 01:02:53 WARN BlockManager: Persisting block rdd_288_0 to disk instead.
25/07/08 01:03:10 WARN MemoryStore: Not enough space to cache rdd_288_1 in memory! (computed 272.8 MiB so far)
25/07/08 01:03:36 WARN MemoryStore: Not enough space to cache rdd_288_0 in memory! (computed 419.0 MiB so far)
25/07/08 01:03:44 WARN MemoryStore: Not enough space to cache rdd_288_0 in memory! (computed 177.0 MiB so far)
25/07/08 01:03:44 WARN MemoryStore: Not enough space to cache rdd_288_1 in memory! (computed 177.0 MiB so far)
25/07/08 01:03:54 WARN MemoryStore: Not enough space to cache rdd_288_1 in memory! (computed 177.0 MiB so far)
25/07/08 01:03:54 WARN MemoryStore: Not enough space to cache

LogisticRegression Accuracy: 0.5250


25/07/08 01:16:41 ERROR Executor: Exception in task 0.0 in stage 182.0 (TID 624)
java.lang.RuntimeException: Labels MUST be in {0, 1}, but got 6.0
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at sc

Py4JJavaError: An error occurred while calling o921.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 182.0 failed 1 times, most recent failure: Lost task 0.0 in stage 182.0 (TID 624) (fed674ed51ff executor driver): java.lang.RuntimeException: Labels MUST be in {0, 1}, but got 6.0
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$SliceIterator.hasNext(Iterator.scala:268)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at scala.collection.AbstractIterator.to(Iterator.scala:1431)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1431)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$take$2(RDD.scala:1492)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
	at org.apache.spark.rdd.RDD.$anonfun$take$1(RDD.scala:1492)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.take(RDD.scala:1465)
	at org.apache.spark.ml.tree.impl.DecisionTreeMetadata$.buildMetadata(DecisionTreeMetadata.scala:119)
	at org.apache.spark.ml.tree.impl.GradientBoostedTrees$.boost(GradientBoostedTrees.scala:333)
	at org.apache.spark.ml.tree.impl.GradientBoostedTrees$.run(GradientBoostedTrees.scala:61)
	at org.apache.spark.ml.classification.GBTClassifier.$anonfun$train$1(GBTClassifier.scala:201)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.classification.GBTClassifier.train(GBTClassifier.scala:170)
	at org.apache.spark.ml.classification.GBTClassifier.train(GBTClassifier.scala:58)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:114)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: java.lang.RuntimeException: Labels MUST be in {0, 1}, but got 6.0
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$SliceIterator.hasNext(Iterator.scala:268)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at scala.collection.AbstractIterator.to(Iterator.scala:1431)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1431)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$take$2(RDD.scala:1492)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more
