In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName('dtree').getOrCreate()

In [4]:
from pyspark.ml import Pipeline

In [5]:
from pyspark.ml.classification import RandomForestClassifier,GBTClassifier,DecisionTreeClassifier

In [6]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [7]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [8]:
train_data, test_data = data.randomSplit([0.7,0.3])

In [9]:
dtc = DecisionTreeClassifier()

In [12]:
rfc = RandomForestClassifier(numTrees=100)

In [11]:
gbt = GBTClassifier()

In [13]:
stc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [15]:
dtc_preds = stc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [16]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[98,99,100,1...|    [0.0,1.0]|  [0.0,1.0]|       1.0|
|  0.0|(692,[123,124,125...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[125,126,127...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[150,151,152...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[151,152,153...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [29.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(69

In [18]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[98,99,100,1...|[-0.7175574083687...|[0.19230298731890...|       1.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[125,126,127...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[1.42019012053189...|[0.94481928974227...|       0.0|
|  0.0|(692,[150

In [19]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [20]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [21]:
print('DTC accuracy:')
acc_eval.evaluate(gbt_preds)

DTC accuracy:


0.967741935483871

In [22]:
rfc_model.featureImportances

SparseVector(692, {121: 0.0005, 128: 0.0006, 161: 0.0006, 184: 0.0003, 203: 0.0006, 207: 0.0007, 234: 0.0006, 235: 0.0062, 238: 0.0005, 242: 0.0004, 244: 0.0078, 262: 0.0075, 263: 0.008, 264: 0.0002, 268: 0.001, 272: 0.0184, 273: 0.0258, 285: 0.001, 289: 0.0004, 290: 0.0086, 295: 0.0007, 299: 0.0015, 300: 0.0305, 301: 0.0056, 303: 0.0006, 314: 0.0004, 318: 0.0006, 319: 0.0027, 322: 0.0007, 326: 0.0018, 327: 0.0058, 328: 0.0114, 329: 0.0018, 330: 0.0144, 343: 0.0006, 346: 0.0071, 347: 0.0045, 349: 0.001, 350: 0.0155, 351: 0.0212, 352: 0.0006, 353: 0.0005, 354: 0.0011, 356: 0.0125, 357: 0.0155, 359: 0.0035, 375: 0.0006, 377: 0.0066, 378: 0.0289, 379: 0.004, 385: 0.0323, 386: 0.0062, 400: 0.0026, 405: 0.002, 406: 0.0424, 407: 0.0053, 408: 0.0027, 409: 0.0002, 427: 0.0135, 428: 0.0042, 429: 0.0047, 433: 0.0175, 434: 0.0457, 435: 0.0099, 438: 0.0012, 440: 0.0292, 442: 0.0064, 454: 0.002, 455: 0.046, 456: 0.0157, 457: 0.0011, 458: 0.0012, 461: 0.0087, 462: 0.0239, 463: 0.0096, 466: 0.0011, 4