In [1]:
from pyspark.sql import SparkSession

In [2]:
spark=SparkSession.builder.appName("Python Project Pump_it_up").config("spark.some.config.option","some-value").getOrCreate()

In [3]:
status = spark.read.format('com.databricks.spark.csv').options(delimiter = ',',header='true', inferSchema='true').load('/FileStore/tables/pumpitup_id_status.csv')
train = spark.read.format('com.databricks.spark.csv').options(delimiter = ',',header='true', inferSchema='true').load('/FileStore/tables/pumpitup_data2.csv')
test = spark.read.format('com.databricks.spark.csv').options(delimiter = ',',header='true', inferSchema='true').load('/FileStore/tables/pumpitup_data1.csv')

In [4]:
import pandas as pd
import numpy as np

In [5]:
train.printSchema()

In [6]:
print ("Our dataset has %d rows." % train.count())

In [7]:
train_df = train.toPandas()
status_df = status.toPandas()

In [8]:
from pyspark.sql.functions import col,sum
display(train.select(*(sum(col(c).isNull().cast("int")).alias(c) for c in train.columns)))

id,amount_tsh,date_recorded,funder,gps_height,installer,longitude,latitude,wpt_name,num_private,basin,subvillage,region,region_code,district_code,lga,ward,population,public_meeting,recorded_by,scheme_management,scheme_name,permit,construction_year,extraction_type,extraction_type_group,extraction_type_class,management,management_group,payment,payment_type,water_quality,quality_group,quantity,quantity_group,source,source_type,source_class,waterpoint_type,waterpoint_type_group
0,0,0,3635,0,3655,0,0,0,0,0,371,0,0,0,0,0,0,3334,0,3877,28166,3056,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [9]:
from numpy import mean, std
from pyspark.sql.functions import *

In [10]:
def removeOutliers(df):
    """Remove the outliers"""
    amount_tsh_loc = df.columns.get_loc('amount_tsh')
    gps_height_loc = df.columns.get_loc('gps_height')
    population_loc = df.columns.get_loc('population')
    for i in [amount_tsh_loc, gps_height_loc, population_loc]:
      col_mean = df.iloc[:,i].mean()
      col_std = df.iloc[:,i].std()
      Q3 = col_mean + 3*col_std
      for j in range(len(df)):
        if df.iloc[j,i] > Q3 :
          df.iloc[j,i] = Q3
    
    return df

In [11]:
train_df = removeOutliers(train_df)  

In [12]:
train = spark.createDataFrame(train_df)

In [13]:
train = train.withColumn('day',dayofmonth('date_recorded')).withColumn('month',month('date_recorded')).withColumn('year',year('date_recorded'))

In [14]:
train_df = train.toPandas()

In [15]:
for i in ['funder','installer','subvillage','public_meeting','scheme_management','permit']:
  train_df = train_df.fillna(train_df[i].value_counts().index[0])

In [16]:
train_df['public_meeting'] = train_df['public_meeting'].astype('str')
train_df['permit'] = train_df['permit'].astype('str')

In [17]:
train_df['age'] = train_df['year'] - train_df['construction_year']

In [18]:
train_df['age'].describe()

In [19]:
train_df[train_df['age']<2000]['age'].describe()

In [20]:
train_df.columns.get_loc('age')

In [21]:
train_df['age'] = np.where(train_df['age']>2000, 53, train_df['age'])

In [22]:
train_df['age'] = np.where(train_df['age']<0, 0, train_df['age'])

In [23]:
train_df['age'].describe()

In [24]:
from numpy import cos, sin

In [25]:
train_df['x'] = cos(train_df['latitude']) * cos(train_df['longitude'])
train_df['y'] = cos(train_df['latitude']) * sin(train_df['longitude']) 
train_df['z']= sin(train_df['latitude']) 

In [26]:
train = spark.createDataFrame(train_df)

In [27]:
train.columns

In [28]:
display(train.select(*(sum(col(c).isNull().cast("int")).alias(c) for c in train.columns)))

id,amount_tsh,date_recorded,funder,gps_height,installer,longitude,latitude,wpt_name,num_private,basin,subvillage,region,region_code,district_code,lga,ward,population,public_meeting,recorded_by,scheme_management,scheme_name,permit,construction_year,extraction_type,extraction_type_group,extraction_type_class,management,management_group,payment,payment_type,water_quality,quality_group,quantity,quantity_group,source,source_type,source_class,waterpoint_type,waterpoint_type_group,day,month,year,age,x,y,z
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [29]:
train = train.drop(
  'id',
 'date_recorded',
 'funder',
 'installer',
 'longitude',
 'latitude',
 'wpt_name',
 'num_private',
 'subvillage',
 'region_code',
 'district_code',
 'lga',
 'recorded_by',
 'scheme_name',
 'construction_year',
 'extraction_type_group',
 'extraction_type_class',
 'management_group',
 'payment_type',
 'quality_group',
 'quantity_group',
 'source_type',
 'source_class',
 'waterpoint_type_group',
 'day',
 'month',
 'year')

In [30]:
train.columns

In [31]:
for i in ['basin','region','ward','scheme_management','extraction_type','management','payment','water_quality','quantity','source','waterpoint_type']:
  train.groupBy(i).count().sort(col(i).desc()).show()

In [32]:
train_df = train.toPandas()

In [33]:
for i in ['basin','region','ward','scheme_management','extraction_type','management','payment','water_quality','quantity','source','waterpoint_type']:
  for j in range(len(train_df)):
    if train_df.iloc[j,train_df.columns.get_loc(i)] == 'unknown':
      train_df.iloc[j,train_df.columns.get_loc(i)] = 'other'

In [34]:
train = spark.createDataFrame(train_df)

In [35]:
train.columns

In [36]:
from pyspark.ml.feature import StringIndexer

In [37]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer

indexers = [StringIndexer(inputCol=column, outputCol=column+"index").fit(train) for column in ["basin",'region','ward','public_meeting','permit','scheme_management','extraction_type','management','payment','water_quality','quantity','source','waterpoint_type'] ]


pipeline = Pipeline(stages=indexers)
train = pipeline.fit(train).transform(train)


display(train)

amount_tsh,gps_height,basin,region,ward,population,public_meeting,scheme_management,permit,extraction_type,management,payment,water_quality,quantity,source,waterpoint_type,age,x,y,z,basinindex,regionindex,wardindex,public_meetingindex,permitindex,scheme_managementindex,extraction_typeindex,managementindex,paymentindex,water_qualityindex,quantityindex,sourceindex,waterpoint_typeindex
6000.0,1390.0,Lake Nyasa,Iringa,Mundindi,109.0,True,VWC,False,gravity,vwc,pay annually,soft,enough,spring,communal standpipe,12,0.8433323014130844,0.3373986455318504,0.4182735748126857,6.0,0.0,545.0,0.0,1.0,0.0,0.0,0.0,5.0,0.0,0.0,0.0,0.0
0.0,1399.0,Lake Victoria,Mara,Natta,280.0,Government Of Tanzania,Other,True,gravity,wug,never pay,soft,insufficient,rainwater harvesting,communal standpipe,3,0.5398050788475607,0.0767569311213749,-0.8382832757339295,0.0,14.0,136.0,2.0,0.0,9.0,0.0,1.0,0.0,0.0,1.0,4.0,0.0
25.0,686.0,Pangani,Manyara,Ngorika,250.0,True,VWC,True,gravity,vwc,pay per bucket,soft,enough,dam,communal standpipe multiple,4,-0.7557333249723,0.1836974060034108,0.6285876267903547,1.0,18.0,1591.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,7.0,3.0
0.0,263.0,Ruvuma / Southern Coast,Mtwara,Nanyumbu,58.0,True,VWC,True,submersible,vwc,never pay,soft,dry,machine dbh,communal standpipe multiple,27,0.1122761014608378,0.1126474220883527,0.9872713078670948,7.0,17.0,642.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,2.0,2.0,3.0
0.0,0.0,Lake Victoria,Kagera,Nyakasimbi,0.0,True,Government Of Tanzania,True,gravity,other,never pay,soft,seasonal,rainwater harvesting,communal standpipe,53,-0.2416583162038554,0.0708209377769432,-0.9677735545993762,0.0,6.0,1416.0,0.0,0.0,2.0,0.0,6.0,0.0,0.0,3.0,4.0,0.0
20.0,0.0,Pangani,Tanga,Moa,1.0,True,VWC,True,submersible,vwc,pay per bucket,salty,enough,other,communal standpipe multiple,2,0.0051556744886271,0.0529226734484029,0.9985853041456404,1.0,11.0,875.0,0.0,0.0,0.0,3.0,0.0,2.0,1.0,0.0,8.0,3.0
0.0,0.0,Internal,Shinyanga,Samuye,0.0,True,VWC,True,swn 80,vwc,never pay,soft,enough,machine dbh,hand pump,53,0.2976007505200422,-0.7545271788138771,0.5849124120080537,3.0,1.0,125.0,0.0,0.0,0.0,4.0,0.0,0.0,0.0,0.0,2.0,1.0
0.0,0.0,Lake Tanganyika,Shinyanga,Chambo,0.0,True,Government Of Tanzania,True,nira/tanira,wug,other,milky,enough,shallow well,hand pump,53,-0.1672712677754079,-0.4362955013372031,0.8841190861471809,4.0,1.0,770.0,0.0,0.0,2.0,1.0,1.0,1.0,3.0,0.0,1.0,1.0
0.0,0.0,Lake Tanganyika,Tabora,Itetemia,0.0,True,VWC,True,india mark ii,vwc,never pay,salty,seasonal,machine dbh,hand pump,53,0.1145181427241638,0.4049136158446174,0.9071552009941,4.0,15.0,670.0,0.0,0.0,0.0,6.0,0.0,0.0,1.0,3.0,2.0,1.0
0.0,0.0,Lake Victoria,Kagera,Kaisho,0.0,True,Government Of Tanzania,True,nira/tanira,vwc,never pay,soft,enough,shallow well,hand pump,53,0.2174564556928259,-0.2190005783798581,-0.951184228499822,0.0,6.0,333.0,0.0,0.0,2.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0


In [38]:
train = train.drop("basin",'region','ward','public_meeting','permit','scheme_management','extraction_type','management','payment','water_quality','quantity','source','waterpoint_type')

In [39]:
display(train)

amount_tsh,gps_height,population,age,x,y,z,basinindex,regionindex,wardindex,public_meetingindex,permitindex,scheme_managementindex,extraction_typeindex,managementindex,paymentindex,water_qualityindex,quantityindex,sourceindex,waterpoint_typeindex
6000.0,1390.0,109.0,12,0.8433323014130844,0.3373986455318504,0.4182735748126857,6.0,0.0,545.0,0.0,1.0,0.0,0.0,0.0,5.0,0.0,0.0,0.0,0.0
0.0,1399.0,280.0,3,0.5398050788475607,0.0767569311213749,-0.8382832757339295,0.0,14.0,136.0,2.0,0.0,9.0,0.0,1.0,0.0,0.0,1.0,4.0,0.0
25.0,686.0,250.0,4,-0.7557333249723,0.1836974060034108,0.6285876267903547,1.0,18.0,1591.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,7.0,3.0
0.0,263.0,58.0,27,0.1122761014608378,0.1126474220883527,0.9872713078670948,7.0,17.0,642.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,2.0,2.0,3.0
0.0,0.0,0.0,53,-0.2416583162038554,0.0708209377769432,-0.9677735545993762,0.0,6.0,1416.0,0.0,0.0,2.0,0.0,6.0,0.0,0.0,3.0,4.0,0.0
20.0,0.0,1.0,2,0.0051556744886271,0.0529226734484029,0.9985853041456404,1.0,11.0,875.0,0.0,0.0,0.0,3.0,0.0,2.0,1.0,0.0,8.0,3.0
0.0,0.0,0.0,53,0.2976007505200422,-0.7545271788138771,0.5849124120080537,3.0,1.0,125.0,0.0,0.0,0.0,4.0,0.0,0.0,0.0,0.0,2.0,1.0
0.0,0.0,0.0,53,-0.1672712677754079,-0.4362955013372031,0.8841190861471809,4.0,1.0,770.0,0.0,0.0,2.0,1.0,1.0,1.0,3.0,0.0,1.0,1.0
0.0,0.0,0.0,53,0.1145181427241638,0.4049136158446174,0.9071552009941,4.0,15.0,670.0,0.0,0.0,0.0,6.0,0.0,0.0,1.0,3.0,2.0,1.0
0.0,0.0,0.0,53,0.2174564556928259,-0.2190005783798581,-0.951184228499822,0.0,6.0,333.0,0.0,0.0,2.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0


In [40]:
train_df = train.toPandas()

In [41]:
train_df.head()

In [42]:
from sklearn.preprocessing import StandardScaler
std = StandardScaler()
train_df = std.fit_transform(train_df)

In [43]:
train_df = pd.DataFrame(train_df)

In [44]:
train_df = pd.concat([train_df,status_df["status_group"]],axis = 1)

In [45]:
train = spark.createDataFrame(train_df)

In [46]:
display(train)

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,status_group
6.050930695119611,1.041254320199537,-0.1799035539015671,-0.8009102213062055,1.3962004320846364,0.7547120761043653,0.5399149399866159,1.1947475600798665,-1.3042273718195538,-0.0272303781146455,-0.3793480326893073,1.0299220179507251,-0.6062367125555588,-0.7567501099829999,-0.5205491879544772,2.242138619282834,-0.3296180002353916,-0.7303108282747479,-1.021633243900815,-0.8258501885606666,functional
-0.2633684888628234,1.0542392836808832,0.4284872985008684,-1.238287108098577,0.8289374581663396,0.2352943295026361,-1.3381129327561836,-1.2968586729872116,1.1313472618212952,-0.8487870849192739,3.464936389684633,-0.6791339622222851,3.328472719896385,-0.7567501099829999,0.006935446114612,-0.9130649074474269,-0.3296180002353916,0.2802527610634314,1.5950588496553757,-0.8258501885606666,functional
-0.2370589089295632,0.0255416212142464,0.3217520612372832,-1.1896896762327578,-1.592298309632278,0.4484097957256184,0.8542466405155031,-0.8815909674760318,1.8272257285758235,2.07386574735636,-0.3793480326893073,-0.6791339622222851,-0.6062367125555588,-0.7567501099829999,-0.5205491879544772,0.3490165032446776,-0.3296180002353916,-0.7303108282747479,3.557577919822519,1.987234081753981,functional
-0.2633684888628234,-0.5847516624090177,-0.3613534572496619,-0.0719487433189201,0.0299272175073271,0.3068184060867028,1.390328991478833,1.610015265591046,1.6532561118871911,0.1676131440370634,-0.3793480326893073,-0.6791339622222851,-0.6062367125555588,0.4173230544966715,-0.5205491879544772,-0.9130649074474269,-0.3296180002353916,1.2908163504016106,0.2867128028772806,1.987234081753981,non functional
-0.2633684888628234,-0.964201150808352,-0.5677082492925933,1.1915844851923745,-0.6315419207119726,0.2234648328551737,-1.531646836815326,-1.2968586729872116,-0.2604096716877614,1.7223439290414213,-0.3793480326893073,-0.6791339622222851,0.268143161322651,-0.7567501099829999,2.6443586164600585,-0.9130649074474269,-0.3296180002353916,2.30137993973979,1.5950588496553757,-0.8258501885606666,functional
-0.2423208249162153,-0.964201150808352,-0.5641504080504738,-1.2868845399643958,-0.1702704829099273,0.1877964205803391,1.4072386922647335,-0.8815909674760318,0.6094384117553989,0.6356393364220962,-0.3793480326893073,-0.6791339622222851,-0.6062367125555588,0.4173230544966715,-0.5205491879544772,0.3490165032446776,0.8945478878465305,-0.7303108282747479,4.211750943211566,1.987234081753981,functional
-0.2633684888628234,-0.964201150808352,-0.5677082492925933,1.1915844851923745,0.3762810327503556,-1.4213235305548366,0.7889704290384747,-0.0510555564536725,-1.130257755130922,-0.8708827420704986,-0.3793480326893073,-0.6791339622222851,-0.6062367125555588,0.8086807759898954,-0.5205491879544772,-0.9130649074474269,-0.3296180002353916,-0.7303108282747479,0.2867128028772806,0.1118445682108826,non functional
-0.2633684888628234,-0.964201150808352,-0.5677082492925933,1.1915844851923745,-0.4925197336545456,-0.7871380887268076,1.2361594865805292,0.3642121490575071,-1.130257755130922,0.4247262454331329,-0.3793480326893073,-0.6791339622222851,0.268143161322651,-0.365392388489776,0.006935446114612,-0.2820242021013746,3.342879664010375,-0.7303108282747479,-0.367460220511767,0.1118445682108826,non functional
-0.2633684888628234,-0.964201150808352,-0.5677082492925933,1.1915844851923745,0.0341173754244057,0.8892587422857886,1.270588860582508,0.3642121490575071,1.3053168785099272,0.2238566349674536,-0.3793480326893073,-0.6791339622222851,-0.6062367125555588,1.5913962189763429,-0.5205491879544772,-0.9130649074474269,0.8945478878465305,2.30137993973979,0.2867128028772806,0.1118445682108826,non functional
-0.2633684888628234,-0.964201150808352,-0.5677082492925933,1.1915844851923745,0.2264991099231296,-0.3541036472300391,-1.5068527204238316,-1.2968586729872116,-0.2604096716877614,-0.4530739523018857,-0.3793480326893073,-0.6791339622222851,0.268143161322651,-0.365392388489776,-0.5205491879544772,-0.9130649074474269,-0.3296180002353916,-0.7303108282747479,-0.367460220511767,0.1118445682108826,functional


In [47]:
train = StringIndexer(inputCol='status_group', outputCol='status_group'+"_index").fit(train).transform(train)

In [48]:
train = train.drop('status_group')

In [49]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
import functools
from pyspark.ml.feature import VectorAssembler

In [50]:
cols = ['0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19']

In [51]:
assembler_features = VectorAssembler(inputCols=cols, outputCol='features')
tmp = [assembler_features]
pipeline = Pipeline(stages=tmp)
train = pipeline.fit(train).transform(train)
train.cache()

In [52]:
train = train.drop('0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19')

In [53]:
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [54]:
from pyspark.mllib.util import MLUtils
from pyspark.mllib.evaluation import MulticlassMetrics

In [55]:
training, test = train.randomSplit([0.7, 0.3], seed=11)
training.cache()

In [56]:
rf = RandomForestClassifier(labelCol='status_group_index', featuresCol='features',numTrees=200)

In [57]:
model = rf.fit(training)

In [58]:
predictions = model.transform(test)

In [59]:
display(predictions)

status_group_index,features,rawPrediction,probability,prediction
0.0,"List(1, 20, List(), List(-0.26336848886282344, -1.0262404207747833, 0.49964412334325853, -0.16914360705055825, 0.5776431526090501, -0.401322204058304, 1.2320357290485873, 1.610015265591046, 1.6532561118871913, -0.15176953660336673, -0.37934803268930734, -0.6791339622222851, 0.268143161322651, -0.365392388489776, -0.5205491879544772, -0.2820242021013746, 4.5670455520922975, -0.7303108282747479, -0.367460220511767, 0.11184456821088264))","List(1, 3, List(), List(111.26418309982525, 71.06525373957226, 17.670563160602516))","List(1, 3, List(), List(0.5563209154991262, 0.35532626869786127, 0.08835281580301257))",0.0
0.0,"List(1, 20, List(), List(-0.26336848886282344, -1.0233548733344842, -0.21192412508064268, 1.1915844851923743, -0.14514441200620215, 0.020912287236359448, -1.5788432821039375, -0.4663232619648522, 0.4354687950667669, 0.8003524170039532, -0.37934803268930734, 1.0299220179507251, -0.6062367125555588, -0.365392388489776, -0.5205491879544772, -0.9130649074474267, 0.8945478878465305, -0.7303108282747479, -0.367460220511767, 0.11184456821088264))","List(1, 3, List(), List(114.8892209877864, 67.57888457214341, 17.531894440070136))","List(1, 3, List(), List(0.5744461049389321, 0.3378944228607171, 0.0876594722003507))",0.0
0.0,"List(1, 20, List(), List(-0.26336848886282344, -1.0089271361329886, 4.7690536138866655, -1.0924948125011196, 0.4618787180973067, -1.786791277597305, -0.15823716735446275, 1.610015265591046, 2.0011953452644553, 1.9714222460188637, 3.4649363896846332, -0.6791339622222851, 0.268143161322651, 0.4173230544966715, -0.5205491879544772, -0.2820242021013746, -0.3296180002353916, 0.28025276106343144, -1.0216332439008147, 1.987234081753981))","List(1, 3, List(), List(63.89351864464527, 122.13743196329393, 13.969049392060795))","List(1, 3, List(), List(0.3194675932232264, 0.6106871598164698, 0.06984524696030399))",1.0
0.0,"List(1, 20, List(), List(-0.26336848886282344, -1.0060415886926894, 1.21121237176716, 1.1915844851923743, 0.7344185754885079, -0.7958771538010762, 1.0396402267056541, 1.610015265591046, 1.6532561118871913, 1.5957960744480433, -0.37934803268930734, 1.0299220179507251, 0.268143161322651, -0.365392388489776, -0.5205491879544772, -0.9130649074474267, 0.8945478878465305, 0.28025276106343144, -0.367460220511767, 0.11184456821088264))","List(1, 3, List(), List(110.63125926530753, 71.32228046789201, 18.046460266800512))","List(1, 3, List(), List(0.5531562963265375, 0.35661140233946, 0.09023230133400253))",0.0
0.0,"List(1, 20, List(), List(-0.26336848886282344, -1.0017132675322407, -0.4431838058184106, 1.1915844851923743, -0.09188447976063348, -0.12507923128921053, -1.5700275546667872, -0.4663232619648522, 0.4354687950667669, 0.8103958975272372, -0.37934803268930734, 1.0299220179507251, -0.6062367125555588, -0.365392388489776, -0.5205491879544772, -0.9130649074474267, -0.3296180002353916, -0.7303108282747479, -0.367460220511767, 0.11184456821088264))","List(1, 3, List(), List(115.07612969197609, 67.20987646630294, 17.713993841720903))","List(1, 3, List(), List(0.5753806484598806, 0.3360493823315148, 0.08856996920860453))",0.0
0.0,"List(1, 20, List(), List(-0.26336848886282344, -0.9959421726516425, 2.2785647444030115, -1.3840794036960338, 0.5221612123066282, 1.8935415207227584, -0.3561472193279342, 0.7794798545686867, 0.4354687950667669, -1.081795833059462, 1.5427941784976629, -0.6791339622222851, 2.8912827829572803, 2.3741116619627904, 2.1168739823909695, -0.9130649074474267, -0.3296180002353916, -0.7303108282747479, 0.9408858262663283, -0.8258501885606666))","List(1, 3, List(), List(123.16112329439676, 64.61495708271501, 12.223919622888173))","List(1, 3, List(), List(0.615805616471984, 0.32307478541357515, 0.061119598114440885))",0.0
0.0,"List(1, 20, List(), List(-0.26336848886282344, -0.9858427566105955, -0.3293328860705864, 1.1915844851923743, -0.14426594066993442, -0.007737518848523199, -1.578012631681619, -0.4663232619648522, 0.4354687950667669, -0.430978295150661, -0.37934803268930734, 1.0299220179507251, -0.6062367125555588, 0.4173230544966715, -0.5205491879544772, -0.9130649074474267, 4.5670455520922975, -0.7303108282747479, 0.2867128028772806, -0.8258501885606666))","List(1, 3, List(), List(100.03748424661852, 87.11510579292629, 12.847409960455098))","List(1, 3, List(), List(0.5001874212330928, 0.43557552896463164, 0.06423704980227551))",0.0
0.0,"List(1, 20, List(), List(-0.26336848886282344, -0.984399982890446, -0.4858779007238447, 1.1915844851923743, -0.13933474514439875, -0.022794165651154046, -1.5773783172734754, -0.4663232619648522, 0.4354687950667669, -0.430978295150661, -0.37934803268930734, 1.0299220179507251, -0.6062367125555588, 0.4173230544966715, -0.5205491879544772, -0.9130649074474267, -0.3296180002353916, -0.7303108282747479, -1.0216332439008147, -0.8258501885606666))","List(1, 3, List(), List(103.28527579770615, 83.1208573735772, 13.593866828716589))","List(1, 3, List(), List(0.5164263789885309, 0.41560428686788614, 0.06796933414358297))",0.0
0.0,"List(1, 20, List(), List(-0.26336848886282344, -0.984399982890446, -0.4609730120290081, 1.1915844851923743, -0.09032410015519653, -0.11834502150778246, -1.5704883148461914, -0.4663232619648522, 0.4354687950667669, 0.8103958975272372, -0.37934803268930734, 1.0299220179507251, -0.6062367125555588, 1.9827539404695667, -0.5205491879544772, -0.9130649074474267, -0.3296180002353916, -0.7303108282747479, 0.2867128028772806, 0.11184456821088264))","List(1, 3, List(), List(101.68886346321628, 85.38182863841782, 12.929307898365844))","List(1, 3, List(), List(0.5084443173160815, 0.4269091431920892, 0.06464653949182923))",0.0
0.0,"List(1, 20, List(), List(-0.26336848886282344, -0.974300566849399, -0.42539459960781306, -0.2663384707821963, 0.5838971617749195, 1.829909761421742, -0.4632390324671274, 0.7794798545686867, 0.4354687950667669, -1.0275610382337286, -0.37934803268930734, -0.6791339622222851, -0.6062367125555588, 1.200038497483119, -0.5205491879544772, -0.9130649074474267, -0.3296180002353916, -0.7303108282747479, 0.9408858262663283, -0.8258501885606666))","List(1, 3, List(), List(104.09082771432405, 82.70237933322377, 13.206792952452153))","List(1, 3, List(), List(0.5204541385716204, 0.4135118966661189, 0.06603396476226077))",0.0


In [60]:
evaluator = MulticlassClassificationEvaluator(labelCol="status_group_index", predictionCol="prediction", metricName="accuracy")

In [61]:
print('Accuracy', evaluator.evaluate(predictions))