# Random Forest Example

This is just a quick walkthrough of the Documentation's Example of Random Forest:

In [1]:
from pyspark.sql import SparkSession


spark = SparkSession.builder.appName('mytree').getOrCreate()

In [2]:
from pyspark.ml import Pipeline

In [3]:
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier

In [4]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [5]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [6]:
train_data, test_data = data.randomSplit([0.7,0.3])

In [9]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [11]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [12]:
dtc_preds = dtc_model.transform(test_data)

In [13]:
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [14]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[122,123,124...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[122,123,148...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[181,182,183...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[97,98,99,12...|   [0.0,40.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[123,124,125...|   [0.0,40.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[124,125,126...|   [0.0,40.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[125,126,127...|   [0.0,40.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [15]:
rfc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[122,123,124...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[122,123,148...|  [86.0,14.0]|[0.86,0.14]|       0.0|
|  0.0|(692,[123,124,125...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[124,125,126...|  [85.0,15.0]|[0.85,0.15]|       0.0|
|  0.0|(692,[124,125,126...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[124,125,126...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[126,127,128...|   [93.0,7.0]|[0.93,0.07]|       0.0|
|  0.0|(692,[181,182,183...|   [92.0,8.0]|[0.92,0.08]|       0.0|
|  1.0|(692,[97,98,99,12...|  [23.0,77.0]|[0.23,0.77]|       1.0|
|  1.0|(692,[123,124,125...|  [0.0,100.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[124,125,126...|  [0.0,100.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[125,126,127...|  [0.0,100.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [16]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[122,123,124...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[122,123,148...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[181,182,183...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  1.0|(692,[97,98,99,12...|[-1.5435020027249...|[0.04364652142729...|       1.0|
|  1.0|(692,[123

In [17]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [18]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [19]:
print('DTC Accuracy: ')
acc_eval.evaluate(dtc_preds)

DTC Accuracy: 


0.9583333333333334

In [21]:
print('rfc Accuracy: ')
acc_eval.evaluate(rfc_preds)

rfc Accuracy: 


1.0

In [22]:
print('rfc Accuracy: ')
acc_eval.evaluate(gbt_preds)

rfc Accuracy: 


0.9583333333333334

In [23]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0009, 119: 0.0008, 147: 0.0011, 156: 0.0007, 178: 0.0002, 180: 0.0009, 181: 0.0004, 182: 0.0011, 185: 0.0003, 190: 0.0005, 207: 0.0021, 214: 0.0003, 234: 0.0013, 236: 0.0002, 243: 0.001, 244: 0.0083, 245: 0.0081, 258: 0.0009, 262: 0.0077, 263: 0.0196, 271: 0.0012, 274: 0.0005, 287: 0.0007, 290: 0.0164, 291: 0.0017, 298: 0.0017, 299: 0.0057, 301: 0.0114, 316: 0.0022, 317: 0.011, 318: 0.0015, 323: 0.0052, 327: 0.0043, 328: 0.013, 330: 0.0088, 345: 0.011, 347: 0.0018, 349: 0.0005, 350: 0.0021, 351: 0.0208, 356: 0.0138, 357: 0.0083, 372: 0.004, 375: 0.0032, 377: 0.009, 378: 0.0417, 379: 0.0245, 383: 0.0007, 384: 0.001, 385: 0.0005, 388: 0.0002, 399: 0.0007, 400: 0.0085, 402: 0.0005, 405: 0.0189, 406: 0.0236, 407: 0.064, 408: 0.0009, 411: 0.0015, 412: 0.0014, 426: 0.0011, 429: 0.0149, 430: 0.0005, 432: 0.0022, 433: 0.0399, 434: 0.0312, 435: 0.0358, 440: 0.0094, 441: 0.0104, 442: 0.0057, 443: 0.0029, 454: 0.0038, 455: 0.0086, 456: 0.0251, 457: 0.0007, 460: 0.0024, 4