Dataset source: https://archive.ics.uci.edu/ml/datasets/Adult <br/>
PySpark ML Classifier Reference: https://spark.apache.org/docs/2.3.0/ml-classification-regression.html#classification <br/>
<b> <i> Classifier Models Used: </i> </b> <br/>
<ul>
  <li> Logistic Regression </li>
  <li> Naive Bayes </li>
  <li> Decision Tree </li>
  <li> Gradient-boosted Tree </li>
  <li> Random Forest </li>
  <li> Multilayer Perceptron </li>
  <li> One-vs-All (Logistic Regression, Random Forest) </li>
</ul>
<hr/>

In [2]:
# Computational and Visualisation Packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from ggplot import *

# Pyspark Packages
from pyspark.sql import functions as F
from pyspark.sql.functions import col, desc, trim
from pyspark.sql.types import *
from pyspark.ml import Pipeline
from pyspark.ml.classification import NaiveBayes, LogisticRegression, DecisionTreeClassifier, GBTClassifier, RandomForestClassifier, OneVsRest, MultilayerPerceptronClassifier
from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler
from pyspark.ml.evaluation import RegressionEvaluator, BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

In [3]:
adult_gov_data = spark.read.csv('/databricks-datasets/adult/adult.data')
country_codes = spark.sql('SELECT * FROM country_codes')

adult_gov_data = adult_gov_data.select(col('_c0').cast('double').alias('age'), col('_c1').alias('workclass'), col('_c2').cast('double').alias('fnlwgt'), col('_c3').alias('education'), col('_c4').cast('double').alias('education_num'), col('_c5').alias('marital_status'),col('_c6').alias('occupation'), col('_c7').alias('relationship'), col('_c8').alias('race'), col('_c9').alias('sex'), col('_c10').cast('double').alias('capital_gain'), col('_c11').cast('double').alias('capital_loss'), col('_c12').cast('double').alias('hours_per_week'), col('_c13').alias('native_country'), col('_c14').alias('income'))
adult_gov_data = adult_gov_data.withColumn('native_country', F.regexp_replace(col('native_country'), '-', ' '))
adult_gov_data = adult_gov_data.withColumn('native_country', trim(col('native_country')))

expanded_cols = adult_gov_data.columns + ['alpha_3_code', 'latitude', 'longitude']
adult_gov_data_expanded = adult_gov_data.join(country_codes, adult_gov_data.native_country == country_codes.country, how='left')[expanded_cols]

In [4]:
display (adult_gov_data_expanded.sample(False, 0.01), 250)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude
35.0,?,129305.0,HS-grad,9.0,Married-civ-spouse,?,Husband,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0
47.0,Local-gov,186009.0,Some-college,10.0,Divorced,Adm-clerical,Unmarried,White,Female,0.0,0.0,38.0,Mexico,<=50K,MEX,23.0,-102.0
19.0,Private,104112.0,HS-grad,9.0,Never-married,Sales,Unmarried,Black,Male,0.0,0.0,30.0,Haiti,<=50K,HTI,19.0,-72.4167
45.0,Private,261192.0,HS-grad,9.0,Married-civ-spouse,Other-service,Husband,Black,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0
59.0,Local-gov,286967.0,HS-grad,9.0,Married-civ-spouse,Transport-moving,Husband,White,Male,0.0,0.0,45.0,United States,<=50K,USA,38.0,-97.0
46.0,Federal-gov,371373.0,HS-grad,9.0,Divorced,Adm-clerical,Not-in-family,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0
52.0,Private,25826.0,10th,6.0,Married-civ-spouse,Craft-repair,Husband,White,Male,0.0,1887.0,47.0,United States,>50K,USA,38.0,-97.0
30.0,Private,103649.0,Bachelors,13.0,Married-civ-spouse,Adm-clerical,Wife,Black,Female,0.0,0.0,40.0,United States,>50K,USA,38.0,-97.0
36.0,Private,131414.0,Some-college,10.0,Never-married,Sales,Not-in-family,Black,Female,0.0,0.0,36.0,United States,<=50K,USA,38.0,-97.0
36.0,Private,51838.0,Some-college,10.0,Divorced,Adm-clerical,Unmarried,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0


In [5]:
display (adult_gov_data_expanded.describe())

summary,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude
count,32561.0,32561,32561.0,32561,32561.0,32561,32561,32561,32561,32561,32561.0,32561.0,32561.0,32561,32561,31606,31606.0,31606.0
mean,38.58164675532078,,189778.36651208505,,10.0806793403151,,,,,,1077.6488437087312,87.303829734959,40.437455852093,,,,37.09002754856691,-91.46084107764304
stddev,13.640432553581356,,105549.97769702227,,2.572720332067397,,,,,,7385.292084840354,402.960218649002,12.347428681731838,,,,5.110597872169294,30.59850465528006
min,17.0,?,12285.0,10th,1.0,Divorced,?,Husband,Amer-Indian-Eskimo,Female,0.0,0.0,1.0,?,<=50K,CAN,-10.0,-102.0
max,90.0,Without-pay,1484705.0,Some-college,16.0,Widowed,Transport-moving,Wife,White,Male,99999.0,4356.0,99.0,Yugoslavia,>50K,VNM,60.0,138.0


In [6]:
display(adult_gov_data_expanded.groupBy('alpha_3_code').agg(F.sum('hours_per_week').alias('total_hours')))

alpha_3_code,total_hours
HTI,1624.0
POL,2300.0
JAM,3126.0
CUB,3720.0
FRA,1307.0
ITA,3037.0
GTM,2511.0
,39644.0
MEX,25939.0
HUN,463.0


In [7]:
col_categorical = ['workclass', 'education', 'marital_status', 'occupation', 'relationship', 'race', 'sex', 'native_country']
pipeline_steps = [] 

for column in col_categorical:
    string_indexed = StringIndexer(inputCol=column, outputCol=column + "Index")
    one_hot_encoded = OneHotEncoderEstimator(inputCols=[string_indexed.getOutputCol()], outputCols=[column + "classVec"])
    pipeline_steps += [string_indexed, one_hot_encoded]
    
label_stringIdx = StringIndexer(inputCol="income", outputCol="label")
pipeline_steps += [label_stringIdx]

numerical_columns = ['age', 'fnlwgt', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week']
assembler_col = [col + "classVec" for col in col_categorical] + numerical_columns
assemblerModel = VectorAssembler(inputCols=assembler_col, outputCol="features")
pipeline_steps += [assemblerModel]

# Applying the pipeline on the dataset
pipelineInst = Pipeline (stages=pipeline_steps)
pipelineModel = pipelineInst.fit (adult_gov_data_expanded)
adult_gov_data_processed = pipelineModel.transform (adult_gov_data_expanded).select(adult_gov_data_expanded.columns + ['label', 'features'])
train, test = adult_gov_data_processed.randomSplit([.75, .25], seed=121)

In [8]:
display (train.describe())

summary,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label
count,24386.0,24386,24386.0,24386,24386.0,24386,24386,24386,24386,24386,24386.0,24386.0,24386.0,24386,24386,23650,23650.0,23650.0,24386.0
mean,38.581686213401134,,189675.10838185844,,10.100303452800787,,,,,,1093.426966292135,89.11801853522513,40.41700155827114,,,,37.08677924735743,-91.2902892304436,0.2413679980316575
stddev,13.62386997654128,,105774.27339437084,,2.577346921180809,,,,,,7492.912946262998,407.9837982980192,12.320523728635733,,,,5.131110552140711,31.13485399639285,0.4279217179126274
min,17.0,?,13769.0,10th,1.0,Divorced,?,Husband,Amer-Indian-Eskimo,Female,0.0,0.0,1.0,?,<=50K,CAN,-10.0,-102.0,0.0
max,90.0,Without-pay,1484705.0,Some-college,16.0,Widowed,Transport-moving,Wife,White,Male,99999.0,4356.0,99.0,Yugoslavia,>50K,VNM,60.0,138.0,1.0


In [9]:
display_cols = ['label', 'age', 'occupation', 'probability', 'prediction']

In [10]:
lrInst = LogisticRegression(labelCol='label', featuresCol='features', maxIter=50)
lrModel = lrInst.fit(train)
lrPredictions = lrModel.transform(test)

lrbceInst = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction")
print ("Model Accuracy = %.15f" % lrbceInst.evaluate(lrPredictions))

In [11]:
display (lrPredictions[display_cols])

label,age,occupation,probability,prediction
0.0,17.0,?,"List(1, 2, List(), List(0.9995291852265232, 4.7081477347669937E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9995686741141787, 4.313258858211785E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9990154651814908, 9.845348185092257E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9997517830442086, 2.4821695579129805E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9989538349801137, 0.0010461650198864013))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9992783383236725, 7.216616763274089E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9997843370346564, 2.1566296534353187E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.999749112688926, 2.5088731107402525E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9989380971078602, 0.0010619028921397965))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.998910170455052, 0.0010898295449478783))",0.0


In [12]:
lrpgInst = (ParamGridBuilder()
             .addGrid(lrInst.regParam, [0.01, 0.5, 2.0])
             .addGrid(lrInst.elasticNetParam, [0.0, 0.5, 1.0])
             .addGrid(lrInst.maxIter, [15, 30, 45])
             .build())
lrcvInst = CrossValidator(estimator=lrInst, estimatorParamMaps=lrpgInst, evaluator=lrbceInst, numFolds=5) #5 folds
lrcvModel = lrcvInst.fit(train)
lrcvPredictions = lrcvModel.transform(test)
print ('Best Model Score: %.15f' % lrbceInst.evaluate(lrcvPredictions))
print ('Best Model Intercept: %.15f' % lrcvModel.bestModel.intercept)
new_frame_lrweights = sqlContext.createDataFrame([(float(w),) for w in lrcvModel.bestModel.coefficients], ["Feature Weight"])

In [13]:
display(new_frame_lrweights)

Feature Weight
0.1028467959532855
-0.3129548816125161
-0.0530022998461281
-0.2838372555748569
-0.2117968704909418
0.3102261181640913
0.6293828591877813
-2.257986491427108
-0.1837038306507316
-0.0010259397380516


In [14]:
display(lrcvPredictions[display_cols])

label,age,occupation,probability,prediction
0.0,17.0,?,"List(1, 2, List(), List(0.9984893958285243, 0.0015106041714757887))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9984298671990526, 0.0015701328009475087))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.997177829126295, 0.002822170873704959))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9990351540834088, 9.648459165913357E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9970251070100321, 0.002974892989967852))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9973158469657829, 0.0026841530342171225))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9991272727427515, 8.727272572485495E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.999026140234414, 9.738597655858769E-4))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9969863030497849, 0.0030136969502151663))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9969176341262079, 0.0030823658737921655))",0.0


In [15]:
nbInst = NaiveBayes(labelCol='label', featuresCol='features', smoothing=2.0, modelType="multinomial")
nbModel = nbInst.fit(train)
nbPredictions = nbModel.transform(test)
a
nbmceInst = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
print ("Model Accuracy = %.15f" % nbmceInst.evaluate(nbPredictions))

In [16]:
display (nbPredictions[display_cols])

label,age,occupation,probability,prediction
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0
0.0,17.0,?,"List(1, 2, List(), List(1.0, 0.0))",0.0


In [17]:
nbpgInst = (ParamGridBuilder()
             .addGrid(nbInst.smoothing, [1.0, 2.0, 3.0, 4.0, 5.0])
             .build())
nbcvInst = CrossValidator(estimator=nbInst, estimatorParamMaps=nbpgInst, evaluator=nbmceInst, numFolds=5) #5 folds
nbcvModel = nbcvInst.fit(train)
nbcvBestPredictions = nbcvModel.transform(test)
print ('Best Model Score: %.15f' % nbmceInst.evaluate(nbcvBestPredictions))

In [18]:
display(nbcvBestPredictions)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label,features,rawPrediction,probability,prediction
17.0,?,47407.0,11th,7.0,Never-married,?,Own-child,White,Male,0.0,0.0,10.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 47407.0, 7.0, 10.0))","List(1, 2, List(), List(-481.4571733052232, -1503.1024106107902))","List(1, 2, List(), List(1.0, 0.0))",0.0
17.0,?,86786.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 86786.0, 6.0, 40.0))","List(1, 2, List(), List(-774.4349671804495, -2637.4436206472974))","List(1, 2, List(), List(1.0, 0.0))",0.0
17.0,?,89870.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 89870.0, 6.0, 40.0))","List(1, 2, List(), List(-791.8046641768052, -2720.7069971978494))","List(1, 2, List(), List(1.0, 0.0))",0.0
17.0,?,158762.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 158762.0, 6.0, 20.0))","List(1, 2, List(), List(-714.702559991093, -4125.809263367621))","List(1, 2, List(), List(1.0, 0.0))",0.0
17.0,?,161981.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 161981.0, 6.0, 40.0))","List(1, 2, List(), List(-902.2371323389249, -4379.232743050699))","List(1, 2, List(), List(1.0, 0.0))",0.0
17.0,?,166759.0,12th,8.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 19, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 166759.0, 8.0, 40.0))","List(1, 2, List(), List(-917.495686148562, -4496.77387381484))","List(1, 2, List(), List(1.0, 0.0))",0.0
17.0,?,170320.0,11th,7.0,Never-married,?,Own-child,White,Female,0.0,0.0,8.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 170320.0, 7.0, 8.0))","List(1, 2, List(), List(-640.046409019491, -4301.016898437263))","List(1, 2, List(), List(1.0, 0.0))",0.0
17.0,?,171461.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 171461.0, 6.0, 20.0))","List(1, 2, List(), List(-734.1501049314039, -4417.881463816725))","List(1, 2, List(), List(1.0, 0.0))",0.0
17.0,?,179715.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 179715.0, 6.0, 40.0))","List(1, 2, List(), List(-929.395393774587, -4787.108040535888))","List(1, 2, List(), List(1.0, 0.0))",0.0
17.0,?,210547.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 210547.0, 6.0, 40.0))","List(1, 2, List(), List(-976.6122380615124, -5496.232387669266))","List(1, 2, List(), List(1.0, 0.0))",0.0


In [19]:
lrPrevBestModel = lrcvModel.bestModel
ovalrInst = LogisticRegression(labelCol='label', featuresCol='features', fitIntercept=True, 
                               maxIter=lrPrevBestModel._java_obj.getMaxIter(), 
                               elasticNetParam=lrPrevBestModel._java_obj.getElasticNetParam(), 
                               regParam=lrPrevBestModel._java_obj.getRegParam())
ovaInst = OneVsRest(classifier=ovalrInst)
ovaModel = ovaInst.fit(train)
ovaPredictions = ovaModel.transform(test)

ovamceInst = MulticlassClassificationEvaluator(metricName="accuracy")
print ("Model Accuracy = %.15f" % ovamceInst.evaluate(ovaPredictions))

In [20]:
display (ovaPredictions)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label,features,prediction
17.0,?,47407.0,11th,7.0,Never-married,?,Own-child,White,Male,0.0,0.0,10.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 47407.0, 7.0, 10.0))",0.0
17.0,?,86786.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 86786.0, 6.0, 40.0))",0.0
17.0,?,89870.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 89870.0, 6.0, 40.0))",0.0
17.0,?,158762.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 158762.0, 6.0, 20.0))",0.0
17.0,?,161981.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 161981.0, 6.0, 40.0))",0.0
17.0,?,166759.0,12th,8.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 19, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 166759.0, 8.0, 40.0))",0.0
17.0,?,170320.0,11th,7.0,Never-married,?,Own-child,White,Female,0.0,0.0,8.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 170320.0, 7.0, 8.0))",0.0
17.0,?,171461.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 171461.0, 6.0, 20.0))",0.0
17.0,?,179715.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 179715.0, 6.0, 40.0))",0.0
17.0,?,210547.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 210547.0, 6.0, 40.0))",0.0


In [21]:
dtInst = DecisionTreeClassifier (labelCol="label", featuresCol="features", maxDepth=4)
dtModel = dtInst.fit (train)
dtPredictions = dtModel.transform (test)
dtbceInst = BinaryClassificationEvaluator()
print ('Model Fit Score: %.15f' % dtbceInst.evaluate(dtPredictions))

In [22]:
display(dtPredictions[display_cols])

label,age,occupation,probability,prediction
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9849181100506658, 0.015081889949334275))",0.0


In [23]:
dtpgInst = (ParamGridBuilder()
             .addGrid(dtInst.maxDepth, [1, 3, 8, 12])
             .addGrid(dtInst.maxBins, [25, 50, 90])
             .build())
dtcvInst = CrossValidator(estimator=dtInst, estimatorParamMaps=dtpgInst, evaluator=dtbceInst, numFolds=5) #5 folds
dtcvModel = dtcvInst.fit(train)
dtcvBestPredictions = dtcvModel.transform(test)
print ('Best Model Score: %.15f' % dtbceInst.evaluate(dtcvBestPredictions))

In [24]:
display(dtcvBestPredictions[display_cols])

label,age,occupation,probability,prediction
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9972826086956522, 0.002717391304347826))",0.0


In [25]:
gbtInst = GBTClassifier (labelCol="label", featuresCol="features", maxIter=20, maxDepth=5)
gbtModel = gbtInst.fit (train)
gbtPredictions = gbtModel.transform (test)
gbtmceInst = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
print ('Model Fit Score: %.15f' % gbtmceInst.evaluate(gbtPredictions))

In [26]:
display(gbtPredictions[display_cols])

label,age,occupation,probability,prediction
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9539880436433699, 0.046011956356630135))",0.0


In [27]:
gbtpgInst = (ParamGridBuilder()
             .addGrid(gbtInst.maxDepth, [3, 5, 7])
             .addGrid(gbtInst.maxIter, [25, 40, 50])
             .build())
gbtcvInst = CrossValidator(estimator=gbtInst, estimatorParamMaps=gbtpgInst, evaluator=gbtmceInst, numFolds=5) #5 folds
gbtcvModel = gbtcvInst.fit(train)
gbtcvBestPredictions = gbtcvModel.transform(test)
print ('Best Model Score: %.15f' % gbtmceInst.evaluate(gbtcvBestPredictions))

In [28]:
display (gbtcvBestPredictions)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label,features,rawPrediction,probability,prediction
17.0,?,47407.0,11th,7.0,Never-married,?,Own-child,White,Male,0.0,0.0,10.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 47407.0, 7.0, 10.0))","List(1, 2, List(), List(1.906925294069026, -1.906925294069026))","List(1, 2, List(), List(0.9784132111774687, 0.02158678882253129))",0.0
17.0,?,86786.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 86786.0, 6.0, 40.0))","List(1, 2, List(), List(1.905824604163853, -1.905824604163853))","List(1, 2, List(), List(0.9783666672763492, 0.021633332723650756))",0.0
17.0,?,89870.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 89870.0, 6.0, 40.0))","List(1, 2, List(), List(1.905824604163853, -1.905824604163853))","List(1, 2, List(), List(0.9783666672763492, 0.021633332723650756))",0.0
17.0,?,158762.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 158762.0, 6.0, 20.0))","List(1, 2, List(), List(1.906925294069026, -1.906925294069026))","List(1, 2, List(), List(0.9784132111774687, 0.02158678882253129))",0.0
17.0,?,161981.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 161981.0, 6.0, 40.0))","List(1, 2, List(), List(1.905824604163853, -1.905824604163853))","List(1, 2, List(), List(0.9783666672763492, 0.021633332723650756))",0.0
17.0,?,166759.0,12th,8.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 19, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 166759.0, 8.0, 40.0))","List(1, 2, List(), List(1.905824604163853, -1.905824604163853))","List(1, 2, List(), List(0.9783666672763492, 0.021633332723650756))",0.0
17.0,?,170320.0,11th,7.0,Never-married,?,Own-child,White,Female,0.0,0.0,8.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 170320.0, 7.0, 8.0))","List(1, 2, List(), List(1.906925294069026, -1.906925294069026))","List(1, 2, List(), List(0.9784132111774687, 0.02158678882253129))",0.0
17.0,?,171461.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 171461.0, 6.0, 20.0))","List(1, 2, List(), List(1.906925294069026, -1.906925294069026))","List(1, 2, List(), List(0.9784132111774687, 0.02158678882253129))",0.0
17.0,?,179715.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 179715.0, 6.0, 40.0))","List(1, 2, List(), List(1.905824604163853, -1.905824604163853))","List(1, 2, List(), List(0.9783666672763492, 0.021633332723650756))",0.0
17.0,?,210547.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 210547.0, 6.0, 40.0))","List(1, 2, List(), List(1.905824604163853, -1.905824604163853))","List(1, 2, List(), List(0.9783666672763492, 0.021633332723650756))",0.0


In [29]:
rfInst = RandomForestClassifier(labelCol="label", featuresCol="features")
rfModel = rfInst.fit(train)
rfPredictions = rfModel.transform(test)
rfbceInst = BinaryClassificationEvaluator()
print ('Model Fit Score: %.15f' % rfbceInst.evaluate(rfPredictions))

In [30]:
display (rfPredictions[display_cols])

label,age,occupation,probability,prediction
0.0,17.0,?,"List(1, 2, List(), List(0.9647656715281634, 0.03523432847183654))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9779761302760008, 0.02202386972399928))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9647656715281634, 0.03523432847183654))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9774397243155428, 0.022560275684457204))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9642292655677055, 0.035770734432294456))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9774397243155428, 0.022560275684457204))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9774397243155428, 0.022560275684457204))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9774397243155428, 0.022560275684457204))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9642292655677055, 0.035770734432294456))",0.0
0.0,17.0,?,"List(1, 2, List(), List(0.9642292655677055, 0.035770734432294456))",0.0


In [31]:
rfpgInst = (ParamGridBuilder()
             .addGrid(rfInst.maxDepth, [1, 3, 8])
             .addGrid(rfInst.maxBins, [25, 50])
             .addGrid(rfInst.numTrees, [5, 20])
             .build())
rfcvInst = CrossValidator(estimator=rfInst, estimatorParamMaps=rfpgInst, evaluator=rfbceInst, numFolds=5)
rfcvModel = rfcvInst.fit(train)
rfcvBestPredictions = rfcvModel.transform(test)
print ('Best Model Score: %.15f' % rfbceInst.evaluate(rfcvBestPredictions))

In [32]:
display (rfcvBestPredictions)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label,features,rawPrediction,probability,prediction
17.0,?,47407.0,11th,7.0,Never-married,?,Own-child,White,Male,0.0,0.0,10.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 47407.0, 7.0, 10.0))","List(1, 2, List(), List(19.832105485565403, 0.16789451443459927))","List(1, 2, List(), List(0.9916052742782699, 0.008394725721729963))",0.0
17.0,?,86786.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 86786.0, 6.0, 40.0))","List(1, 2, List(), List(19.81839571368025, 0.18160428631975262))","List(1, 2, List(), List(0.9909197856840123, 0.00908021431598763))",0.0
17.0,?,89870.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 89870.0, 6.0, 40.0))","List(1, 2, List(), List(19.811886699516585, 0.18811330048341662))","List(1, 2, List(), List(0.9905943349758293, 0.00940566502417083))",0.0
17.0,?,158762.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 158762.0, 6.0, 20.0))","List(1, 2, List(), List(19.828128073777574, 0.17187192622242903))","List(1, 2, List(), List(0.9914064036888786, 0.00859359631112145))",0.0
17.0,?,161981.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 161981.0, 6.0, 40.0))","List(1, 2, List(), List(19.811886699516585, 0.18811330048341662))","List(1, 2, List(), List(0.9905943349758293, 0.00940566502417083))",0.0
17.0,?,166759.0,12th,8.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 19, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 166759.0, 8.0, 40.0))","List(1, 2, List(), List(19.821515838485244, 0.1784841615147604))","List(1, 2, List(), List(0.991075791924262, 0.008924208075738018))",0.0
17.0,?,170320.0,11th,7.0,Never-married,?,Own-child,White,Female,0.0,0.0,8.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 170320.0, 7.0, 8.0))","List(1, 2, List(), List(19.84173462453406, 0.15826537546594302))","List(1, 2, List(), List(0.9920867312267028, 0.00791326877329715))",0.0
17.0,?,171461.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 171461.0, 6.0, 20.0))","List(1, 2, List(), List(19.831248198582568, 0.16875180141743681))","List(1, 2, List(), List(0.9915624099291283, 0.00843759007087184))",0.0
17.0,?,179715.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 179715.0, 6.0, 40.0))","List(1, 2, List(), List(19.813161805871392, 0.18683819412860891))","List(1, 2, List(), List(0.9906580902935695, 0.009341909706430446))",0.0
17.0,?,210547.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 210547.0, 6.0, 40.0))","List(1, 2, List(), List(19.813161805871392, 0.18683819412860891))","List(1, 2, List(), List(0.9906580902935695, 0.009341909706430446))",0.0


In [33]:
rfPrevBestModel = rfcvModel.bestModel
ovarfInst = RandomForestClassifier(labelCol='label', featuresCol='features',
                               maxDepth=rfPrevBestModel._java_obj.getMaxDepth(), 
                               maxBins=rfPrevBestModel._java_obj.getMaxBins(), 
                               numTrees=rfPrevBestModel._java_obj.getNumTrees())
ovarfInst = OneVsRest(classifier=ovarfInst)
ovarfModel = ovarfInst.fit(train)
ovarfPredictions = ovarfModel.transform(test)

ovamceInst = MulticlassClassificationEvaluator(metricName="accuracy")
print ("Model Accuracy = %.15f" % ovamceInst.evaluate(ovarfPredictions))

In [34]:
display (ovarfPredictions)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label,features,prediction
17.0,?,47407.0,11th,7.0,Never-married,?,Own-child,White,Male,0.0,0.0,10.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 47407.0, 7.0, 10.0))",0.0
17.0,?,86786.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 86786.0, 6.0, 40.0))",0.0
17.0,?,89870.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 89870.0, 6.0, 40.0))",0.0
17.0,?,158762.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 158762.0, 6.0, 20.0))",0.0
17.0,?,161981.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 161981.0, 6.0, 40.0))",0.0
17.0,?,166759.0,12th,8.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 19, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 166759.0, 8.0, 40.0))",0.0
17.0,?,170320.0,11th,7.0,Never-married,?,Own-child,White,Female,0.0,0.0,8.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 170320.0, 7.0, 8.0))",0.0
17.0,?,171461.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 171461.0, 6.0, 20.0))",0.0
17.0,?,179715.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 179715.0, 6.0, 40.0))",0.0
17.0,?,210547.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 210547.0, 6.0, 40.0))",0.0


In [35]:
layers = [train.schema["features"].metadata["ml_attr"]["num_attrs"], 20, 10, 2]
mpcInst = MultilayerPerceptronClassifier(maxIter=100, layers=layers, blockSize=128, seed=451)
mpcModel = mpcInst.fit(train)
mpcPredictions = mpcModel.transform(test)
mpcBce = MulticlassClassificationEvaluator(metricName="accuracy")
print ("Model Accuracy = %.15f" % mpcBce.evaluate(mpcPredictions))

In [36]:
display (mpcPredictions)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label,features,rawPrediction,probability,prediction
17.0,?,47407.0,11th,7.0,Never-married,?,Own-child,White,Male,0.0,0.0,10.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 47407.0, 7.0, 10.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0
17.0,?,86786.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 86786.0, 6.0, 40.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0
17.0,?,89870.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 89870.0, 6.0, 40.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0
17.0,?,158762.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 158762.0, 6.0, 20.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0
17.0,?,161981.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 161981.0, 6.0, 40.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0
17.0,?,166759.0,12th,8.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 19, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 166759.0, 8.0, 40.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0
17.0,?,170320.0,11th,7.0,Never-married,?,Own-child,White,Female,0.0,0.0,8.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 170320.0, 7.0, 8.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0
17.0,?,171461.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 171461.0, 6.0, 20.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0
17.0,?,179715.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 179715.0, 6.0, 40.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0
17.0,?,210547.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 210547.0, 6.0, 40.0))","List(1, 2, List(), List(0.953811992394703, -0.22494370514379952))","List(1, 2, List(), List(0.7647240011164417, 0.23527599888355832))",0.0


In [37]:
mpcpgInst = (ParamGridBuilder()
             .addGrid(mpcInst.maxIter, [50, 100, 150])
             .addGrid(mpcInst.blockSize, [128, 256, 512])
             .build())
mpccvInst = CrossValidator(estimator=mpcInst, estimatorParamMaps=mpcpgInst, evaluator=mpcBce, numFolds=5)
mpccvModel = mpccvInst.fit(train)
mpccvBestPredictions = mpccvModel.transform(test)
print ("Best Model Accuracy Score = %.15f" % mpcBce.evaluate(mpccvBestPredictions))

In [38]:
display (mpccvBestPredictions)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label,features,rawPrediction,probability,prediction
17.0,?,47407.0,11th,7.0,Never-married,?,Own-child,White,Male,0.0,0.0,10.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 47407.0, 7.0, 10.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0
17.0,?,86786.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 86786.0, 6.0, 40.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0
17.0,?,89870.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 89870.0, 6.0, 40.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0
17.0,?,158762.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 158762.0, 6.0, 20.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0
17.0,?,161981.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 161981.0, 6.0, 40.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0
17.0,?,166759.0,12th,8.0,Never-married,?,Own-child,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 19, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 166759.0, 8.0, 40.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0
17.0,?,170320.0,11th,7.0,Never-married,?,Own-child,White,Female,0.0,0.0,8.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 13, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 170320.0, 7.0, 8.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0
17.0,?,171461.0,10th,6.0,Never-married,?,Own-child,White,Female,0.0,0.0,20.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 171461.0, 6.0, 20.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0
17.0,?,179715.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 179715.0, 6.0, 40.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0
17.0,?,210547.0,10th,6.0,Never-married,?,Own-child,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(3, 15, 24, 36, 45, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 17.0, 210547.0, 6.0, 40.0))","List(1, 2, List(), List(0.6876307671197011, -0.4942175011979356))","List(1, 2, List(), List(0.7652799645741133, 0.23472003542588674))",0.0


In [39]:
selectedModel = rfcvModel.bestModel
selectedPredictions = selectedModel.transform(adult_gov_data_processed)
print ("Model Fit Score = ", rfbceInst.evaluate(selectedPredictions))

In [40]:
display(selectedPredictions)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label,features,rawPrediction,probability,prediction
39.0,State-gov,77516.0,Bachelors,13.0,Never-married,Adm-clerical,Not-in-family,White,Male,2174.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(4, 10, 24, 32, 44, 48, 52, 53, 94, 95, 96, 97, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 39.0, 77516.0, 13.0, 2174.0, 40.0))","List(1, 2, List(), List(17.917590901421665, 2.0824090985783323))","List(1, 2, List(), List(0.8958795450710834, 0.10412045492891664))",0.0
50.0,Self-emp-not-inc,83311.0,Bachelors,13.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,0.0,0.0,13.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(1, 10, 23, 31, 43, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 50.0, 83311.0, 13.0, 13.0))","List(1, 2, List(), List(7.898398051088204, 12.101601948911796))","List(1, 2, List(), List(0.39491990255441023, 0.6050800974455898))",1.0
38.0,Private,215646.0,HS-grad,9.0,Divorced,Handlers-cleaners,Not-in-family,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(0, 8, 25, 38, 44, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 38.0, 215646.0, 9.0, 40.0))","List(1, 2, List(), List(18.718690776480567, 1.2813092235194286))","List(1, 2, List(), List(0.9359345388240286, 0.06406546117597144))",0.0
53.0,Private,234721.0,11th,7.0,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(0, 13, 23, 38, 43, 49, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 53.0, 234721.0, 7.0, 40.0))","List(1, 2, List(), List(14.362465642341656, 5.637534357658344))","List(1, 2, List(), List(0.7181232821170828, 0.28187671788291724))",0.0
28.0,Private,338409.0,Bachelors,13.0,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0.0,0.0,40.0,Cuba,<=50K,CUB,21.5,-80.0,0.0,"List(0, 100, List(0, 10, 23, 29, 47, 49, 62, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 28.0, 338409.0, 13.0, 40.0))","List(1, 2, List(), List(10.28674916937619, 9.713250830623808))","List(1, 2, List(), List(0.5143374584688095, 0.48566254153119043))",0.0
37.0,Private,284582.0,Masters,14.0,Married-civ-spouse,Exec-managerial,Wife,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(0, 11, 23, 31, 47, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 37.0, 284582.0, 14.0, 40.0))","List(1, 2, List(), List(9.111774750436037, 10.88822524956396))","List(1, 2, List(), List(0.45558873752180185, 0.544411262478198))",1.0
49.0,Private,160187.0,9th,5.0,Married-spouse-absent,Other-service,Not-in-family,Black,Female,0.0,0.0,16.0,Jamaica,<=50K,JAM,18.25,-77.5,0.0,"List(0, 100, List(0, 18, 28, 34, 44, 49, 64, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 49.0, 160187.0, 5.0, 16.0))","List(1, 2, List(), List(19.2238734485469, 0.776126551453101))","List(1, 2, List(), List(0.961193672427345, 0.03880632757265505))",0.0
52.0,Self-emp-not-inc,209642.0,HS-grad,9.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,0.0,0.0,45.0,United States,>50K,USA,38.0,-97.0,1.0,"List(0, 100, List(1, 8, 23, 31, 43, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 52.0, 209642.0, 9.0, 45.0))","List(1, 2, List(), List(11.803534864169333, 8.196465135830667))","List(1, 2, List(), List(0.5901767432084666, 0.40982325679153336))",0.0
31.0,Private,45781.0,Masters,14.0,Never-married,Prof-specialty,Not-in-family,White,Female,14084.0,0.0,50.0,United States,>50K,USA,38.0,-97.0,1.0,"List(0, 100, List(0, 11, 24, 29, 44, 48, 53, 94, 95, 96, 97, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 31.0, 45781.0, 14.0, 14084.0, 50.0))","List(1, 2, List(), List(5.989804833295561, 14.010195166704438))","List(1, 2, List(), List(0.29949024166477806, 0.7005097583352219))",1.0
42.0,Private,159449.0,Bachelors,13.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,5178.0,0.0,40.0,United States,>50K,USA,38.0,-97.0,1.0,"List(0, 100, List(0, 10, 23, 31, 43, 48, 52, 53, 94, 95, 96, 97, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 42.0, 159449.0, 13.0, 5178.0, 40.0))","List(1, 2, List(), List(7.303673845589384, 12.696326154410615))","List(1, 2, List(), List(0.36518369227946923, 0.6348163077205308))",1.0


In [41]:
display(selectedPredictions)

age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income,alpha_3_code,latitude,longitude,label,features,rawPrediction,probability,prediction
39.0,State-gov,77516.0,Bachelors,13.0,Never-married,Adm-clerical,Not-in-family,White,Male,2174.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(4, 10, 24, 32, 44, 48, 52, 53, 94, 95, 96, 97, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 39.0, 77516.0, 13.0, 2174.0, 40.0))","List(1, 2, List(), List(17.917590901421665, 2.0824090985783323))","List(1, 2, List(), List(0.8958795450710834, 0.10412045492891664))",0.0
50.0,Self-emp-not-inc,83311.0,Bachelors,13.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,0.0,0.0,13.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(1, 10, 23, 31, 43, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 50.0, 83311.0, 13.0, 13.0))","List(1, 2, List(), List(7.898398051088204, 12.101601948911796))","List(1, 2, List(), List(0.39491990255441023, 0.6050800974455898))",1.0
38.0,Private,215646.0,HS-grad,9.0,Divorced,Handlers-cleaners,Not-in-family,White,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(0, 8, 25, 38, 44, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 38.0, 215646.0, 9.0, 40.0))","List(1, 2, List(), List(18.718690776480567, 1.2813092235194286))","List(1, 2, List(), List(0.9359345388240286, 0.06406546117597144))",0.0
53.0,Private,234721.0,11th,7.0,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(0, 13, 23, 38, 43, 49, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 53.0, 234721.0, 7.0, 40.0))","List(1, 2, List(), List(14.362465642341656, 5.637534357658344))","List(1, 2, List(), List(0.7181232821170828, 0.28187671788291724))",0.0
28.0,Private,338409.0,Bachelors,13.0,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0.0,0.0,40.0,Cuba,<=50K,CUB,21.5,-80.0,0.0,"List(0, 100, List(0, 10, 23, 29, 47, 49, 62, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 28.0, 338409.0, 13.0, 40.0))","List(1, 2, List(), List(10.28674916937619, 9.713250830623808))","List(1, 2, List(), List(0.5143374584688095, 0.48566254153119043))",0.0
37.0,Private,284582.0,Masters,14.0,Married-civ-spouse,Exec-managerial,Wife,White,Female,0.0,0.0,40.0,United States,<=50K,USA,38.0,-97.0,0.0,"List(0, 100, List(0, 11, 23, 31, 47, 48, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 37.0, 284582.0, 14.0, 40.0))","List(1, 2, List(), List(9.111774750436037, 10.88822524956396))","List(1, 2, List(), List(0.45558873752180185, 0.544411262478198))",1.0
49.0,Private,160187.0,9th,5.0,Married-spouse-absent,Other-service,Not-in-family,Black,Female,0.0,0.0,16.0,Jamaica,<=50K,JAM,18.25,-77.5,0.0,"List(0, 100, List(0, 18, 28, 34, 44, 49, 64, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 49.0, 160187.0, 5.0, 16.0))","List(1, 2, List(), List(19.2238734485469, 0.776126551453101))","List(1, 2, List(), List(0.961193672427345, 0.03880632757265505))",0.0
52.0,Self-emp-not-inc,209642.0,HS-grad,9.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,0.0,0.0,45.0,United States,>50K,USA,38.0,-97.0,1.0,"List(0, 100, List(1, 8, 23, 31, 43, 48, 52, 53, 94, 95, 96, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 52.0, 209642.0, 9.0, 45.0))","List(1, 2, List(), List(11.803534864169333, 8.196465135830667))","List(1, 2, List(), List(0.5901767432084666, 0.40982325679153336))",0.0
31.0,Private,45781.0,Masters,14.0,Never-married,Prof-specialty,Not-in-family,White,Female,14084.0,0.0,50.0,United States,>50K,USA,38.0,-97.0,1.0,"List(0, 100, List(0, 11, 24, 29, 44, 48, 53, 94, 95, 96, 97, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 31.0, 45781.0, 14.0, 14084.0, 50.0))","List(1, 2, List(), List(5.989804833295561, 14.010195166704438))","List(1, 2, List(), List(0.29949024166477806, 0.7005097583352219))",1.0
42.0,Private,159449.0,Bachelors,13.0,Married-civ-spouse,Exec-managerial,Husband,White,Male,5178.0,0.0,40.0,United States,>50K,USA,38.0,-97.0,1.0,"List(0, 100, List(0, 10, 23, 31, 43, 48, 52, 53, 94, 95, 96, 97, 99), List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 42.0, 159449.0, 13.0, 5178.0, 40.0))","List(1, 2, List(), List(7.303673845589384, 12.696326154410615))","List(1, 2, List(), List(0.36518369227946923, 0.6348163077205308))",1.0


In [42]:
classifier_comparison_df = spark.createDataFrame(pd.DataFrame({
  'classifiers':['LR', 'NB', 'DT', 'GBT', 'RF', 'MLP', 'ovA-LR', 'ovA-RF'], 
  'accuracy_scores':[90.1, 78.8, 76.7, 85.5, 89.1, 76.5, 84.8, 84.8]
}))
display(classifier_comparison_df)

accuracy_scores,classifiers
90.1,LR
78.8,NB
76.7,DT
85.5,GBT
89.1,RF
76.5,MLP
84.8,ovA-LR
84.8,ovA-RF


We implemented all the classifiers available with the Pyspark ML module, and the RF classifier emerged with the best accuracy for this dataset with a score of 90.6%. We plan to implement XGBoost on the same data in the future notebooks. <br/> 

For the best models against each classifier, the below-mentioned accuracy scores were achieved: <br/>

<ul>
  <li> Logistic Regression - 90.1% </li>
  <li> Naive Bayes - 78.8% </li>
  <li> Decision Tree - 76.7% </li>
  <li> Gradient-boosted Tree - 85.5% </li>
  <li> Random Forest - 89.1% </li>
  <li> Multilayer Perceptron - 76.5%</li>
  <li> One-vs-All (Logistic Regression - 84.8%, Random Forest - 84.8%) </li>
</ul>

The published version of the notebook is available at - <br/> 
https://databricks-prod-cloudfront.cloud.databricks.com/public/4027ec902e239c93eaaa8714f173bcfc/3173713035751393/675963439015456/2308983777460038/latest.html