In [1]:
#import SparkSession
from pyspark.sql import SparkSession
spark=SparkSession.builder.appName('decision_tree').getOrCreate()

In [8]:
#read the dataset
df=spark.read.csv('affairs.csv',inferSchema=True,header=True)

In [9]:
#check the shape of the data 
print((df.count(),len(df.columns)))

(6366, 6)


In [10]:
#printSchema
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)



In [12]:
#view the dataset
df.show(5)

+-------------+----+-----------+--------+---------+-------+
|rate_marriage| age|yrs_married|children|religious|affairs|
+-------------+----+-----------+--------+---------+-------+
|            5|32.0|        6.0|     1.0|        3|      0|
|            4|22.0|        2.5|     0.0|        2|      0|
|            3|32.0|        9.0|     3.0|        3|      1|
|            3|27.0|       13.0|     3.0|        1|      1|
|            4|22.0|        2.5|     0.0|        1|      1|
+-------------+----+-----------+--------+---------+-------+
only showing top 5 rows



In [13]:
#Exploratory Data Analysis
df.describe().show()

+-------+------------------+------------------+-----------------+------------------+------------------+------------------+
|summary|     rate_marriage|               age|      yrs_married|          children|         religious|           affairs|
+-------+------------------+------------------+-----------------+------------------+------------------+------------------+
|  count|              6366|              6366|             6366|              6366|              6366|              6366|
|   mean| 4.109644989004084|29.082862079798932| 9.00942507068803|1.3968740182218033|2.4261702796104303|0.3224945020420987|
| stddev|0.9614295945655025| 6.847881883668817|7.280119972766412| 1.433470828560344|0.8783688402641785| 0.467467779921086|
|    min|                 1|              17.5|              0.5|               0.0|                 1|                 0|
|    max|                 5|              42.0|             23.0|               5.5|                 4|                 1|
+-------+-------

In [14]:
df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 2053|
|      0| 4313|
+-------+-----+



In [17]:
df.groupBy('rate_marriage').count().show()

+-------------+-----+
|rate_marriage|count|
+-------------+-----+
|            1|   99|
|            3|  993|
|            5| 2684|
|            4| 2242|
|            2|  348|
+-------------+-----+



In [19]:
from pyspark.ml.feature import VectorAssembler

In [21]:
df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'], outputCol="features")
df = df_assembler.transform(df)

In [22]:
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)
 |-- features: vector (nullable = true)



In [23]:
df.select(['features','affairs']).show(10,False)

+-----------------------+-------+
|features               |affairs|
+-----------------------+-------+
|[5.0,32.0,6.0,1.0,3.0] |0      |
|[4.0,22.0,2.5,0.0,2.0] |0      |
|[3.0,32.0,9.0,3.0,3.0] |1      |
|[3.0,27.0,13.0,3.0,1.0]|1      |
|[4.0,22.0,2.5,0.0,1.0] |1      |
|[4.0,37.0,16.5,4.0,3.0]|1      |
|[5.0,27.0,9.0,1.0,1.0] |1      |
|[4.0,27.0,9.0,0.0,2.0] |1      |
|[5.0,37.0,23.0,5.5,2.0]|1      |
|[5.0,37.0,23.0,5.5,2.0]|1      |
+-----------------------+-------+
only showing top 10 rows



In [24]:
#select data for building model
model_df=df.select(['features','affairs'])

In [25]:
from pyspark.ml.classification import DecisionTreeClassifier

In [26]:
train_df,test_df=model_df.randomSplit([0.75,0.25])

In [27]:
train_df.count()

4775

In [29]:
train_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 1547|
|      0| 3228|
+-------+-----+



In [30]:
test_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1|  506|
|      0| 1085|
+-------+-----+



In [32]:
dt_classifier=DecisionTreeClassifier(labelCol='affairs').fit(train_df)

In [33]:
dt_classifier

DecisionTreeClassificationModel (uid=DecisionTreeClassifier_4e048cf4b473ad266b10) of depth 5 with 61 nodes

In [34]:
predictions=dt_classifier.transform(test_df)

In [36]:
predictions.show(10,False)

+-----------------------+-------+-------------+----------------------------------------+----------+
|features               |affairs|rawPrediction|probability                             |prediction|
+-----------------------+-------+-------------+----------------------------------------+----------+
|[1.0,22.0,2.5,1.0,1.0] |1      |[13.0,32.0]  |[0.28888888888888886,0.7111111111111111]|1.0       |
|[1.0,22.0,2.5,1.0,1.0] |1      |[13.0,32.0]  |[0.28888888888888886,0.7111111111111111]|1.0       |
|[1.0,22.0,2.5,1.0,3.0] |1      |[18.0,13.0]  |[0.5806451612903226,0.41935483870967744]|0.0       |
|[1.0,27.0,2.5,0.0,2.0] |1      |[13.0,32.0]  |[0.28888888888888886,0.7111111111111111]|1.0       |
|[1.0,27.0,6.0,1.0,1.0] |1      |[22.0,37.0]  |[0.3728813559322034,0.6271186440677966] |1.0       |
|[1.0,27.0,6.0,1.0,2.0] |0      |[22.0,37.0]  |[0.3728813559322034,0.6271186440677966] |1.0       |
|[1.0,27.0,6.0,1.0,3.0] |1      |[22.0,37.0]  |[0.3728813559322034,0.6271186440677966] |1.0       |


1. Features : The input features used for training of the model
2. Affairs : The true labels for the records
3. rawPrediction : a measure of confidence in each possible label (where larger = more confident).
4. probability : The Probability is the conditional probability for each class
5. prediction : The Prediction is the result of finding the statistical mode of the rawPrediction

In [37]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [39]:
accuracy=MulticlassClassificationEvaluator(labelCol='affairs',metricName='accuracy').evaluate(predictions)

In [40]:
accuracy

0.7108736643620365

In [46]:
accuracy=MulticlassClassificationEvaluator(labelCol='affairs',metricName='f1').evaluate(predictions)

In [47]:
accuracy

0.6947861869441101

In [50]:
precision=MulticlassClassificationEvaluator(labelCol='affairs',metricName='weightedPrecision').evaluate(predictions)

In [51]:
precision

0.692945539246159

In [52]:
recall=MulticlassClassificationEvaluator(labelCol='affairs',metricName='weightedRecall').evaluate(predictions)

In [53]:
recall

0.7108736643620365

In [75]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [73]:
auc=BinaryClassificationEvaluator(labelCol='affairs').evaluate(predictions)

In [74]:
auc

0.6016365822116173

In [54]:
## Random Forest

In [55]:
from pyspark.ml.classification import RandomForestClassifier

In [59]:
rf_classifier=RandomForestClassifier(labelCol='affairs').fit(train_df)

In [60]:
rf_predictions=rf_classifier.transform(test_df)

In [62]:
rf_predictions.select(['probability','affairs','prediction']).show(10,False)

+----------------------------------------+-------+----------+
|probability                             |affairs|prediction|
+----------------------------------------+-------+----------+
|[0.3752227609655098,0.6247772390344901] |1      |1.0       |
|[0.3752227609655098,0.6247772390344901] |1      |1.0       |
|[0.48105740260018903,0.5189425973998109]|1      |1.0       |
|[0.4081563861358649,0.5918436138641351] |1      |1.0       |
|[0.34565995949820005,0.6543400405018]   |1      |1.0       |
|[0.38992422528637677,0.6100757747136232]|0      |1.0       |
|[0.4004383825310048,0.5995616174689953] |1      |1.0       |
|[0.3436447664079562,0.6563552335920437] |1      |1.0       |
|[0.2852497569325212,0.7147502430674788] |1      |1.0       |
|[0.2852497569325212,0.7147502430674788] |1      |1.0       |
+----------------------------------------+-------+----------+
only showing top 10 rows



In [63]:
rf_accuracy=MulticlassClassificationEvaluator(labelCol='affairs',metricName='accuracy').evaluate(rf_predictions)

In [64]:
rf_accuracy

0.7165304839723444

In [65]:
rf_precision=MulticlassClassificationEvaluator(labelCol='affairs',metricName='weightedPrecision').evaluate(rf_predictions)

In [66]:
rf_precision

0.69605938625242

In [69]:
#AUC


In [72]:
rf_auc=BinaryClassificationEvaluator(labelCol='affairs').evaluate(rf_predictions)

In [76]:
rf_auc

0.7464782062257519

In [77]:
# Feature importance

In [83]:
rf_classifier.featureImportances

SparseVector(5, {0: 0.5704, 1: 0.0388, 2: 0.2259, 3: 0.0512, 4: 0.1138})

In [91]:
df.schema["features"].metadata["ml_attr"]["attrs"]

{'numeric': [{'idx': 0, 'name': 'rate_marriage'},
  {'idx': 1, 'name': 'age'},
  {'idx': 2, 'name': 'yrs_married'},
  {'idx': 3, 'name': 'children'},
  {'idx': 4, 'name': 'religious'}]}

In [96]:
#increased number of trees
rf_classifier=RandomForestClassifier(labelCol='affairs',numTrees=500).fit(train_df)

In [97]:
rf_predictions=rf_classifier.transform(test_df)

In [98]:
rf_accuracy=MulticlassClassificationEvaluator(labelCol='affairs',metricName='accuracy').evaluate(rf_predictions)

In [99]:
rf_accuracy

0.7184160905091138

In [100]:
rf_precision=MulticlassClassificationEvaluator(labelCol='affairs',metricName='weightedPrecision').evaluate(rf_predictions)

In [101]:
rf_precision

0.6985822498970077

In [103]:
df.schema["features"].metadata["ml_attr"]["attrs"]

{'numeric': [{'idx': 0, 'name': 'rate_marriage'},
  {'idx': 1, 'name': 'age'},
  {'idx': 2, 'name': 'yrs_married'},
  {'idx': 3, 'name': 'children'},
  {'idx': 4, 'name': 'religious'}]}

In [102]:
rf_classifier.featureImportances

SparseVector(5, {0: 0.565, 1: 0.0352, 2: 0.2347, 3: 0.0576, 4: 0.1075})