In [23]:
# Call all the imports
from pyspark.sql import SparkSession
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier, GBTClassifier
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

In [4]:
# Create a spark session
spark = SparkSession.builder.appName('trees').getOrCreate()

In [6]:
# work on libsvm data
data = spark.read.format("libsvm").load("sample_libsvm_data.txt")

In [7]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [9]:
data.head(1)

[Row(label=0.0, features=SparseVector(692, {127: 51.0, 128: 159.0, 129: 253.0, 130: 159.0, 131: 50.0, 154: 48.0, 155: 238.0, 156: 252.0, 157: 252.0, 158: 252.0, 159: 237.0, 181: 54.0, 182: 227.0, 183: 253.0, 184: 252.0, 185: 239.0, 186: 233.0, 187: 252.0, 188: 57.0, 189: 6.0, 207: 10.0, 208: 60.0, 209: 224.0, 210: 252.0, 211: 253.0, 212: 252.0, 213: 202.0, 214: 84.0, 215: 252.0, 216: 253.0, 217: 122.0, 235: 163.0, 236: 252.0, 237: 252.0, 238: 252.0, 239: 253.0, 240: 252.0, 241: 252.0, 242: 96.0, 243: 189.0, 244: 253.0, 245: 167.0, 262: 51.0, 263: 238.0, 264: 253.0, 265: 253.0, 266: 190.0, 267: 114.0, 268: 253.0, 269: 228.0, 270: 47.0, 271: 79.0, 272: 255.0, 273: 168.0, 289: 48.0, 290: 238.0, 291: 252.0, 292: 252.0, 293: 179.0, 294: 12.0, 295: 75.0, 296: 121.0, 297: 21.0, 300: 253.0, 301: 243.0, 302: 50.0, 316: 38.0, 317: 165.0, 318: 253.0, 319: 233.0, 320: 208.0, 321: 84.0, 328: 253.0, 329: 252.0, 330: 165.0, 343: 7.0, 344: 178.0, 345: 252.0, 346: 240.0, 347: 71.0, 348: 19.0, 349: 28.0

In [10]:
data.printSchema()

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)



In [13]:
# Split the data into train and test
train_data, test_data = data.randomSplit([0.7,0.3])

In [64]:
# Create the 3 types of classifiers
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=75)
gbt = GBTClassifier()

In [65]:
# Build a model on the training data
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [66]:
# Find out the prediction
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [67]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[121,122,123...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[129,130,131...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[181,182,183...|   [31.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(69

In [68]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[121,122,123...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[129,130,131...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[153

In [69]:
# Perform evaluation on the predictions
binary_eval = BinaryClassificationEvaluator(metricName='areaUnderROC')

In [70]:
dtc_eval = binary_eval.evaluate(dtc_preds)
print("DTC AUC: ", dtc_eval)

DTC AUC:  0.9772727272727273


In [71]:
rfc_eval = binary_eval.evaluate(rfc_preds)
print("RFC AUC: ", rfc_eval)

RFC AUC:  1.0


In [72]:
gbt_eval = binary_eval.evaluate(gbt_preds)
print("GBT AUC: ", gbt_eval)

GBT AUC:  0.9810606060606061


In [73]:
# Perform evaluation on the predictions
multiclass_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [74]:
dtc_eval = multiclass_eval.evaluate(dtc_preds)
print("DTC accuracy: ", dtc_eval)

DTC accuracy:  0.9705882352941176


In [75]:
rfc_eval = multiclass_eval.evaluate(rfc_preds)
print("RFC accuracy: ", rfc_eval)

RFC accuracy:  1.0


In [76]:
gbt_eval = multiclass_eval.evaluate(gbt_preds)
print("GBT accuracy: ", gbt_eval)

GBT accuracy:  0.9705882352941176


### Working on College data

In [78]:
data = spark.read.csv("College.csv", inferSchema=True, header=True)

In [79]:
data.printSchema()

root
 |-- School: string (nullable = true)
 |-- Private: string (nullable = true)
 |-- Apps: integer (nullable = true)
 |-- Accept: integer (nullable = true)
 |-- Enroll: integer (nullable = true)
 |-- Top10perc: integer (nullable = true)
 |-- Top25perc: integer (nullable = true)
 |-- F_Undergrad: integer (nullable = true)
 |-- P_Undergrad: integer (nullable = true)
 |-- Outstate: integer (nullable = true)
 |-- Room_Board: integer (nullable = true)
 |-- Books: integer (nullable = true)
 |-- Personal: integer (nullable = true)
 |-- PhD: integer (nullable = true)
 |-- Terminal: integer (nullable = true)
 |-- S_F_Ratio: double (nullable = true)
 |-- perc_alumni: integer (nullable = true)
 |-- Expend: integer (nullable = true)
 |-- Grad_Rate: integer (nullable = true)



In [81]:
data.head(1)

[Row(School='Abilene Christian University', Private='Yes', Apps=1660, Accept=1232, Enroll=721, Top10perc=23, Top25perc=52, F_Undergrad=2885, P_Undergrad=537, Outstate=7440, Room_Board=3300, Books=450, Personal=2200, PhD=70, Terminal=78, S_F_Ratio=18.1, perc_alumni=12, Expend=7041, Grad_Rate=60)]

In [82]:
data.columns

['School',
 'Private',
 'Apps',
 'Accept',
 'Enroll',
 'Top10perc',
 'Top25perc',
 'F_Undergrad',
 'P_Undergrad',
 'Outstate',
 'Room_Board',
 'Books',
 'Personal',
 'PhD',
 'Terminal',
 'S_F_Ratio',
 'perc_alumni',
 'Expend',
 'Grad_Rate']

In [88]:
# Create a features column using assember
assembler=VectorAssembler(inputCols=['Apps',
 'Accept',
 'Enroll',
 'Top10perc',
 'Top25perc',
 'F_Undergrad',
 'P_Undergrad',
 'Outstate',
 'Room_Board',
 'Books',
 'Personal',
 'PhD',
 'Terminal',
 'S_F_Ratio',
 'perc_alumni',
 'Expend',
 'Grad_Rate'], outputCol='features')

features_data = assembler.transform(data)
features_data.show()

+--------------------+-------+----+------+------+---------+---------+-----------+-----------+--------+----------+-----+--------+---+--------+---------+-----------+------+---------+--------------------+
|              School|Private|Apps|Accept|Enroll|Top10perc|Top25perc|F_Undergrad|P_Undergrad|Outstate|Room_Board|Books|Personal|PhD|Terminal|S_F_Ratio|perc_alumni|Expend|Grad_Rate|            features|
+--------------------+-------+----+------+------+---------+---------+-----------+-----------+--------+----------+-----+--------+---+--------+---------+-----------+------+---------+--------------------+
|Abilene Christian...|    Yes|1660|  1232|   721|       23|       52|       2885|        537|    7440|      3300|  450|    2200| 70|      78|     18.1|         12|  7041|       60|[1660.0,1232.0,72...|
|  Adelphi University|    Yes|2186|  1924|   512|       16|       29|       2683|       1227|   12280|      6450|  750|    1500| 29|      30|     12.2|         16| 10527|       56|[2186.0,1924

In [94]:
print("Number of rows: ", features_data.count())

Number of rows:  777


In [105]:
# Convert the predictions column from string to index
private_idx = StringIndexer(inputCol='Private', outputCol='label')
final_data = private_idx.fit(features_data).transform(features_data)
final_data.show()

+--------------------+-------+----+------+------+---------+---------+-----------+-----------+--------+----------+-----+--------+---+--------+---------+-----------+------+---------+--------------------+-----+
|              School|Private|Apps|Accept|Enroll|Top10perc|Top25perc|F_Undergrad|P_Undergrad|Outstate|Room_Board|Books|Personal|PhD|Terminal|S_F_Ratio|perc_alumni|Expend|Grad_Rate|            features|label|
+--------------------+-------+----+------+------+---------+---------+-----------+-----------+--------+----------+-----+--------+---+--------+---------+-----------+------+---------+--------------------+-----+
|Abilene Christian...|    Yes|1660|  1232|   721|       23|       52|       2885|        537|    7440|      3300|  450|    2200| 70|      78|     18.1|         12|  7041|       60|[1660.0,1232.0,72...|  0.0|
|  Adelphi University|    Yes|2186|  1924|   512|       16|       29|       2683|       1227|   12280|      6450|  750|    1500| 29|      30|     12.2|         16| 1052

In [106]:
final_data.printSchema()

root
 |-- School: string (nullable = true)
 |-- Private: string (nullable = true)
 |-- Apps: integer (nullable = true)
 |-- Accept: integer (nullable = true)
 |-- Enroll: integer (nullable = true)
 |-- Top10perc: integer (nullable = true)
 |-- Top25perc: integer (nullable = true)
 |-- F_Undergrad: integer (nullable = true)
 |-- P_Undergrad: integer (nullable = true)
 |-- Outstate: integer (nullable = true)
 |-- Room_Board: integer (nullable = true)
 |-- Books: integer (nullable = true)
 |-- Personal: integer (nullable = true)
 |-- PhD: integer (nullable = true)
 |-- Terminal: integer (nullable = true)
 |-- S_F_Ratio: double (nullable = true)
 |-- perc_alumni: integer (nullable = true)
 |-- Expend: integer (nullable = true)
 |-- Grad_Rate: integer (nullable = true)
 |-- features: vector (nullable = true)
 |-- label: double (nullable = true)



In [108]:
final_data_truncated = final_data.select(['features','label'])

In [109]:
final_data_truncated.show()

+--------------------+-----+
|            features|label|
+--------------------+-----+
|[1660.0,1232.0,72...|  0.0|
|[2186.0,1924.0,51...|  0.0|
|[1428.0,1097.0,33...|  0.0|
|[417.0,349.0,137....|  0.0|
|[193.0,146.0,55.0...|  0.0|
|[587.0,479.0,158....|  0.0|
|[353.0,340.0,103....|  0.0|
|[1899.0,1720.0,48...|  0.0|
|[1038.0,839.0,227...|  0.0|
|[582.0,498.0,172....|  0.0|
|[1732.0,1425.0,47...|  0.0|
|[2652.0,1900.0,48...|  0.0|
|[1179.0,780.0,290...|  0.0|
|[1267.0,1080.0,38...|  0.0|
|[494.0,313.0,157....|  0.0|
|[1420.0,1093.0,22...|  0.0|
|[4302.0,992.0,418...|  0.0|
|[1216.0,908.0,423...|  0.0|
|[1130.0,704.0,322...|  0.0|
|[3540.0,2001.0,10...|  1.0|
+--------------------+-----+
only showing top 20 rows



In [110]:
# Split the data into train and test
train_data, test_data = final_data_truncated.randomSplit([0.7,0.3])
print('Training set size: ', train_data.count())
print('Test set size: ', test_data.count())

Training set size:  541
Test set size:  236


In [111]:
# Create the 3 types of classifiers
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=75)
gbt = GBTClassifier()

In [112]:
# Build a model on the training data
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [113]:
# Find out the prediction
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [114]:
dtc_preds.show()

+--------------------+-----+-------------+-----------+----------+
|            features|label|rawPrediction|probability|prediction|
+--------------------+-----+-------------+-----------+----------+
|[81.0,72.0,51.0,3...|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[100.0,90.0,35.0,...|  0.0|   [16.0,0.0]|  [1.0,0.0]|       0.0|
|[152.0,128.0,75.0...|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[167.0,130.0,46.0...|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[202.0,184.0,122....|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[222.0,185.0,91.0...|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[247.0,189.0,100....|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[261.0,192.0,111....|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[291.0,245.0,126....|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[313.0,228.0,137....|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[318.0,240.0,130....|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[325.0,284.0,95.0...|  0.0|  [292.0,0.0]|  [1.0,0.0]|       0.0|
|[335.0,28

In [115]:
rfc_preds.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|label|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[81.0,72.0,51.0,3...|  0.0|[74.9028067954465...|[0.99870409060595...|       0.0|
|[100.0,90.0,35.0,...|  0.0|[73.6751012664404...|[0.98233468355253...|       0.0|
|[152.0,128.0,75.0...|  0.0|[73.7784274009853...|[0.98371236534647...|       0.0|
|[167.0,130.0,46.0...|  0.0|[74.8991303248583...|[0.99865507099811...|       0.0|
|[202.0,184.0,122....|  0.0|[74.7885560445287...|[0.99718074726038...|       0.0|
|[222.0,185.0,91.0...|  0.0|[74.8991303248583...|[0.99865507099811...|       0.0|
|[247.0,189.0,100....|  0.0|[74.8536757794037...|[0.99804901039205...|       0.0|
|[261.0,192.0,111....|  0.0|[74.8605022862302...|[0.99814003048307...|       0.0|
|[291.0,245.0,126....|  0.0|[69.4524283998074...|[0.92603237866409...|       0.0|
|[313.0,228.0,13

In [116]:
gbt_preds.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|label|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[81.0,72.0,51.0,3...|  0.0|[1.54393694199812...|[0.95638977419966...|       0.0|
|[100.0,90.0,35.0,...|  0.0|[1.55051578968326...|[0.95693527664652...|       0.0|
|[152.0,128.0,75.0...|  0.0|[1.54402910377526...|[0.95639746138870...|       0.0|
|[167.0,130.0,46.0...|  0.0|[1.54393694199812...|[0.95638977419966...|       0.0|
|[202.0,184.0,122....|  0.0|[1.54393694199812...|[0.95638977419966...|       0.0|
|[222.0,185.0,91.0...|  0.0|[1.54393694199812...|[0.95638977419966...|       0.0|
|[247.0,189.0,100....|  0.0|[1.54393694199812...|[0.95638977419966...|       0.0|
|[261.0,192.0,111....|  0.0|[1.54393694199812...|[0.95638977419966...|       0.0|
|[291.0,245.0,126....|  0.0|[1.52949084503695...|[0.95516871165448...|       0.0|
|[313.0,228.0,13

In [117]:
# Perform evaluation on the predictions
multiclass_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [118]:
dtc_eval = multiclass_eval.evaluate(dtc_preds)
print("DTC accuracy: ", dtc_eval)

DTC accuracy:  0.8983050847457628


In [121]:
rfc_eval = multiclass_eval.evaluate(rfc_preds)
print("RFC accuracy: ", rfc_eval)

RFC accuracy:  0.923728813559322


In [122]:
GBT_eval = multiclass_eval.evaluate(gbt_preds)
print("GBT accuracy: ", gbt_eval)

GBT accuracy:  0.9705882352941176


### Working with Dog Food dataset

In [124]:
data = spark.read.csv("dog_food.csv", inferSchema=True, header=True)

In [125]:
data.printSchema()

root
 |-- A: integer (nullable = true)
 |-- B: integer (nullable = true)
 |-- C: double (nullable = true)
 |-- D: integer (nullable = true)
 |-- Spoiled: double (nullable = true)



In [126]:
data.show()

+---+---+----+---+-------+
|  A|  B|   C|  D|Spoiled|
+---+---+----+---+-------+
|  4|  2|12.0|  3|    1.0|
|  5|  6|12.0|  7|    1.0|
|  6|  2|13.0|  6|    1.0|
|  4|  2|12.0|  1|    1.0|
|  4|  2|12.0|  3|    1.0|
| 10|  3|13.0|  9|    1.0|
|  8|  5|14.0|  5|    1.0|
|  5|  8|12.0|  8|    1.0|
|  6|  5|12.0|  9|    1.0|
|  3|  3|12.0|  1|    1.0|
|  9|  8|11.0|  3|    1.0|
|  1| 10|12.0|  3|    1.0|
|  1|  5|13.0| 10|    1.0|
|  2| 10|12.0|  6|    1.0|
|  1| 10|11.0|  4|    1.0|
|  5|  3|12.0|  2|    1.0|
|  4|  9|11.0|  8|    1.0|
|  5|  1|11.0|  1|    1.0|
|  4|  9|12.0| 10|    1.0|
|  5|  8|10.0|  9|    1.0|
+---+---+----+---+-------+
only showing top 20 rows



In [127]:
data.head(2)

[Row(A=4, B=2, C=12.0, D=3, Spoiled=1.0),
 Row(A=5, B=6, C=12.0, D=7, Spoiled=1.0)]

In [129]:
data.columns

['A', 'B', 'C', 'D', 'Spoiled']

In [130]:
# Create a features column using assember
assembler=VectorAssembler(inputCols=['A', 'B', 'C', 'D'], outputCol='features')

features_data = assembler.transform(data)
features_data.show()

+---+---+----+---+-------+-------------------+
|  A|  B|   C|  D|Spoiled|           features|
+---+---+----+---+-------+-------------------+
|  4|  2|12.0|  3|    1.0| [4.0,2.0,12.0,3.0]|
|  5|  6|12.0|  7|    1.0| [5.0,6.0,12.0,7.0]|
|  6|  2|13.0|  6|    1.0| [6.0,2.0,13.0,6.0]|
|  4|  2|12.0|  1|    1.0| [4.0,2.0,12.0,1.0]|
|  4|  2|12.0|  3|    1.0| [4.0,2.0,12.0,3.0]|
| 10|  3|13.0|  9|    1.0|[10.0,3.0,13.0,9.0]|
|  8|  5|14.0|  5|    1.0| [8.0,5.0,14.0,5.0]|
|  5|  8|12.0|  8|    1.0| [5.0,8.0,12.0,8.0]|
|  6|  5|12.0|  9|    1.0| [6.0,5.0,12.0,9.0]|
|  3|  3|12.0|  1|    1.0| [3.0,3.0,12.0,1.0]|
|  9|  8|11.0|  3|    1.0| [9.0,8.0,11.0,3.0]|
|  1| 10|12.0|  3|    1.0|[1.0,10.0,12.0,3.0]|
|  1|  5|13.0| 10|    1.0|[1.0,5.0,13.0,10.0]|
|  2| 10|12.0|  6|    1.0|[2.0,10.0,12.0,6.0]|
|  1| 10|11.0|  4|    1.0|[1.0,10.0,11.0,4.0]|
|  5|  3|12.0|  2|    1.0| [5.0,3.0,12.0,2.0]|
|  4|  9|11.0|  8|    1.0| [4.0,9.0,11.0,8.0]|
|  5|  1|11.0|  1|    1.0| [5.0,1.0,11.0,1.0]|
|  4|  9|12.0

In [132]:
final_data = features_data.select(['features','Spoiled'])
final_data.show()

+-------------------+-------+
|           features|Spoiled|
+-------------------+-------+
| [4.0,2.0,12.0,3.0]|    1.0|
| [5.0,6.0,12.0,7.0]|    1.0|
| [6.0,2.0,13.0,6.0]|    1.0|
| [4.0,2.0,12.0,1.0]|    1.0|
| [4.0,2.0,12.0,3.0]|    1.0|
|[10.0,3.0,13.0,9.0]|    1.0|
| [8.0,5.0,14.0,5.0]|    1.0|
| [5.0,8.0,12.0,8.0]|    1.0|
| [6.0,5.0,12.0,9.0]|    1.0|
| [3.0,3.0,12.0,1.0]|    1.0|
| [9.0,8.0,11.0,3.0]|    1.0|
|[1.0,10.0,12.0,3.0]|    1.0|
|[1.0,5.0,13.0,10.0]|    1.0|
|[2.0,10.0,12.0,6.0]|    1.0|
|[1.0,10.0,11.0,4.0]|    1.0|
| [5.0,3.0,12.0,2.0]|    1.0|
| [4.0,9.0,11.0,8.0]|    1.0|
| [5.0,1.0,11.0,1.0]|    1.0|
|[4.0,9.0,12.0,10.0]|    1.0|
| [5.0,8.0,10.0,9.0]|    1.0|
+-------------------+-------+
only showing top 20 rows



In [133]:
# Split the data into train and test
train_data, test_data = final_data.randomSplit([0.7,0.3])
print('Training set size: ', train_data.count())
print('Test set size: ', test_data.count())

Training set size:  335
Test set size:  155


In [136]:
# Create the 3 types of classifiers
dtc = DecisionTreeClassifier(labelCol='Spoiled')
rfc = RandomForestClassifier(labelCol='Spoiled')
gbt = GBTClassifier(labelCol='Spoiled')

In [137]:
# Build a model on the training data
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [138]:
# Find out the prediction
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [139]:
dtc_preds.show()

+-------------------+-------+-------------+-----------+----------+
|           features|Spoiled|rawPrediction|probability|prediction|
+-------------------+-------+-------------+-----------+----------+
|  [1.0,2.0,9.0,1.0]|    0.0|    [5.0,0.0]|  [1.0,0.0]|       0.0|
|[1.0,4.0,13.0,10.0]|    1.0|    [0.0,7.0]|  [0.0,1.0]|       1.0|
| [1.0,5.0,8.0,10.0]|    0.0|  [180.0,0.0]|  [1.0,0.0]|       0.0|
|[1.0,5.0,13.0,10.0]|    1.0|    [0.0,7.0]|  [0.0,1.0]|       1.0|
|  [1.0,6.0,7.0,8.0]|    0.0|  [180.0,0.0]|  [1.0,0.0]|       0.0|
|[1.0,6.0,11.0,10.0]|    1.0|    [0.0,1.0]|  [0.0,1.0]|       1.0|
|  [1.0,7.0,7.0,2.0]|    0.0|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  [1.0,7.0,7.0,6.0]|    0.0|  [180.0,0.0]|  [1.0,0.0]|       0.0|
|  [1.0,7.0,8.0,2.0]|    0.0|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|[1.0,7.0,11.0,10.0]|    1.0|    [0.0,1.0]|  [0.0,1.0]|       1.0|
|  [1.0,9.0,7.0,5.0]|    0.0|  [180.0,0.0]|  [1.0,0.0]|       0.0|
|  [1.0,9.0,9.0,7.0]|    0.0|  [180.0,0.0]|  [1.0,0.0]|       

In [140]:
rfc_preds.show()

+-------------------+-------+--------------------+--------------------+----------+
|           features|Spoiled|       rawPrediction|         probability|prediction|
+-------------------+-------+--------------------+--------------------+----------+
|  [1.0,2.0,9.0,1.0]|    0.0|[16.8333333333333...|[0.84166666666666...|       0.0|
|[1.0,4.0,13.0,10.0]|    1.0|          [0.0,20.0]|           [0.0,1.0]|       1.0|
| [1.0,5.0,8.0,10.0]|    0.0|          [20.0,0.0]|           [1.0,0.0]|       0.0|
|[1.0,5.0,13.0,10.0]|    1.0|          [0.0,20.0]|           [0.0,1.0]|       1.0|
|  [1.0,6.0,7.0,8.0]|    0.0|[19.7826086956521...|[0.98913043478260...|       0.0|
|[1.0,6.0,11.0,10.0]|    1.0|        [3.75,16.25]|     [0.1875,0.8125]|       1.0|
|  [1.0,7.0,7.0,2.0]|    0.0|          [20.0,0.0]|           [1.0,0.0]|       0.0|
|  [1.0,7.0,7.0,6.0]|    0.0|          [20.0,0.0]|           [1.0,0.0]|       0.0|
|  [1.0,7.0,8.0,2.0]|    0.0|          [20.0,0.0]|           [1.0,0.0]|       0.0|
|[1.

In [141]:
gbt_preds.show()

+-------------------+-------+--------------------+--------------------+----------+
|           features|Spoiled|       rawPrediction|         probability|prediction|
+-------------------+-------+--------------------+--------------------+----------+
|  [1.0,2.0,9.0,1.0]|    0.0|[1.24108734512716...|[0.92288271343223...|       0.0|
|[1.0,4.0,13.0,10.0]|    1.0|[-1.5435020027249...|[0.04364652142729...|       1.0|
| [1.0,5.0,8.0,10.0]|    0.0|[1.54364066094232...|[0.95636505271255...|       0.0|
|[1.0,5.0,13.0,10.0]|    1.0|[-1.5435020027249...|[0.04364652142729...|       1.0|
|  [1.0,6.0,7.0,8.0]|    0.0|[1.54364066094232...|[0.95636505271255...|       0.0|
|[1.0,6.0,11.0,10.0]|    1.0|[-1.5435020027249...|[0.04364652142729...|       1.0|
|  [1.0,7.0,7.0,2.0]|    0.0|[1.54476023772918...|[0.95645839926581...|       0.0|
|  [1.0,7.0,7.0,6.0]|    0.0|[1.54364066094232...|[0.95636505271255...|       0.0|
|  [1.0,7.0,8.0,2.0]|    0.0|[1.54476023772918...|[0.95645839926581...|       0.0|
|[1.

In [144]:
# Perform evaluation on the predictions
multiclass_eval = MulticlassClassificationEvaluator(labelCol='Spoiled', metricName='accuracy')

In [145]:
dtc_eval = multiclass_eval.evaluate(dtc_preds)
print("DTC accuracy: ", dtc_eval)

DTC accuracy:  0.9870967741935484


In [147]:
rfc_eval = multiclass_eval.evaluate(rfc_preds)
print("RFC accuracy: ", rfc_eval)

RFC accuracy:  0.9935483870967742


In [148]:
gbt_eval = multiclass_eval.evaluate(gbt_preds)
print("GBT accuracy: ", gbt_eval)

GBT accuracy:  0.9870967741935484


In [152]:
# Feature importances
dtc_model.featureImportances

SparseVector(4, {0: 0.0135, 1: 0.0087, 2: 0.9536, 3: 0.0242})

In [150]:
# Feature importances
rfc_model.featureImportances

SparseVector(4, {0: 0.0232, 1: 0.0343, 2: 0.9092, 3: 0.0334})

In [153]:
# Feature importances
gbt_model.featureImportances

SparseVector(4, {0: 0.0784, 1: 0.0863, 2: 0.7532, 3: 0.0821})