In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *

spark = SparkSession \
    .builder \
    .appName("baseline") \
    .getOrCreate()

In [3]:
# read in the data
data_processed_path = '../data_processed/'

train_lab_dx = spark.read.format("csv") \
                  .option("header", "true") \
                  .option("inferSchema", "true") \
            .load(data_processed_path + "train_final.csv")
            
validation_lab_dx = spark.read.format("csv") \
                  .option("header", "true") \
                  .option("inferSchema", "true") \
            .load(data_processed_path + "validation_final.csv")

test_lab_dx = spark.read.format("csv") \
                  .option("header", "true") \
                  .option("inferSchema", "true") \
            .load(data_processed_path + "test_final.csv")

print(train_lab_dx.count())
print(validation_lab_dx.count())
print(test_lab_dx.count())
# train_lab_dx.schema

24089
6037
6120


In [4]:
# transform the data
from pyspark.sql.types import *
from pyspark.sql.functions import udf
from collections import namedtuple

# column_names = ['dx231_mean','dx146_mean','dx96_mean','dx85_mean','dx2616_mean','dx225_mean','dx86_mean','dx79_mean','dx40_mean','dx202_mean','dx215_mean','dx94_mean','dx81_mean','dx55_mean','dx655_mean','dx18_mean','dx232_mean','dx141_mean','dx243_mean','dx2613_mean','dx58_mean','dx181_mean','dx80_mean','dx178_mean','dx16_mean','Creatinine_mean','dx77_mean','dx250_mean','dx104_mean','dx124_mean','dx106_mean','dx147_mean','dx17_mean','dx137_mean','dx2619_mean','dx203_mean','dx119_mean','dx236_mean','dx663_mean','dx78_mean','dx2607_mean','Calcium_mean','dx90_mean','dx161_mean','dx91_mean','AST_mean','dx32_mean','dx204_mean','dx2615_mean','dx19_mean','dx83_mean','Potassium_mean','dx49_mean','dx175_mean','dx208_mean','dx97_mean','dx660_mean','dx139_mean','dx659_mean','WBC_mean','dx144_mean','dx2621_mean','dx116_mean','dx126_mean','dx100_mean','dx195_mean','dx29_mean','Bicarbonate_mean','dx229_mean','dx26_mean','dx2_mean','dx242_mean','dx15_mean','dx148_mean','dx131_mean','dx226_mean','dx127_mean','dx98_mean','dx2608_mean','dx166_mean','dx88_mean','dx44_mean','dx62_mean','dx209_mean','dx111_mean','dx120_mean','dx118_mean','dx103_mean','dx654_mean','dx151_mean','dx168_mean','dx125_mean','dx240_mean','dx114_mean','dx158_mean','dx252_mean','dx42_mean','Ammonia_mean','dx92_mean','dx257_mean','dx230_mean','dx76_mean','dx4_mean','dx13_mean','dx108_mean','dx64_mean','ALT_mean','dx129_mean','dx224_mean','dx8_mean','dx39_mean','dx237_mean','dx60_mean','dx5_mean','dx239_mean','dx23_mean','dx171_mean','dx33_mean','dx138_mean','dx123_mean','GENDER','dx653_mean','dx170_mean','dx162_mean','dx246_mean','dx93_mean','dx47_mean','dx6_mean','dx3_mean','dx200_mean','dx206_mean','dx251_mean','dx95_mean','dx241_mean','is_emergency','dx110_mean','dx84_mean','dx197_mean','dx143_mean','pO2_mean','dx211_mean','dx163_mean','dx670_mean','dx37_mean','dx140_mean','dx154_mean','dx121_mean','dx198_mean','dx651_mean','dx245_mean','dx207_mean','dx99_mean','dx50_mean','dx159_mean','dx27_mean','dx164_mean','dx253_mean','Hematocrit_mean','dx122_mean','dx35_mean','dx51_mean','Magnesium_mean','dx63_mean','dx2620_mean','dx22_mean','dx128_mean','dx133_mean','dx87_mean','dx82_mean','dx136_mean','dx201_mean','dx233_mean','dx152_mean','dx255_mean','dx249_mean','dx45_mean','dx135_mean','dx156_mean','Sodium_mean','dx132_mean','dx57_mean','dx228_mean','dx14_mean','dx256_mean','dx652_mean','dx130_mean','dx107_mean','dx234_mean','dx661_mean','dx59_mean','dx102_mean','dx43_mean','dx25_mean','dx213_mean','dx46_mean','dx36_mean','dx105_mean','dx212_mean','dx54_mean','dx112_mean','dx650_mean','dx157_mean','dx115_mean','dx244_mean','dx89_mean','dx117_mean','dx238_mean','dx657_mean','dx248_mean','dx210_mean','dx2618_mean','dx217_mean','dx145_mean','dx2617_mean','dx10_mean','dx48_mean','dx52_mean','Albumin_mean','dx38_mean','dx173_mean','dx155_mean','dx199_mean','dx142_mean','dx658_mean','dx662_mean','dx235_mean','dx101_mean','Temperature_mean','dx53_mean','dx149_mean','dx134_mean','dx24_mean','dx259_mean','dx12_mean','dx11_mean','dx113_mean','dx205_mean','age','dx165_mean','dx109_mean','is_urgent','dx247_mean','dx2603_mean','dx160_mean','dx153_mean','UreaNitrogen_mean','dx7_mean','Bilirubin_mean','pH_mean','admission_LOS_days','readmission']
column_names = ['dx231_mean','dx146_mean','dx119_mean','dx85_mean','dx87_mean','Chloride_mean','dx105_mean','dx2613_mean','dx2618_mean','dx136_mean','dx202_mean','dx246_mean','dx215_mean','dx252_mean','dx81_mean','dx154_mean','dx655_mean','dx18_mean','dx96_mean','dx145_mean','dx243_mean','dx58_mean','dx181_mean','dx80_mean','dx178_mean','dx16_mean','Creatinine_mean','dx77_mean','Globulin_mean','dx104_mean','dx124_mean','dx106_mean','dx147_mean','dx17_mean','dx137_mean','dx88_mean','dx91_mean','Triglycerides_mean','Albumin_mean','dx663_mean','dx55_mean','dx78_mean','dx2607_mean','dx90_mean','dx161_mean','dx203_mean','dx32_mean','dx209_mean','dx95_mean','UreaNitrogen_mean','dx2616_mean','dx204_mean','dx2615_mean','dx98_mean','dx19_mean','dx94_mean','dx83_mean','dx108_mean','dx49_mean','dx175_mean','dx208_mean','dx6_mean','dx2619_mean','dx139_mean','dx659_mean','WBC_mean','dx144_mean','dx116_mean','dx126_mean','dx100_mean','dx29_mean','Bicarbonate_mean','dx142_mean','dx26_mean','dx2_mean','dx242_mean','dx15_mean','dx148_mean','dx131_mean','dx226_mean','dx127_mean','dx212_mean','dx2608_mean','dx166_mean','dx241_mean','pCO2_mean','dx62_mean','Glucose_mean','dx111_mean','dx120_mean','dx103_mean','dx654_mean','dx151_mean','dx125_mean','dx114_mean','dx158_mean','dx660_mean','Hemoglobin_mean','dx42_mean','Ammonia_mean','dx92_mean','dx230_mean','dx64_mean','dx76_mean','dx135_mean','dx228_mean','Potassium_mean','dx107_mean','dx129_mean','dx244_mean','dx86_mean','dx8_mean','dx39_mean','dx237_mean','dx225_mean','dx199_mean','dx239_mean','dx23_mean','dx171_mean','dx33_mean','dx138_mean','dx247_mean','dx653_mean','dx170_mean','dx162_mean','Amylase_mean','dx118_mean','dx232_mean','dx134_mean','dx3_mean','dx200_mean','dx206_mean','dx251_mean','dx79_mean','dx650_mean','is_emergency','dx110_mean','dx84_mean','dx197_mean','dx60_mean','pO2_mean','dx44_mean','dx211_mean','dx163_mean','dx670_mean','dx37_mean','dx140_mean','dx93_mean','dx121_mean','dx198_mean','dx651_mean','LOS','dx99_mean','dx50_mean','dx159_mean','dx27_mean','dx164_mean','dx253_mean','Hematocrit_mean','dx122_mean','dx248_mean','GENDER','Magnesium_mean','dx63_mean','dx2620_mean','dx22_mean','dx128_mean','dx133_mean','dx207_mean','dx82_mean','dx168_mean','dx201_mean','dx250_mean','dx233_mean','AsparateAminotransferaseAST_mean','dx152_mean','dx255_mean','dx249_mean','dx45_mean','dx4_mean','dx156_mean','Sodium_mean','dx132_mean','dx57_mean','dx13_mean','dx14_mean','dx256_mean','dx652_mean','dx130_mean','dx165_mean','dx234_mean','dx661_mean','dx59_mean','dx102_mean','dx43_mean','dx245_mean','dx213_mean','dx46_mean','dx36_mean','dx257_mean','dx2621_mean','dx54_mean','dx112_mean','Protein_mean','dx157_mean','dx115_mean','dx97_mean','dx89_mean','dx117_mean','dx238_mean','dx657_mean','dx35_mean','dx210_mean','dx52_mean','dx217_mean','dx141_mean','dx2617_mean','dx10_mean','dx48_mean','dx51_mean','dx236_mean','dx38_mean','dx173_mean','dx155_mean','dx5_mean','dx229_mean','dx658_mean','dx662_mean','dx235_mean','dx101_mean','Temperature_mean','dx53_mean','dx149_mean','dx143_mean','dx47_mean','dx24_mean','dx259_mean','dx12_mean','dx11_mean','dx113_mean','dx205_mean','age','dx25_mean','dx109_mean','is_urgent','dx123_mean','dx2603_mean','dx160_mean','dx153_mean','dx7_mean','pH_mean','dx40_mean','readmission']

DataRow = namedtuple("DataRow", column_names + ["duplicate_labels"])

def organize_values(lab_dx_vals):
    lab_dx_rdd = lab_dx_vals.rdd.map(lambda row: (row.SUBJECT_ID, DataRow._make(list(row[1:]) + [[row.readmission]])))

    lab_dx_rdd = lab_dx_rdd.map(lambda row: row[1])
    # print(lab_dx_rdd.take(1))

    # gather some counts so we know how much to upsample
    num_without_readmission = lab_dx_rdd.map(lambda x: x.readmission == 0).sum()
    num_readmission = lab_dx_rdd.map(lambda x: x.readmission == 1).sum()
    print("num without readmission %s" % num_without_readmission)
    print("num readmission %s" % num_readmission)
    print("total %s" % lab_dx_rdd.count())
    upsample_amount = (num_without_readmission // num_readmission) + 1
    print("upsample_amount: {}".format(upsample_amount))

    def upsample(row):
        # 0 is the dominant class so don't need to upsample
        if row.readmission == 0:
            return row
        dup_labels = [row.readmission]*upsample_amount
        return DataRow._make(list(row[:-1]) + [dup_labels])

    lab_dx_rdd = lab_dx_rdd.map(lambda row: upsample(row))
    
    return lab_dx_rdd
    
train_lab_dx_rdd = organize_values(train_lab_dx)
validation_lab_dx_rdd = organize_values(validation_lab_dx)
test_lab_dx_rdd = organize_values(test_lab_dx)

# print(train_lab_dx_rdd.take(1))

num without readmission 21536
num readmission 2553
total 24089
upsample_amount: 9
num without readmission 5430
num readmission 607
total 6037
upsample_amount: 9
num without readmission 5498
num readmission 622
total 6120
upsample_amount: 9


In [5]:
from pyspark.sql.functions import explode

train_lab_dx_balanced = train_lab_dx_rdd.toDF()
train_lab_dx_balanced.groupBy('readmission').count().show()
validation_lab_dx_balanced = validation_lab_dx_rdd.toDF()
validation_lab_dx_balanced.groupBy('readmission').count().show()
test_lab_dx_balanced = test_lab_dx_rdd.toDF()
test_lab_dx_balanced.groupBy('readmission').count().show()

# balance the classes
train_lab_dx_balanced = train_lab_dx_rdd.toDF().withColumn("duplicate_labels", explode("duplicate_labels"))
validation_lab_dx_balanced = validation_lab_dx_rdd.toDF().withColumn("duplicate_labels", explode("duplicate_labels"))
test_lab_dx_balanced = test_lab_dx_rdd.toDF().withColumn("duplicate_labels", explode("duplicate_labels"))


train_lab_dx_balanced.groupBy('readmission').count().show()
validation_lab_dx_balanced.groupBy('readmission').count().show()
test_lab_dx_balanced.groupBy('readmission').count().show()

+-----------+-----+
|readmission|count|
+-----------+-----+
|          0|21536|
|          1| 2553|
+-----------+-----+

+-----------+-----+
|readmission|count|
+-----------+-----+
|          0| 5430|
|          1|  607|
+-----------+-----+

+-----------+-----+
|readmission|count|
+-----------+-----+
|          0| 5498|
|          1|  622|
+-----------+-----+

+-----------+-----+
|readmission|count|
+-----------+-----+
|          0|21536|
|          1|22977|
+-----------+-----+

+-----------+-----+
|readmission|count|
+-----------+-----+
|          0| 5430|
|          1| 5463|
+-----------+-----+

+-----------+-----+
|readmission|count|
+-----------+-----+
|          0| 5498|
|          1| 5598|
+-----------+-----+



In [6]:
# train logistic regression
from pyspark.mllib.classification import LogisticRegressionWithLBFGS
from pyspark.ml.classification import LogisticRegression
from pyspark.mllib.evaluation import BinaryClassificationMetrics
from pyspark.mllib.regression import LabeledPoint

points_train_simple = train_lab_dx_balanced.rdd.map(lambda row: LabeledPoint(row.readmission, \
                                    [row.LOS, row.age, row.GENDER]))
points_validation_simple = validation_lab_dx_balanced.rdd.map(lambda row: LabeledPoint(row.readmission, \
                                    [row.LOS, row.age, row.GENDER]))
points_test_simple = test_lab_dx_balanced.rdd.map(lambda row: LabeledPoint(row.readmission, \
                                    [row.LOS, row.age, row.GENDER]))

log_reg_simple = LogisticRegressionWithLBFGS.train(points_train_simple)


# [row.dx231_mean,row.dx146_mean,row.dx96_mean,row.dx85_mean,row.dx2616_mean,row.dx225_mean,row.dx86_mean,row.dx79_mean,row.dx40_mean,row.dx202_mean,row.dx215_mean,row.dx94_mean,row.dx81_mean,row.dx55_mean,row.dx655_mean,row.dx18_mean,row.dx232_mean,row.dx141_mean,row.dx243_mean,row.dx2613_mean,row.dx58_mean,row.dx181_mean,row.dx80_mean,row.dx178_mean,row.dx16_mean,row.Creatinine_mean,row.dx77_mean,row.dx250_mean,row.dx104_mean,row.dx124_mean,row.dx106_mean,row.dx147_mean,row.dx17_mean,row.dx137_mean,row.dx2619_mean,row.dx203_mean,row.dx119_mean,row.dx236_mean,row.dx663_mean,row.dx78_mean,row.dx2607_mean,row.Calcium_mean,row.dx90_mean,row.dx161_mean,row.dx91_mean,row.AST_mean,row.dx32_mean,row.dx204_mean,row.dx2615_mean,row.dx19_mean,row.dx83_mean,row.Potassium_mean,row.dx49_mean,row.dx175_mean,row.dx208_mean,row.dx97_mean,row.dx660_mean,row.dx139_mean,row.dx659_mean,row.WBC_mean,row.dx144_mean,row.dx2621_mean,row.dx116_mean,row.dx126_mean,row.dx100_mean,row.dx195_mean,row.dx29_mean,row.Bicarbonate_mean,row.dx229_mean,row.dx26_mean,row.dx2_mean,row.dx242_mean,row.dx15_mean,row.dx148_mean,row.dx131_mean,row.dx226_mean,row.dx127_mean,row.dx98_mean,row.dx2608_mean,row.dx166_mean,row.dx88_mean,row.dx44_mean,row.dx62_mean,row.dx209_mean,row.dx111_mean,row.dx120_mean,row.dx118_mean,row.dx103_mean,row.dx654_mean,row.dx151_mean,row.dx168_mean,row.dx125_mean,row.dx240_mean,row.dx114_mean,row.dx158_mean,row.dx252_mean,row.dx42_mean,row.Ammonia_mean,row.dx92_mean,row.dx257_mean,row.dx230_mean,row.dx76_mean,row.dx4_mean,row.dx13_mean,row.dx108_mean,row.dx64_mean,row.ALT_mean,row.dx129_mean,row.dx224_mean,row.dx8_mean,row.dx39_mean,row.dx237_mean,row.dx60_mean,row.dx5_mean,row.dx239_mean,row.dx23_mean,row.dx171_mean,row.dx33_mean,row.dx138_mean,row.dx123_mean,row.GENDER,row.dx653_mean,row.dx170_mean,row.dx162_mean,row.dx246_mean,row.dx93_mean,row.dx47_mean,row.dx6_mean,row.dx3_mean,row.dx200_mean,row.dx206_mean,row.dx251_mean,row.dx95_mean,row.dx241_mean,row.is_emergency,row.dx110_mean,row.dx84_mean,row.dx197_mean,row.dx143_mean,row.pO2_mean,row.dx211_mean,row.dx163_mean,row.dx670_mean,row.dx37_mean,row.dx140_mean,row.dx154_mean,row.dx121_mean,row.dx198_mean,row.dx651_mean,row.dx245_mean,row.dx207_mean,row.dx99_mean,row.dx50_mean,row.dx159_mean,row.dx27_mean,row.dx164_mean,row.dx253_mean,row.Hematocrit_mean,row.dx122_mean,row.dx35_mean,row.dx51_mean,row.Magnesium_mean,row.dx63_mean,row.dx2620_mean,row.dx22_mean,row.dx128_mean,row.dx133_mean,row.dx87_mean,row.dx82_mean,row.dx136_mean,row.dx201_mean,row.dx233_mean,row.dx152_mean,row.dx255_mean,row.dx249_mean,row.dx45_mean,row.dx135_mean,row.dx156_mean,row.Sodium_mean,row.dx132_mean,row.dx57_mean,row.dx228_mean,row.dx14_mean,row.dx256_mean,row.dx652_mean,row.dx130_mean,row.dx107_mean,row.dx234_mean,row.dx661_mean,row.dx59_mean,row.dx102_mean,row.dx43_mean,row.dx25_mean,row.dx213_mean,row.dx46_mean,row.dx36_mean,row.dx105_mean,row.dx212_mean,row.dx54_mean,row.dx112_mean,row.dx650_mean,row.dx157_mean,row.dx115_mean,row.dx244_mean,row.dx89_mean,row.dx117_mean,row.dx238_mean,row.dx657_mean,row.dx248_mean,row.dx210_mean,row.dx2618_mean,row.dx217_mean,row.dx145_mean,row.dx2617_mean,row.dx10_mean,row.dx48_mean,row.dx52_mean,row.Albumin_mean,row.dx38_mean,row.dx173_mean,row.dx155_mean,row.dx199_mean,row.dx142_mean,row.dx658_mean,row.dx662_mean,row.dx235_mean,row.dx101_mean,row.Temperature_mean,row.dx53_mean,row.dx149_mean,row.dx134_mean,row.dx24_mean,row.dx259_mean,row.dx12_mean,row.dx11_mean,row.dx113_mean,row.dx205_mean,row.age,row.dx165_mean,row.dx109_mean,row.is_urgent,row.dx247_mean,row.dx2603_mean,row.dx160_mean,row.dx153_mean,row.UreaNitrogen_mean,row.dx7_mean,row.Bilirubin_mean,row.pH_mean,row.admission_LOS_days]\
points_train_all = train_lab_dx_balanced.rdd.map(lambda row: LabeledPoint(row.readmission, \
                                    [row.dx231_mean,row.dx146_mean,row.dx119_mean,row.dx85_mean,row.dx87_mean,row.Chloride_mean,row.dx105_mean,row.dx2613_mean,row.dx2618_mean,row.dx136_mean,row.dx202_mean,row.dx246_mean,row.dx215_mean,row.dx252_mean,row.dx81_mean,row.dx154_mean,row.dx655_mean,row.dx18_mean,row.dx96_mean,row.dx145_mean,row.dx243_mean,row.dx58_mean,row.dx181_mean,row.dx80_mean,row.dx178_mean,row.dx16_mean,row.Creatinine_mean,row.dx77_mean,row.Globulin_mean,row.dx104_mean,row.dx124_mean,row.dx106_mean,row.dx147_mean,row.dx17_mean,row.dx137_mean,row.dx88_mean,row.dx91_mean,row.Triglycerides_mean,row.Albumin_mean,row.dx663_mean,row.dx55_mean,row.dx78_mean,row.pCO2_mean,row.dx90_mean,row.dx161_mean,row.dx203_mean,row.dx32_mean,row.dx209_mean,row.dx95_mean,row.UreaNitrogen_mean,row.dx2616_mean,row.dx204_mean,row.dx2615_mean,row.dx98_mean,row.dx19_mean,row.dx94_mean,row.dx83_mean,row.dx108_mean,row.dx49_mean,row.dx175_mean,row.dx208_mean,row.dx6_mean,row.dx2619_mean,row.dx139_mean,row.dx659_mean,row.WBC_mean,row.dx144_mean,row.dx116_mean,row.dx126_mean,row.dx100_mean,row.dx29_mean,row.Bicarbonate_mean,row.dx142_mean,row.dx26_mean,row.dx2_mean,row.dx242_mean,row.dx15_mean,row.dx148_mean,row.dx131_mean,row.dx226_mean,row.dx127_mean,row.dx212_mean,row.dx2608_mean,row.dx166_mean,row.dx241_mean,row.dx2607_mean,row.dx62_mean,row.Glucose_mean,row.dx111_mean,row.dx120_mean,row.dx103_mean,row.dx654_mean,row.dx151_mean,row.dx125_mean,row.dx114_mean,row.dx158_mean,row.dx660_mean,row.Hemoglobin_mean,row.dx42_mean,row.Ammonia_mean,row.dx92_mean,row.dx230_mean,row.dx64_mean,row.dx76_mean,row.dx135_mean,row.dx228_mean,row.Potassium_mean,row.dx107_mean,row.dx129_mean,row.dx244_mean,row.dx86_mean,row.dx8_mean,row.dx39_mean,row.dx237_mean,row.dx225_mean,row.dx199_mean,row.dx239_mean,row.dx23_mean,row.dx171_mean,row.dx33_mean,row.dx138_mean,row.dx247_mean,row.dx653_mean,row.dx170_mean,row.dx162_mean,row.Amylase_mean,row.dx118_mean,row.dx232_mean,row.dx134_mean,row.dx3_mean,row.dx200_mean,row.dx206_mean,row.dx251_mean,row.dx79_mean,row.dx650_mean,row.is_emergency,row.dx110_mean,row.dx84_mean,row.dx197_mean,row.dx60_mean,row.pO2_mean,row.dx44_mean,row.dx211_mean,row.dx163_mean,row.dx670_mean,row.dx37_mean,row.dx140_mean,row.dx93_mean,row.dx121_mean,row.dx198_mean,row.dx651_mean,row.LOS,row.dx99_mean,row.dx50_mean,row.dx159_mean,row.dx27_mean,row.dx164_mean,row.dx253_mean,row.Hematocrit_mean,row.dx122_mean,row.dx248_mean,row.GENDER,row.Magnesium_mean,row.dx63_mean,row.dx2620_mean,row.dx22_mean,row.dx128_mean,row.dx133_mean,row.dx207_mean,row.dx82_mean,row.dx168_mean,row.dx201_mean,row.dx250_mean,row.dx233_mean,row.AsparateAminotransferaseAST_mean,row.dx152_mean,row.dx255_mean,row.dx249_mean,row.dx45_mean,row.dx4_mean,row.dx156_mean,row.Sodium_mean,row.dx132_mean,row.dx57_mean,row.dx13_mean,row.dx14_mean,row.dx256_mean,row.dx652_mean,row.dx130_mean,row.dx165_mean,row.dx234_mean,row.dx661_mean,row.dx59_mean,row.dx102_mean,row.dx43_mean,row.dx245_mean,row.dx213_mean,row.dx46_mean,row.dx36_mean,row.dx257_mean,row.dx2621_mean,row.dx54_mean,row.dx112_mean,row.Protein_mean,row.dx157_mean,row.dx115_mean,row.dx97_mean,row.dx89_mean,row.dx117_mean,row.dx238_mean,row.dx657_mean,row.dx35_mean,row.dx210_mean,row.dx52_mean,row.dx217_mean,row.dx141_mean,row.dx2617_mean,row.dx10_mean,row.dx48_mean,row.dx51_mean,row.dx236_mean,row.dx38_mean,row.dx173_mean,row.dx155_mean,row.dx5_mean,row.dx229_mean,row.dx658_mean,row.dx662_mean,row.dx235_mean,row.dx101_mean,row.Temperature_mean,row.dx53_mean,row.dx149_mean,row.dx143_mean,row.dx47_mean,row.dx24_mean,row.dx259_mean,row.dx12_mean,row.dx11_mean,row.dx113_mean,row.dx205_mean,row.age,row.dx25_mean,row.dx109_mean,row.is_urgent,row.dx123_mean,row.dx2603_mean,row.dx160_mean,row.dx153_mean,row.dx7_mean,row.pH_mean,row.dx40_mean]))
points_validation_all = validation_lab_dx_balanced.rdd.map(lambda row: LabeledPoint(row.readmission, \
                                    [row.dx231_mean,row.dx146_mean,row.dx119_mean,row.dx85_mean,row.dx87_mean,row.Chloride_mean,row.dx105_mean,row.dx2613_mean,row.dx2618_mean,row.dx136_mean,row.dx202_mean,row.dx246_mean,row.dx215_mean,row.dx252_mean,row.dx81_mean,row.dx154_mean,row.dx655_mean,row.dx18_mean,row.dx96_mean,row.dx145_mean,row.dx243_mean,row.dx58_mean,row.dx181_mean,row.dx80_mean,row.dx178_mean,row.dx16_mean,row.Creatinine_mean,row.dx77_mean,row.Globulin_mean,row.dx104_mean,row.dx124_mean,row.dx106_mean,row.dx147_mean,row.dx17_mean,row.dx137_mean,row.dx88_mean,row.dx91_mean,row.Triglycerides_mean,row.Albumin_mean,row.dx663_mean,row.dx55_mean,row.dx78_mean,row.pCO2_mean,row.dx90_mean,row.dx161_mean,row.dx203_mean,row.dx32_mean,row.dx209_mean,row.dx95_mean,row.UreaNitrogen_mean,row.dx2616_mean,row.dx204_mean,row.dx2615_mean,row.dx98_mean,row.dx19_mean,row.dx94_mean,row.dx83_mean,row.dx108_mean,row.dx49_mean,row.dx175_mean,row.dx208_mean,row.dx6_mean,row.dx2619_mean,row.dx139_mean,row.dx659_mean,row.WBC_mean,row.dx144_mean,row.dx116_mean,row.dx126_mean,row.dx100_mean,row.dx29_mean,row.Bicarbonate_mean,row.dx142_mean,row.dx26_mean,row.dx2_mean,row.dx242_mean,row.dx15_mean,row.dx148_mean,row.dx131_mean,row.dx226_mean,row.dx127_mean,row.dx212_mean,row.dx2608_mean,row.dx166_mean,row.dx241_mean,row.dx2607_mean,row.dx62_mean,row.Glucose_mean,row.dx111_mean,row.dx120_mean,row.dx103_mean,row.dx654_mean,row.dx151_mean,row.dx125_mean,row.dx114_mean,row.dx158_mean,row.dx660_mean,row.Hemoglobin_mean,row.dx42_mean,row.Ammonia_mean,row.dx92_mean,row.dx230_mean,row.dx64_mean,row.dx76_mean,row.dx135_mean,row.dx228_mean,row.Potassium_mean,row.dx107_mean,row.dx129_mean,row.dx244_mean,row.dx86_mean,row.dx8_mean,row.dx39_mean,row.dx237_mean,row.dx225_mean,row.dx199_mean,row.dx239_mean,row.dx23_mean,row.dx171_mean,row.dx33_mean,row.dx138_mean,row.dx247_mean,row.dx653_mean,row.dx170_mean,row.dx162_mean,row.Amylase_mean,row.dx118_mean,row.dx232_mean,row.dx134_mean,row.dx3_mean,row.dx200_mean,row.dx206_mean,row.dx251_mean,row.dx79_mean,row.dx650_mean,row.is_emergency,row.dx110_mean,row.dx84_mean,row.dx197_mean,row.dx60_mean,row.pO2_mean,row.dx44_mean,row.dx211_mean,row.dx163_mean,row.dx670_mean,row.dx37_mean,row.dx140_mean,row.dx93_mean,row.dx121_mean,row.dx198_mean,row.dx651_mean,row.LOS,row.dx99_mean,row.dx50_mean,row.dx159_mean,row.dx27_mean,row.dx164_mean,row.dx253_mean,row.Hematocrit_mean,row.dx122_mean,row.dx248_mean,row.GENDER,row.Magnesium_mean,row.dx63_mean,row.dx2620_mean,row.dx22_mean,row.dx128_mean,row.dx133_mean,row.dx207_mean,row.dx82_mean,row.dx168_mean,row.dx201_mean,row.dx250_mean,row.dx233_mean,row.AsparateAminotransferaseAST_mean,row.dx152_mean,row.dx255_mean,row.dx249_mean,row.dx45_mean,row.dx4_mean,row.dx156_mean,row.Sodium_mean,row.dx132_mean,row.dx57_mean,row.dx13_mean,row.dx14_mean,row.dx256_mean,row.dx652_mean,row.dx130_mean,row.dx165_mean,row.dx234_mean,row.dx661_mean,row.dx59_mean,row.dx102_mean,row.dx43_mean,row.dx245_mean,row.dx213_mean,row.dx46_mean,row.dx36_mean,row.dx257_mean,row.dx2621_mean,row.dx54_mean,row.dx112_mean,row.Protein_mean,row.dx157_mean,row.dx115_mean,row.dx97_mean,row.dx89_mean,row.dx117_mean,row.dx238_mean,row.dx657_mean,row.dx35_mean,row.dx210_mean,row.dx52_mean,row.dx217_mean,row.dx141_mean,row.dx2617_mean,row.dx10_mean,row.dx48_mean,row.dx51_mean,row.dx236_mean,row.dx38_mean,row.dx173_mean,row.dx155_mean,row.dx5_mean,row.dx229_mean,row.dx658_mean,row.dx662_mean,row.dx235_mean,row.dx101_mean,row.Temperature_mean,row.dx53_mean,row.dx149_mean,row.dx143_mean,row.dx47_mean,row.dx24_mean,row.dx259_mean,row.dx12_mean,row.dx11_mean,row.dx113_mean,row.dx205_mean,row.age,row.dx25_mean,row.dx109_mean,row.is_urgent,row.dx123_mean,row.dx2603_mean,row.dx160_mean,row.dx153_mean,row.dx7_mean,row.pH_mean,row.dx40_mean]))
points_test_all = test_lab_dx_balanced.rdd.map(lambda row: LabeledPoint(row.readmission, \
                                    [row.dx231_mean,row.dx146_mean,row.dx119_mean,row.dx85_mean,row.dx87_mean,row.Chloride_mean,row.dx105_mean,row.dx2613_mean,row.dx2618_mean,row.dx136_mean,row.dx202_mean,row.dx246_mean,row.dx215_mean,row.dx252_mean,row.dx81_mean,row.dx154_mean,row.dx655_mean,row.dx18_mean,row.dx96_mean,row.dx145_mean,row.dx243_mean,row.dx58_mean,row.dx181_mean,row.dx80_mean,row.dx178_mean,row.dx16_mean,row.Creatinine_mean,row.dx77_mean,row.Globulin_mean,row.dx104_mean,row.dx124_mean,row.dx106_mean,row.dx147_mean,row.dx17_mean,row.dx137_mean,row.dx88_mean,row.dx91_mean,row.Triglycerides_mean,row.Albumin_mean,row.dx663_mean,row.dx55_mean,row.dx78_mean,row.pCO2_mean,row.dx90_mean,row.dx161_mean,row.dx203_mean,row.dx32_mean,row.dx209_mean,row.dx95_mean,row.UreaNitrogen_mean,row.dx2616_mean,row.dx204_mean,row.dx2615_mean,row.dx98_mean,row.dx19_mean,row.dx94_mean,row.dx83_mean,row.dx108_mean,row.dx49_mean,row.dx175_mean,row.dx208_mean,row.dx6_mean,row.dx2619_mean,row.dx139_mean,row.dx659_mean,row.WBC_mean,row.dx144_mean,row.dx116_mean,row.dx126_mean,row.dx100_mean,row.dx29_mean,row.Bicarbonate_mean,row.dx142_mean,row.dx26_mean,row.dx2_mean,row.dx242_mean,row.dx15_mean,row.dx148_mean,row.dx131_mean,row.dx226_mean,row.dx127_mean,row.dx212_mean,row.dx2608_mean,row.dx166_mean,row.dx241_mean,row.dx2607_mean,row.dx62_mean,row.Glucose_mean,row.dx111_mean,row.dx120_mean,row.dx103_mean,row.dx654_mean,row.dx151_mean,row.dx125_mean,row.dx114_mean,row.dx158_mean,row.dx660_mean,row.Hemoglobin_mean,row.dx42_mean,row.Ammonia_mean,row.dx92_mean,row.dx230_mean,row.dx64_mean,row.dx76_mean,row.dx135_mean,row.dx228_mean,row.Potassium_mean,row.dx107_mean,row.dx129_mean,row.dx244_mean,row.dx86_mean,row.dx8_mean,row.dx39_mean,row.dx237_mean,row.dx225_mean,row.dx199_mean,row.dx239_mean,row.dx23_mean,row.dx171_mean,row.dx33_mean,row.dx138_mean,row.dx247_mean,row.dx653_mean,row.dx170_mean,row.dx162_mean,row.Amylase_mean,row.dx118_mean,row.dx232_mean,row.dx134_mean,row.dx3_mean,row.dx200_mean,row.dx206_mean,row.dx251_mean,row.dx79_mean,row.dx650_mean,row.is_emergency,row.dx110_mean,row.dx84_mean,row.dx197_mean,row.dx60_mean,row.pO2_mean,row.dx44_mean,row.dx211_mean,row.dx163_mean,row.dx670_mean,row.dx37_mean,row.dx140_mean,row.dx93_mean,row.dx121_mean,row.dx198_mean,row.dx651_mean,row.LOS,row.dx99_mean,row.dx50_mean,row.dx159_mean,row.dx27_mean,row.dx164_mean,row.dx253_mean,row.Hematocrit_mean,row.dx122_mean,row.dx248_mean,row.GENDER,row.Magnesium_mean,row.dx63_mean,row.dx2620_mean,row.dx22_mean,row.dx128_mean,row.dx133_mean,row.dx207_mean,row.dx82_mean,row.dx168_mean,row.dx201_mean,row.dx250_mean,row.dx233_mean,row.AsparateAminotransferaseAST_mean,row.dx152_mean,row.dx255_mean,row.dx249_mean,row.dx45_mean,row.dx4_mean,row.dx156_mean,row.Sodium_mean,row.dx132_mean,row.dx57_mean,row.dx13_mean,row.dx14_mean,row.dx256_mean,row.dx652_mean,row.dx130_mean,row.dx165_mean,row.dx234_mean,row.dx661_mean,row.dx59_mean,row.dx102_mean,row.dx43_mean,row.dx245_mean,row.dx213_mean,row.dx46_mean,row.dx36_mean,row.dx257_mean,row.dx2621_mean,row.dx54_mean,row.dx112_mean,row.Protein_mean,row.dx157_mean,row.dx115_mean,row.dx97_mean,row.dx89_mean,row.dx117_mean,row.dx238_mean,row.dx657_mean,row.dx35_mean,row.dx210_mean,row.dx52_mean,row.dx217_mean,row.dx141_mean,row.dx2617_mean,row.dx10_mean,row.dx48_mean,row.dx51_mean,row.dx236_mean,row.dx38_mean,row.dx173_mean,row.dx155_mean,row.dx5_mean,row.dx229_mean,row.dx658_mean,row.dx662_mean,row.dx235_mean,row.dx101_mean,row.Temperature_mean,row.dx53_mean,row.dx149_mean,row.dx143_mean,row.dx47_mean,row.dx24_mean,row.dx259_mean,row.dx12_mean,row.dx11_mean,row.dx113_mean,row.dx205_mean,row.age,row.dx25_mean,row.dx109_mean,row.is_urgent,row.dx123_mean,row.dx2603_mean,row.dx160_mean,row.dx153_mean,row.dx7_mean,row.pH_mean,row.dx40_mean]))


# readmission_position = 28
# points_train_all = train_lab_dx_balanced.rdd.map(lambda row: LabeledPoint(row.readmission, \
#                                     list(row)[:readmission_position] + list(row)[readmission_position+1:]))
# points_validation_all = validation_lab_dx_balanced.rdd.map(lambda row: LabeledPoint(row.readmission, \
#                                     list(row)[:readmission_position] + list(row)[readmission_position+1:]))
# points_test_all = test_lab_dx_balanced.rdd.map(lambda row: LabeledPoint(row.readmission, \
#                                     list(row)[:readmission_position] + list(row)[readmission_position+1:]))

log_reg_all = LogisticRegressionWithLBFGS.train(points_train_all)


In [7]:
def check_results_calc(predictionAndLabels):    
    print("Num predict class 0: %s" % predictionAndLabels.map(lambda x: x[0] == 0.0).sum())
    print("Num predict class 1: %s" % predictionAndLabels.map(lambda x: x[0] == 1.0).sum())

    num_correct = predictionAndLabels.map(lambda x: x[0] == x[1]).sum()
    accuracy = num_correct / predictionAndLabels.count()
    print("Accuracy = %s" % accuracy)

    metrics = BinaryClassificationMetrics(predictionAndLabels)

    print("Area under precision-recall curve = %s" % metrics.areaUnderPR)

    print("Area under ROC curve = %s" % metrics.areaUnderROC)

# check results
def check_results(log_reg_model, point_vals):    
    predictionAndLabels = point_vals.map(lambda lp: (float(log_reg_model.predict(lp.features)), lp.label))
    check_results_calc(predictionAndLabels)
    
print('train logistic regression 3 features')
check_results(log_reg_simple, points_train_simple)
print('\nvalidation logistic regression 3 features')
check_results(log_reg_simple, points_validation_simple)
print('\ntest logistic regression 3 features')
check_results(log_reg_simple, points_test_simple)

print('\ntrain logistic regression')
check_results(log_reg_all, points_train_all)
print('\nvalidation logistic regression')
check_results(log_reg_all, points_validation_all)
print('\ntest logistic regression')
check_results(log_reg_all, points_test_all)

train logistic regression 3 features
Num predict class 0: 21574
Num predict class 1: 22939
Accuracy = 0.5633185810886707
Area under precision-recall curve = 0.6860459006088558
Area under ROC curve = 0.5628881312024604

validation logistic regression 3 features
Num predict class 0: 5320
Num predict class 1: 5573
Accuracy = 0.5536583126778665
Area under precision-recall curve = 0.6685582768390615
Area under ROC curve = 0.5536236237147338

test logistic regression 3 features
Num predict class 0: 5264
Num predict class 1: 5832
Accuracy = 0.5717375630857967
Area under precision-recall curve = 0.6862903824299297
Area under ROC curve = 0.5715127044151688

train logistic regression
Num predict class 0: 20256
Num predict class 1: 24257
Accuracy = 0.6404645833801361
Area under precision-recall curve = 0.7443571142633225
Area under ROC curve = 0.6391555305386141

validation logistic regression
Num predict class 0: 5151
Num predict class 1: 5742
Accuracy = 0.6240705039933903
Area under precision-r

In [8]:
from pyspark.ml.linalg import VectorUDT
from pyspark.sql.functions import udf

as_ml = udf(lambda v: v.asML() if v is not None else None, VectorUDT())

data_train=spark.createDataFrame(points_train_all, ['features', 'label']).withColumn('newfeatures',as_ml(col('features')))
data_validation=spark.createDataFrame(points_validation_all, ['features', 'label']).withColumn('newfeatures',as_ml(col('features')))
data_test=spark.createDataFrame(points_test_all, ['features', 'label']).withColumn('newfeatures',as_ml(col('features')))

In [9]:
from pyspark.ml import *
from pyspark.ml.classification import *
from pyspark.ml.feature import *
from pyspark.ml.evaluation import *
from pyspark.ml.feature import *



In [10]:
print(data_train.count())
print(data_validation.count())
print(data_test.count())

44513
10893
11096


In [11]:
data_train.take(1)

[Row(features=DenseVector([0.0, 0.0, 0.0, 0.0, 0.0, 104.25, 105.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 96.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7125, 0.0, 2.749, 0.0, 0.0, 106.0, 0.0, 0.0, 0.0, 0.0, 0.0, 163.6682, 2.9546, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 21.75, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 19.5101, 0.0, 0.0, 0.0, 0.0, 0.0, 24.125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 41.3209, 0.0, 107.3771, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.5076, 0.0, 59.9429, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0265, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 237.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 106.1954, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 100.5405, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 31.9875, 0.0, 0.0, 1.0, 2.2875, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 138.8243, 0.0, 0.0, 0.0, 0.0, 0.0,

In [12]:
# a = sc.parallelize(data)
data_train.printSchema()

root
 |-- features: vector (nullable = true)
 |-- label: double (nullable = true)
 |-- newfeatures: vector (nullable = true)



In [13]:
data_train.select("label").take(3)

[Row(label=0.0), Row(label=0.0), Row(label=0.0)]

In [14]:
# Random Forest
###########
from pyspark.mllib.classification import *
# Train a RandomForest model.

############

classifier = RandomForestClassifier(labelCol="label", featuresCol="newfeatures", maxDepth=8, numTrees=25)

# Chain indexers and forest in a Pipeline
pipeline = Pipeline(stages=[classifier])
# classifier.train(data_train)

# # Train model.
# trainingFeats=data_train.features
model = pipeline.fit(data_train)

# # Make predictions.
# testFeats=data_validation.features
predictions_rf_train = model.transform(data_train)
predictions_rf_validation = model.transform(data_validation)
predictions_rf_test = model.transform(data_test)

# # Select example rows to display. Comment it out for GBT
# predictions.select("prediction", "label", "probability").show(5)

classifierModel = model.stages[0]
print(classifierModel)  # summary only, 

RandomForestClassificationModel (uid=RandomForestClassifier_479bba33ff73d16960c2) with 25 trees


In [15]:
# Select (prediction, true label) and compute test error
def print_results(prediction_vals):
    evaluator = MulticlassClassificationEvaluator(
        labelCol="label", predictionCol="prediction", metricName="f1")
    f1=evaluator.evaluate(prediction_vals)
    print("f1 = %g" % f1)
    
    evaluator = MulticlassClassificationEvaluator(
        labelCol="label", predictionCol="prediction", metricName="accuracy")
    accuracy=evaluator.evaluate(prediction_vals)
    print("Accuracy = %g" % accuracy)
    
    evaluator = BinaryClassificationEvaluator(
        labelCol="label", rawPredictionCol="prediction", metricName="areaUnderPR")
    areaUnderPR = evaluator.evaluate(prediction_vals)
    print("Area under precision-recall curve = %g" % (areaUnderPR))
    
    evaluator = BinaryClassificationEvaluator(
        labelCol="label", rawPredictionCol="prediction", metricName="areaUnderROC")
    areaUnderROC = evaluator.evaluate(prediction_vals)
    print("Area under ROC curve = %g" % (areaUnderROC))

print('train random forest')
print_results(predictions_rf_train)
print('\nvalidation random forest')
print_results(predictions_rf_validation)
print('\ntest random forest')
print_results(predictions_rf_test)

train random forest
f1 = 0.706487
Accuracy = 0.706445
Area under precision-recall curve = 0.788916
Area under ROC curve = 0.706308

validation random forest
f1 = 0.625431
Accuracy = 0.626549
Area under precision-recall curve = 0.715133
Area under ROC curve = 0.626716

test random forest
f1 = 0.619702
Accuracy = 0.620494
Area under precision-recall curve = 0.713262
Area under ROC curve = 0.620917


In [16]:
# predictions = model.transform(beforeSampledDF)
# # Select (prediction, true label) and compute test error
evaluator = BinaryClassificationEvaluator(
    labelCol="label", rawPredictionCol="prediction", metricName="areaUnderROC")
# areaUnderROC = evaluator.evaluate(predictions)

# print("areaUnderROC = %g" % (areaUnderROC))

In [17]:
# path = '/group/cdac/test/Projects/Lab_Pipeline_rfModel'
path = 'Projects/Lab_Pipeline_rfModel'
# model.save(path)

In [18]:
# Tuning hyperparameters
# Create ParamGrid for Cross Validation
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

rfc = RandomForestClassifier(labelCol="label", featuresCol="newfeatures")

paramGrid = (ParamGridBuilder()
             .addGrid(rfc.maxDepth, [8, 10])
             .addGrid(rfc.maxBins, [16, 32])
             .addGrid(rfc.numTrees, [20, 30])
             .build())

In [19]:
# Create 3-fold CrossValidator
cv = CrossValidator(estimator=rfc, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3)

# Run cross validations.  This can take about 15 minutes since it is training over 3x2x3x3 = 54 Random Forests!
cvModel = cv.fit(data_train)

In [20]:
classifierModel = cvModel.bestModel

print(classifierModel.getNumTrees)
# print(classifierModel.featureImportances)
# print(classifierModel.trees)

20


In [21]:
### Evaluate tuned random forest ###
predictions_rftuned_train = cvModel.transform(data_train)
predictions_rftuned_validation = cvModel.transform(data_validation)
predictions_rftuned_test = cvModel.transform(data_test)

print('train random forest tuned')
print_results(predictions_rftuned_train)
print('\nvalidation random forest tuned')
print_results(predictions_rftuned_validation)
print('\ntest random forest tuned')
print_results(predictions_rftuned_test)


train random forest tuned
f1 = 0.772068
Accuracy = 0.772022
Area under precision-recall curve = 0.837345
Area under ROC curve = 0.772559

validation random forest tuned
f1 = 0.618925
Accuracy = 0.624438
Area under precision-recall curve = 0.709336
Area under ROC curve = 0.624803

test random forest tuned
f1 = 0.607255
Accuracy = 0.612833
Area under precision-recall curve = 0.701593
Area under ROC curve = 0.613918


In [22]:
# gradient boosted tree
###########
from pyspark.mllib.classification import *
# Train a gbt model.

############

# classifier = GBTClassifier(labelCol="label", featuresCol="newfeatures",maxDepth=20).setMaxIter(20)          #  areaUnderROC = 0.994652, computational intense
classifier = GBTClassifier(labelCol="label", featuresCol="newfeatures",maxDepth=8).setMaxIter(8)


# Chain indexers and forest in a Pipeline
pipeline = Pipeline(stages=[classifier])
# classifier.train(data_train)

# # Train model.
# trainingFeats=data_train.features
model = pipeline.fit(data_train)

# # Make predictions.
# testFeats=data_validation.features
predictions_gbt_train = model.transform(data_train)
predictions_gbt_validation = model.transform(data_validation)
predictions_gbt_test = model.transform(data_test)

# # Select example rows to display. Comment it out for GBT
# predictions.select("prediction", "label", "probability").show(5)

classifierModel = model.stages[0]
print(classifierModel)  # summary only

GBTClassificationModel (uid=GBTClassifier_40399dca6297003910b4) with 8 trees


In [23]:
print('train gbt')
print_results(predictions_gbt_train)
print('\nvalidation gbt')
print_results(predictions_gbt_validation)
print('\ntest gbt')
print_results(predictions_gbt_test)

train gbt
f1 = 0.737279
Accuracy = 0.737515
Area under precision-recall curve = 0.811443
Area under ROC curve = 0.736661

validation gbt
f1 = 0.601931
Accuracy = 0.60369
Area under precision-recall curve = 0.695291
Area under ROC curve = 0.603893

test gbt
f1 = 0.605928
Accuracy = 0.607606
Area under precision-recall curve = 0.700979
Area under ROC curve = 0.608205


In [24]:
from pyspark.mllib.classification import *
# Train a MultilayerPerceptronClassifier

############
layers = [4, 5, 4, 3]

# create the trainer and set its parameters
classifier = MultilayerPerceptronClassifier(labelCol="label", featuresCol="newfeatures",maxIter=100, layers=layers, blockSize=128, seed=1234)

# Chain indexers and forest in a Pipeline
pipeline = Pipeline(stages=[classifier])
# classifier.train(data_train)

# # Train model.
# trainingFeats=data_train.features
model = pipeline.fit(data_train)

# # Make predictions.
# testFeats=data_validation.features
predictions_mlp_train = model.transform(data_train)
predictions_mlp_validation = model.transform(data_validation)
predictions_mlp_test = model.transform(data_test)

classifierModel = model.stages[0]
print(classifierModel)  # summary only

MultilayerPerceptronClassifier_41e08bbe5e7da1850a1e


In [25]:
# predictions_mlp_validation.take(1)

# predictionAndLabels_mlp_validation = predictions_mlp_validation.select("prediction", "label").rdd
# print('train mlp')
# # check_results_calc(predictionAndLabels)
# print('\nvalidation mlp')
# check_results_calc(predictionAndLabels_mlp_validation)