In [1]:
import findspark
findspark.init('/home/siddharth/spark-2.4.1-bin-hadoop2.7/')

from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('trees').getOrCreate()

In [3]:
data = spark.read.format('libsvm').load('./Python-and-Spark-for-Big-Data-master/Spark_for_Machine_Learning/Tree_Methods/sample_libsvm_data.txt')

In [4]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [5]:
from pyspark.ml.classification import RandomForestClassifier,DecisionTreeClassifier,GBTClassifier
from pyspark.ml import Pipeline

In [6]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator,MulticlassClassificationEvaluator

In [7]:
train_data,test_data = data.randomSplit([0.7,0.3])

In [8]:
dt = DecisionTreeClassifier()
rf = RandomForestClassifier()
gb = GBTClassifier()

In [9]:
dt_model = dt.fit(train_data)
rf_model = rf.fit(train_data)
gb_model = gb.fit(train_data)

In [10]:
multi_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [11]:
dt_pred = dt_model.transform(test_data)
rf_pred = rf_model.transform(test_data)
gb_pred = gb_model.transform(test_data)

In [12]:
dt_pred.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[100,101,102...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[151,152,153...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[234,235,237...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[99,100,101,...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[119,120,121...|   [0.0,38.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [13]:
rf_pred.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [19.0,1.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[100,101,102...|   [13.0,7.0]|[0.65,0.35]|       0.0|
|  0.0|(692,[123,124,125...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [19.0,1.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[126,127,128...|   [19.0,1.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[127,128,129...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[151,152,153...|   [20.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [19.0,1.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[234,235,237...|   [16.0,4.0]|  [0.8,0.2]|       0.0|
|  1.0|(692,[99,100,101,...|   [9.0,11.0]|[0.45,0.55]|       1.0|
|  1.0|(692,[119,120,121...|   [5.0,15.0]|[0.25,0.75]|       1.0|
|  1.0|(69

In [14]:
gb_pred.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[95,96,97,12...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[100,101,102...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[151,152,153...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[152,153,154...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[234

In [15]:
acc_dt = multi_eval.evaluate(dt_pred)
acc_dt

0.9655172413793104

In [16]:
acc_rf = multi_eval.evaluate(rf_pred)
acc_rf

1.0

In [17]:
acc_gb = multi_eval.evaluate(gb_pred)
acc_gb

0.9655172413793104