# Decision Tree Example

This is just a quick walkthrough of the Documentation's Example of Random Forest:

In [12]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, DecisionTreeClassifier, GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [7]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('tree').getOrCreate()

In [8]:
# Load and parse the data file, converting it to a DataFrame.
data = spark.read.format("libsvm").load("./data/sample_libsvm_data.txt")


In [9]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [10]:
# Dividir los datos en train_data y test_data (70% y 30%) 
train_data,test_data = data.randomSplit([0.7,0.3])

In [15]:
#Crear los modelos

dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees= 100)
gbt = GBTClassifier()

In [16]:
# Entrenar los modelos con los datos de entrenamiento
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [17]:
# Hacer las predicciones
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [19]:
dtc_preds.printSchema()

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [20]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[98,99,100,1...|    [0.0,1.0]|  [0.0,1.0]|       1.0|
|  0.0|(692,[124,125,126...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[151,152,153...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [34.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [0.0,39.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[97,98,99,12...|   [0.0,39.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[123,124,125...|   [0.0,39.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[123,124,125...|   [0.0,39.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [21]:
rfc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[98,99,100,1...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[124,125,126...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[126,127,128...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[126,127,128...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[126,127,128...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[127,128,129...|   [93.0,7.0]|[0.93,0.07]|       0.0|
|  0.0|(692,[151,152,153...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[153,154,155...|   [93.0,7.0]|[0.93,0.07]|       0.0|
|  0.0|(692,[154,155,156...|  [70.0,30.0]|  [0.7,0.3]|       0.0|
|  1.0|(692,[97,98,99,12...|  [18.0,82.0]|[0.18,0.82]|       1.0|
|  1.0|(692,[123,124,125...|   [5.0,95.0]|[0.05,0.95]|       1.0|
|  1.0|(692,[123,124,125...|  [0.0,100.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [22]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[98,99,100,1...|[-0.6467328511439...|[0.21526678244352...|       1.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.33683157264700...|[0.93545456275330...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[1.27088160333594...|[0.92701820697318...|       0.0|
|  0.0|(692,[151,152,153...|[1.20458335876362...|[0.91752364598482...|       0.0|
|  0.0|(692,[153,154,155...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[154,155,156...|[-1.2548359576286...|[0.07518291818174...|       1.0|
|  1.0|(692,[97,

In [23]:
#Selecionar (prediction, true label) y ejecutar con test
acc_eval = MulticlassClassificationEvaluator(metricName="accuracy")

In [24]:
print('DTC ACCURACY:')
acc_eval.evaluate(dtc_preds)

0.9230769230769231

In [26]:
print('RFC ACCURACY:')
acc_eval.evaluate(rfc_preds)

RFC ACCURACY:


1.0

In [27]:
print('GBT ACCURACY:')
acc_eval.evaluate(gbt_preds)

GBT ACCURACY:


0.9230769230769231

In [31]:
#rfc_model.featureImportances