In [1]:
import findspark
findspark.init('/home/shashank/spark-2.3.2-bin-hadoop2.7')

In [2]:
import pyspark
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName('tree').getOrCreate()

In [4]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier

In [5]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [6]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [7]:
train, test = data.randomSplit([0.7,0.3])

In [8]:
#Decision Tree
dtc = DecisionTreeClassifier(featuresCol='features', labelCol='label', maxBins=20, maxDepth=3)
rfc = RandomForestClassifier(featuresCol='features', labelCol='label', maxBins=100, maxDepth=10, minInfoGain=0.01, numTrees=80)
gbc = GBTClassifier(featuresCol='features', labelCol='label', maxBins= 100, maxDepth=10, minInfoGain=0.01, maxIter=40)

In [9]:
dtc_model = dtc.fit(train)
rfc_model = rfc.fit(train)
gbc_model = gbc.fit(train)

In [10]:
dtc_pred = dtc_model.transform(test)
rfc_pred = rfc_model.transform(test)
gbc_pred = gbc_model.transform(test)

In [11]:
dtc_pred.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[100,101,102...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[128,129,130...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[181,182,183...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[97,98,99,12...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[123,124,125...|   [0.0,40.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[124,125,126...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(69

In [12]:
rfc_pred.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[100,101,102...|[50.9722222222222...|[0.63715277777777...|       0.0|
|  0.0|(692,[123,124,125...|[78.9722222222222...|[0.98715277777777...|       0.0|
|  0.0|(692,[124,125,126...|[79.9722222222222...|[0.99965277777777...|       0.0|
|  0.0|(692,[124,125,126...|[75.9722222222222...|[0.94965277777777...|       0.0|
|  0.0|(692,[126,127,128...|          [78.0,2.0]|       [0.975,0.025]|       0.0|
|  0.0|(692,[126,127,128...|[78.9722222222222...|[0.98715277777777...|       0.0|
|  0.0|(692,[128,129,130...|[79.9722222222222...|[0.99965277777777...|       0.0|
|  0.0|(692,[154,155,156...|[56.9722222222222...|[0.71215277777777...|       0.0|
|  0.0|(692,[181,182,183...|[73.9722222222222...|[0.92465277777777...|       0.0|
|  1.0|(692,[97,

In [13]:
gbc_pred.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[100,101,102...|[1.53832761151627...|[0.95591945897249...|       0.0|
|  0.0|(692,[123,124,125...|[1.80732231417309...|[0.97377953278270...|       0.0|
|  0.0|(692,[124,125,126...|[1.80732231417309...|[0.97377953278270...|       0.0|
|  0.0|(692,[124,125,126...|[1.80732231417309...|[0.97377953278270...|       0.0|
|  0.0|(692,[126,127,128...|[1.80732231417309...|[0.97377953278270...|       0.0|
|  0.0|(692,[126,127,128...|[1.80732231417309...|[0.97377953278270...|       0.0|
|  0.0|(692,[128,129,130...|[1.80732231417309...|[0.97377953278270...|       0.0|
|  0.0|(692,[154,155,156...|[1.19399142778609...|[0.91590633404825...|       0.0|
|  0.0|(692,[181,182,183...|[1.80732231417309...|[0.97377953278270...|       0.0|
|  1.0|(692,[97,

In [15]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

In [17]:
acc_eval = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='label', metricName='accuracy')

In [19]:
dtc_eval = acc_eval.evaluate(dtc_pred)
rfc_eval = acc_eval.evaluate(rfc_pred)
gbc_eval = acc_eval.evaluate(gbc_pred)

In [21]:
dtc_eval

0.88

In [22]:
rfc_eval

1.0

In [23]:
gbc_eval

0.88

In [24]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0016, 154: 0.0003, 174: 0.0007, 183: 0.0006, 185: 0.0002, 187: 0.0004, 206: 0.0019, 207: 0.0008, 208: 0.0004, 209: 0.0008, 212: 0.0021, 215: 0.0072, 232: 0.003, 238: 0.0012, 240: 0.0014, 243: 0.0006, 244: 0.0072, 261: 0.0086, 262: 0.0204, 263: 0.0158, 264: 0.0012, 273: 0.0091, 285: 0.0007, 293: 0.0005, 298: 0.0012, 300: 0.0099, 301: 0.0099, 302: 0.0045, 316: 0.0006, 328: 0.0039, 330: 0.0172, 331: 0.0004, 341: 0.0029, 346: 0.0056, 348: 0.0019, 349: 0.0006, 350: 0.004, 351: 0.0052, 352: 0.0012, 356: 0.0209, 357: 0.0083, 358: 0.0093, 360: 0.0018, 372: 0.0083, 374: 0.003, 378: 0.0506, 379: 0.0378, 384: 0.0141, 387: 0.0028, 398: 0.0012, 399: 0.0088, 400: 0.011, 401: 0.0008, 403: 0.0011, 405: 0.0365, 406: 0.0418, 407: 0.0437, 408: 0.0007, 425: 0.0005, 426: 0.004, 427: 0.0185, 428: 0.0104, 429: 0.0007, 433: 0.0708, 434: 0.012, 453: 0.0004, 454: 0.0045, 455: 0.0327, 456: 0.0094, 460: 0.0091, 461: 0.0225, 462: 0.0732, 463: 0.0126, 465: 0.0004, 468: 0.0045, 469: 0.0183,