In [1]:
from matplotlib import pyplot as plt
from pyspark.conf import SparkConf
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.tuning import CrossValidator
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.sql import SparkSession
from pyspark.sql.functions import when, col
from pyspark.sql.types import IntegerType


In [2]:
# Create PySpark SparkSession
conf: SparkConf = SparkConf().setAppName("ECD_TCC").setMaster("local[*]")
spark: SparkSession = SparkSession.builder.config(conf=conf).getOrCreate()


In [3]:
# Filtrando apenas casos de covid
df_raw = spark.read.option("delimiter", ";").option("header", True).csv("../datasets/raw").where("CLASSI_FIN == 5")


In [4]:
# Selecionando apenas as colunas com fatores de risco
select_cols = ["PUERPERA", "CARDIOPATI", "HEMATOLOGI", "SIND_DOWN", "HEPATICA", "ASMA", "DIABETES", "NEUROLOGIC",
               "PNEUMOPATI", "IMUNODEPRE", "RENAL", "OBESIDADE", "VACINA_COV", "VACINA", "EVOLUCAO"]


In [6]:
# Preenchendo valores nulos como ignorados e normalizando os preenchimentos de fatores de risco
df = df_raw.select(select_cols).where("VACINA_COV <> '12/02/2021'").cache()
for column in select_cols:
    df = df.withColumn(
        column,
        when(col(column) == "1", "Y")
        .when(col(column) == "2", "N")
    )
df = df.withColumnRenamed("EVOLUCAO", "label").where(col("label").isNotNull()).cache()


In [7]:
print("Dataset com os dados relevantes")
pandas_df = df.toPandas()
pandas_df.head()


Dataset com os dados relevantes


Unnamed: 0,PUERPERA,CARDIOPATI,HEMATOLOGI,SIND_DOWN,HEPATICA,ASMA,DIABETES,NEUROLOGIC,PNEUMOPATI,IMUNODEPRE,RENAL,OBESIDADE,VACINA_COV,VACINA,label
0,,,,,,,Y,,,,,,Y,Y,Y
1,,Y,,,,Y,,,Y,,,,,,Y
2,N,N,N,N,N,N,Y,N,Y,N,N,Y,N,N,Y
3,N,N,N,N,N,N,Y,N,N,N,N,N,N,,N
4,,,,,,,,,,,,,Y,,Y


In [8]:
pandas_df.to_csv("../datasets/processed/df.csv", index=False, header=True)

In [None]:
# Valores de camadas inciais e finais para o MLP
features_n = len(select_cols) - 1
classes_n = df.select("EVOLUCAO").distinct().count() + 1


AnalysisException: Column 'EVOLUCAO' does not exist. Did you mean one of the following? [HEPATICA, VACINA, ASMA, NEUROLOGIC, RENAL, DIABETES, HEMATOLOGI, OBESIDADE, PUERPERA, VACINA_COV, label, CARDIOPATI, IMUNODEPRE, PNEUMOPATI, SIND_DOWN];
'Project ['EVOLUCAO]
+- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, DIABETES#536, NEUROLOGIC#552, PNEUMOPATI#568, IMUNODEPRE#584, RENAL#600, OBESIDADE#616, VACINA_COV#632, VACINA#648, EVOLUCAO#664 AS label#680]
   +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, DIABETES#536, NEUROLOGIC#552, PNEUMOPATI#568, IMUNODEPRE#584, RENAL#600, OBESIDADE#616, VACINA_COV#632, VACINA#648, CASE WHEN (EVOLUCAO#126 = 1) THEN Y WHEN (EVOLUCAO#126 = 2) THEN N END AS EVOLUCAO#664]
      +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, DIABETES#536, NEUROLOGIC#552, PNEUMOPATI#568, IMUNODEPRE#584, RENAL#600, OBESIDADE#616, VACINA_COV#632, CASE WHEN (VACINA#73 = 1) THEN Y WHEN (VACINA#73 = 2) THEN N END AS VACINA#648, EVOLUCAO#126]
         +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, DIABETES#536, NEUROLOGIC#552, PNEUMOPATI#568, IMUNODEPRE#584, RENAL#600, OBESIDADE#616, CASE WHEN (VACINA_COV#171 = 1) THEN Y WHEN (VACINA_COV#171 = 2) THEN N END AS VACINA_COV#632, VACINA#73, EVOLUCAO#126]
            +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, DIABETES#536, NEUROLOGIC#552, PNEUMOPATI#568, IMUNODEPRE#584, RENAL#600, CASE WHEN (OBESIDADE#69 = 1) THEN Y WHEN (OBESIDADE#69 = 2) THEN N END AS OBESIDADE#616, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
               +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, DIABETES#536, NEUROLOGIC#552, PNEUMOPATI#568, IMUNODEPRE#584, CASE WHEN (RENAL#68 = 1) THEN Y WHEN (RENAL#68 = 2) THEN N END AS RENAL#600, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                  +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, DIABETES#536, NEUROLOGIC#552, PNEUMOPATI#568, CASE WHEN (IMUNODEPRE#67 = 1) THEN Y WHEN (IMUNODEPRE#67 = 2) THEN N END AS IMUNODEPRE#584, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                     +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, DIABETES#536, NEUROLOGIC#552, CASE WHEN (PNEUMOPATI#66 = 1) THEN Y WHEN (PNEUMOPATI#66 = 2) THEN N END AS PNEUMOPATI#568, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                        +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, DIABETES#536, CASE WHEN (NEUROLOGIC#65 = 1) THEN Y WHEN (NEUROLOGIC#65 = 2) THEN N END AS NEUROLOGIC#552, PNEUMOPATI#66, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                           +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, ASMA#520, CASE WHEN (DIABETES#64 = 1) THEN Y WHEN (DIABETES#64 = 2) THEN N END AS DIABETES#536, NEUROLOGIC#65, PNEUMOPATI#66, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                              +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, HEPATICA#504, CASE WHEN (ASMA#63 = 1) THEN Y WHEN (ASMA#63 = 2) THEN N END AS ASMA#520, DIABETES#64, NEUROLOGIC#65, PNEUMOPATI#66, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                                 +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, SIND_DOWN#488, CASE WHEN (HEPATICA#62 = 1) THEN Y WHEN (HEPATICA#62 = 2) THEN N END AS HEPATICA#504, ASMA#63, DIABETES#64, NEUROLOGIC#65, PNEUMOPATI#66, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                                    +- Project [PUERPERA#440, CARDIOPATI#456, HEMATOLOGI#472, CASE WHEN (SIND_DOWN#61 = 1) THEN Y WHEN (SIND_DOWN#61 = 2) THEN N END AS SIND_DOWN#488, HEPATICA#62, ASMA#63, DIABETES#64, NEUROLOGIC#65, PNEUMOPATI#66, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                                       +- Project [PUERPERA#440, CARDIOPATI#456, CASE WHEN (HEMATOLOGI#60 = 1) THEN Y WHEN (HEMATOLOGI#60 = 2) THEN N END AS HEMATOLOGI#472, SIND_DOWN#61, HEPATICA#62, ASMA#63, DIABETES#64, NEUROLOGIC#65, PNEUMOPATI#66, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                                          +- Project [PUERPERA#440, CASE WHEN (CARDIOPATI#59 = 1) THEN Y WHEN (CARDIOPATI#59 = 2) THEN N END AS CARDIOPATI#456, HEMATOLOGI#60, SIND_DOWN#61, HEPATICA#62, ASMA#63, DIABETES#64, NEUROLOGIC#65, PNEUMOPATI#66, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                                             +- Project [CASE WHEN (PUERPERA#57 = 1) THEN Y WHEN (PUERPERA#57 = 2) THEN N END AS PUERPERA#440, CARDIOPATI#59, HEMATOLOGI#60, SIND_DOWN#61, HEPATICA#62, ASMA#63, DIABETES#64, NEUROLOGIC#65, PNEUMOPATI#66, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                                                +- Filter NOT (VACINA_COV#171 = 12/02/2021)
                                                   +- Project [PUERPERA#57, CARDIOPATI#59, HEMATOLOGI#60, SIND_DOWN#61, HEPATICA#62, ASMA#63, DIABETES#64, NEUROLOGIC#65, PNEUMOPATI#66, IMUNODEPRE#67, RENAL#68, OBESIDADE#69, VACINA_COV#171, VACINA#73, EVOLUCAO#126]
                                                      +- Filter (cast(CLASSI_FIN#123 as int) = 5)
                                                         +- Relation [DT_NOTIFIC#17,SEM_NOT#18,DT_SIN_PRI#19,SEM_PRI#20,SG_UF_NOT#21,ID_REGIONA#22,CO_REGIONA#23,ID_MUNICIP#24,CO_MUN_NOT#25,ID_UNIDADE#26,CO_UNI_NOT#27,CS_SEXO#28,DT_NASC#29,NU_IDADE_N#30,TP_IDADE#31,COD_IDADE#32,CS_GESTANT#33,CS_RACA#34,CS_ESCOL_N#35,ID_PAIS#36,CO_PAIS#37,SG_UF#38,ID_RG_RESI#39,CO_RG_RESI#40,... 142 more fields] csv


In [None]:
# Juntar as variáveis em um vetor de features
assembler = VectorAssembler(inputCols=feature_list, outputCol="features")
assembledDF = assembler.transform(df)


IllegalArgumentException: Data type string of column PUERPERA is not supported.
Data type string of column CARDIOPATI is not supported.
Data type string of column HEMATOLOGI is not supported.
Data type string of column SIND_DOWN is not supported.
Data type string of column HEPATICA is not supported.
Data type string of column ASMA is not supported.
Data type string of column DIABETES is not supported.
Data type string of column NEUROLOGIC is not supported.
Data type string of column PNEUMOPATI is not supported.
Data type string of column IMUNODEPRE is not supported.
Data type string of column RENAL is not supported.
Data type string of column OBESIDADE is not supported.
Data type string of column VACINA_COV is not supported.
Data type string of column VACINA is not supported.

In [None]:
# Modelos a testar: regressão logística, Floresta aleatória e multilayerperceptron
lr = LogisticRegression().setFamily("binomial")
rf = RandomForestClassifier()
mlp = MultilayerPerceptronClassifier()


In [None]:
# Pipelines de transformações para a validação cruzada
pipeline_lr = Pipeline(stages=[rf])
pipeline_rf = Pipeline(stages=[rf])
pipeline_mlp = Pipeline(stages=[mlp])


In [None]:
# Parâmetros para a validação cruzada
paramGrid_lr = (
    ParamGridBuilder()
    .addGrid(lr.regParam, [0.01, 0.1, 0.5, 1.0, 2.0])
    .addGrid(lr.elasticNetParam, [0.0, 0.25, 0.5, 0.75, 1.0])
    .addGrid(lr.maxIter, [1, 5, 10, 20, 50])
    .build()
)
paramGrid_rf = (
    ParamGridBuilder()
    .addGrid(rf.numTrees, [10, 30, 50])
    .addGrid(rf.maxDepth, [5, 10, 15])
    .build()
)
paramGrid_mlp = (
    ParamGridBuilder()
    .addGrid(mlp.layers, [
        [features_n, 3, 3, classes_n],
        [features_n, 4, 4, classes_n],
        [features_n, 3, 3, 3, classes_n],
        [features_n, 4, 4, 4, classes_n],
    ]).build()
)

In [None]:
# Validador cruzado
crossval_lr = CrossValidator(estimator=pipeline_lr, estimatorParamMaps=paramGrid_lr,
                             evaluator=MulticlassClassificationEvaluator())
crossval_rf = CrossValidator(estimator=pipeline_rf, estimatorParamMaps=paramGrid_rf,
                             evaluator=MulticlassClassificationEvaluator())
crossval_mlp = CrossValidator(estimator=pipeline_mlp, estimatorParamMaps=paramGrid_mlp,
                              evaluator=MulticlassClassificationEvaluator())


In [None]:
# Separar os dados em treino e teste
taxa_de_treino = 0.00001
(trainingData, testData) = df.randomSplit([taxa_de_treino, 1 - taxa_de_treino])


In [None]:
# Salvando as massa utilizadas em treino e teste para fazer análises depois de gerar os modelos
trainingData.write.mode("overwrite").parquet("../datasets/training")
testData.write.mode("overwrite").parquet("../datasets/test")

In [None]:
crossval_lr_model = crossval_lr.fit(assembledDF)

Py4JJavaError: An error occurred while calling o11341.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 477.0 failed 1 times, most recent failure: Lost task 0.0 in stage 477.0 (TID 2762) (192.168.1.113 executor driver): scala.MatchError: [null,1.0,[-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0]] (of class org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema)
	at org.apache.spark.ml.PredictorParams.$anonfun$extractInstances$1(Predictor.scala:81)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator$SliceIterator.next(Iterator.scala:273)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at scala.collection.AbstractIterator.to(Iterator.scala:1431)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1431)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$take$2(RDD.scala:1470)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2268)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:833)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2249)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2268)
	at org.apache.spark.rdd.RDD.$anonfun$take$1(RDD.scala:1470)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.take(RDD.scala:1443)
	at org.apache.spark.ml.tree.impl.DecisionTreeMetadata$.buildMetadata(DecisionTreeMetadata.scala:119)
	at org.apache.spark.ml.tree.impl.RandomForest$.run(RandomForest.scala:274)
	at org.apache.spark.ml.classification.RandomForestClassifier.$anonfun$train$1(RandomForestClassifier.scala:161)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:138)
	at org.apache.spark.ml.classification.RandomForestClassifier.train(RandomForestClassifier.scala:46)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:115)
	at jdk.internal.reflect.GeneratedMethodAccessor131.invoke(Unknown Source)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: scala.MatchError: [null,1.0,[-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0.0,0.0]] (of class org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema)
	at org.apache.spark.ml.PredictorParams.$anonfun$extractInstances$1(Predictor.scala:81)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at scala.collection.Iterator$SliceIterator.next(Iterator.scala:273)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at scala.collection.AbstractIterator.to(Iterator.scala:1431)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1431)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at scala.collection.AbstractIterator.toArray(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$take$2(RDD.scala:1470)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2268)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	... 1 more


In [None]:
crossval_lr_model.write().save("model/lr")

In [None]:
# Avaliando o modelo
predictions = cvModel.transform(testData)
evaluator = MulticlassClassificationEvaluator()
rmse = evaluator.evaluate(predictions)

predictions.select("label", "features", "rawPrediction",
                   "probability", "prediction").show()
predictions.select("prediction").distinct().show()

result = predictions.toPandas()

plt.plot(result.label, result.prediction, 'bo')
plt.xlabel('Sobrevivencia')
plt.ylabel('Prediction')
plt.suptitle("Model Performance RMSE: %f" % rmse)
plt.show()


In [None]:
# Selecionando o melhor modelo
bestPipeline = cvModel.bestModel
bestModel = bestPipeline.stages[2]

importances = bestModel.featureImportances

x_values = list(range(len(importances)))

plt.bar(x_values, importances, orientation='vertical')
plt.xticks(x_values, feature_list, rotation=40)
plt.ylabel('Importance')
plt.xlabel('Feature')
plt.title('Feature Importances')


In [None]:
print('numTrees - ', bestModel.getNumTrees)
print('maxDepth - ', bestModel.getOrDefault('maxDepth'))
