In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Tree').getOrCreate()

In [2]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier, DecisionTreeClassifier)

In [3]:
data = spark.read.format('libsvm').load('Dataset/sample_libsvm_data.txt')

In [4]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [5]:
train_data , test_data = data.randomSplit([0.7,0.3])

In [6]:
dtc = DecisionTreeClassifier(labelCol='label',featuresCol='features')
rfc = RandomForestClassifier(labelCol='label',featuresCol='features',numTrees=100)
gbt = GBTClassifier(labelCol='label',featuresCol='features')

In [7]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [8]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [9]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[123,124,125...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[151,152,153...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [0.0,39.0]|  [0.0,1.0]|       1.0|
|  0.0|(692,[154,155,156...|   [0.0,39.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[119,120,121...|   [0.0,39.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [10]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[151,152,153...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[152,153,154...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[153,154,155...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[154

In [11]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [12]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [13]:
print('DTC Accuracy')
acc_eval.evaluate(dtc_preds)

DTC Accuracy


0.9310344827586207

In [14]:
print('RFC Accuracy')
acc_eval.evaluate(rfc_preds)

RFC Accuracy


1.0

In [15]:
print('GBT Accuracy')
acc_eval.evaluate(gbt_preds)

GBT Accuracy


0.9310344827586207

In [16]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0005, 126: 0.0005, 128: 0.001, 152: 0.001, 154: 0.0003, 158: 0.0005, 180: 0.0003, 181: 0.0011, 183: 0.001, 184: 0.0005, 207: 0.0017, 208: 0.0003, 213: 0.0005, 214: 0.0006, 215: 0.0005, 234: 0.0002, 242: 0.0023, 245: 0.0005, 260: 0.0051, 262: 0.0149, 263: 0.0004, 271: 0.0002, 272: 0.0143, 273: 0.0145, 285: 0.0003, 286: 0.003, 290: 0.0028, 291: 0.0006, 300: 0.0154, 301: 0.0066, 304: 0.0018, 316: 0.0005, 317: 0.0095, 318: 0.0003, 322: 0.0028, 323: 0.0082, 326: 0.0005, 327: 0.0022, 328: 0.0069, 329: 0.0072, 341: 0.0011, 347: 0.0009, 350: 0.0312, 351: 0.0394, 355: 0.0006, 357: 0.0072, 369: 0.0009, 371: 0.0098, 373: 0.0071, 374: 0.0005, 375: 0.0015, 377: 0.0263, 378: 0.0227, 379: 0.0471, 382: 0.0006, 385: 0.0149, 386: 0.002, 398: 0.0015, 401: 0.0075, 405: 0.0103, 406: 0.0349, 407: 0.0647, 408: 0.005, 414: 0.0045, 415: 0.0029, 416: 0.002, 427: 0.0005, 430: 0.0004, 433: 0.0184, 434: 0.0093, 435: 0.0245, 436: 0.0006, 438: 0.0027, 440: 0.0078, 441: 0.0018, 442: 0.0141, 