We'll work through Decision Tree, Random Forest and Gradient Boosted Tress.

We will also expand a little more from the documentation example and show more some useful evaluation features!

In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('mytree').getOrCreate()

In [2]:
from pyspark.ml import Pipeline

In [3]:
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier, DecisionTreeClassifier)

In [4]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [6]:
data.show(5)

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
+-----+--------------------+
only showing top 5 rows



In [7]:
train_data, test_data = data.randomSplit([0.7,0.3])

In [9]:
dtc = DecisionTreeClassifier()             # default parameters labelCol='label'
rfc = RandomForestClassifier(numTrees=100) # as numTress increases, accuracy does not increase anymore. 
                                           # The goal is to reach this number.
gbt = GBTClassifier()                      # default parameters

In [10]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [11]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [24]:
dtc_preds.show(4)

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[98,99,100,1...|    [0.0,1.0]|  [0.0,1.0]|       1.0|
|  0.0|(692,[121,122,123...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
+-----+--------------------+-------------+-----------+----------+
only showing top 4 rows



In [25]:
rfc_preds.show(4)

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[98,99,100,1...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[121,122,123...|   [95.0,5.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[123,124,125...|   [95.0,5.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[124,125,126...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
+-----+--------------------+-------------+-----------+----------+
only showing top 4 rows



In [26]:
gbt_preds.show(4)

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[98,99,100,1...|[-0.5628413359418...|[0.24495872828862...|       1.0|
|  0.0|(692,[121,122,123...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
+-----+--------------------+--------------------+--------------------+----------+
only showing top 4 rows



In [16]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [17]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [23]:
print('DTC, RFC and GBT ACCURACY:')
[acc_eval.evaluate(dtc_preds),acc_eval.evaluate(rfc_preds),acc_eval.evaluate(gbt_preds)]

DTC, RFC and GBT ACCURACY:


[0.9375, 1.0, 0.9375]

In [22]:
# shows the feature and its importance
rfc_model.featureImportances

SparseVector(692, {126: 0.0007, 131: 0.0019, 155: 0.0006, 156: 0.0013, 180: 0.001, 184: 0.0008, 186: 0.0007, 187: 0.0008, 207: 0.0041, 210: 0.0015, 215: 0.0007, 216: 0.0075, 231: 0.0018, 235: 0.0086, 236: 0.0007, 243: 0.0017, 244: 0.0077, 245: 0.0009, 260: 0.0008, 262: 0.0071, 263: 0.0069, 267: 0.0006, 268: 0.0006, 271: 0.0003, 289: 0.0076, 291: 0.0057, 295: 0.0142, 297: 0.0006, 298: 0.0007, 300: 0.0019, 302: 0.0024, 313: 0.0006, 314: 0.0014, 317: 0.0094, 318: 0.0135, 319: 0.0008, 320: 0.0005, 323: 0.0179, 324: 0.001, 325: 0.0007, 328: 0.0007, 329: 0.0094, 330: 0.0145, 331: 0.0006, 346: 0.0078, 348: 0.0004, 349: 0.0005, 350: 0.0397, 351: 0.0394, 352: 0.0041, 353: 0.0007, 354: 0.0004, 355: 0.0008, 357: 0.0014, 358: 0.0136, 370: 0.001, 372: 0.0091, 373: 0.0084, 375: 0.0014, 378: 0.0611, 379: 0.0194, 380: 0.0006, 381: 0.0016, 382: 0.0024, 383: 0.0011, 385: 0.0082, 386: 0.0004, 388: 0.0006, 397: 0.0007, 399: 0.0125, 400: 0.0014, 401: 0.0193, 402: 0.0007, 403: 0.0013, 405: 0.0182, 406: 0.03