In [1]:
import findspark

In [2]:
findspark.init()

In [3]:
import pyspark

In [4]:
from pyspark.sql import SparkSession

In [5]:
spark = SparkSession.builder.appName("RandomForest").getOrCreate()

In [65]:
df = spark.read.csv(r'rf.csv',inferSchema=True,header=True)

In [14]:
df.count(),len(df.columns)

(6366, 6)

In [15]:
df.show(5)

+-------------+----+-----------+--------+---------+-------+
|rate_marriage| age|yrs_married|children|religious|affairs|
+-------------+----+-----------+--------+---------+-------+
|            5|32.0|        6.0|     1.0|        3|      0|
|            4|22.0|        2.5|     0.0|        2|      0|
|            3|32.0|        9.0|     3.0|        3|      1|
|            3|27.0|       13.0|     3.0|        1|      1|
|            4|22.0|        2.5|     0.0|        1|      1|
+-------------+----+-----------+--------+---------+-------+
only showing top 5 rows



In [31]:
df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 2053|
|      0| 4313|
+-------+-----+



In [48]:
df.filter(df['affairs'] == 1).show(5)

+-------------+----+-----------+--------+---------+-------+
|rate_marriage| age|yrs_married|children|religious|affairs|
+-------------+----+-----------+--------+---------+-------+
|            3|32.0|        9.0|     3.0|        3|      1|
|            3|27.0|       13.0|     3.0|        1|      1|
|            4|22.0|        2.5|     0.0|        1|      1|
|            4|37.0|       16.5|     4.0|        3|      1|
|            5|27.0|        9.0|     1.0|        1|      1|
+-------------+----+-----------+--------+---------+-------+
only showing top 5 rows



In [49]:
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)



In [50]:
from pyspark.ml.feature import VectorAssembler

In [66]:
vec_asm = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'], outputCol='features')

In [70]:
df = vec_asm.transform(df)

In [71]:
df.show(5)

+-------------+----+-----------+--------+---------+-------+--------------------+
|rate_marriage| age|yrs_married|children|religious|affairs|            features|
+-------------+----+-----------+--------+---------+-------+--------------------+
|            5|32.0|        6.0|     1.0|        3|      0|[5.0,32.0,6.0,1.0...|
|            4|22.0|        2.5|     0.0|        2|      0|[4.0,22.0,2.5,0.0...|
|            3|32.0|        9.0|     3.0|        3|      1|[3.0,32.0,9.0,3.0...|
|            3|27.0|       13.0|     3.0|        1|      1|[3.0,27.0,13.0,3....|
|            4|22.0|        2.5|     0.0|        1|      1|[4.0,22.0,2.5,0.0...|
+-------------+----+-----------+--------+---------+-------+--------------------+
only showing top 5 rows



In [75]:
df = df.select('features', 'affairs')

In [93]:
df.select(df['features'])

DataFrame[features: vector]

In [98]:
rownums = [2,5]

In [106]:
df.rdd.zipWithIndex().filter(lambda x: x[1] in rownums).map(lambda x: x[0]).collect()

[Row(features=DenseVector([3.0, 32.0, 9.0, 3.0, 3.0]), affairs=1),
 Row(features=DenseVector([4.0, 37.0, 16.5, 4.0, 3.0]), affairs=1)]

In [109]:
train_df, test_df = df.randomSplit([0.8,0.2])

In [115]:
train_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 1602|
|      0| 3460|
+-------+-----+



In [116]:
from pyspark.ml.classification import RandomForestClassifier

In [117]:
rf = RandomForestClassifier(labelCol='affairs', numTrees=50).fit(train_df)

In [121]:
preds = rf.transform(test_df)

In [123]:
preds.show(5)

+--------------------+-------+--------------------+--------------------+----------+
|            features|affairs|       rawPrediction|         probability|prediction|
+--------------------+-------+--------------------+--------------------+----------+
|[1.0,22.0,2.5,0.0...|      1|[18.7304017666513...|[0.37460803533302...|       1.0|
|[1.0,22.0,2.5,0.0...|      1|[21.6518119199105...|[0.43303623839821...|       1.0|
|[1.0,22.0,2.5,1.0...|      0|[22.6488866777797...|[0.45297773355559...|       1.0|
|[1.0,22.0,2.5,1.0...|      1|[22.6488866777797...|[0.45297773355559...|       1.0|
|[1.0,27.0,2.5,0.0...|      0|[18.1714366212769...|[0.36342873242553...|       1.0|
+--------------------+-------+--------------------+--------------------+----------+
only showing top 5 rows



In [124]:
preds = preds.select('affairs','prediction')

In [127]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator

In [128]:
rf_acc = MulticlassClassificationEvaluator(labelCol='affairs', metricName='accuracy').evaluate(preds)

In [129]:
rf_acc

0.7032208588957055

In [131]:
BinaryClassificationEvaluator(labelCol='affairs').evaluate(rf.transform(test_df))

0.7371179845231256

In [140]:
rf.featureImportances

SparseVector(5, {0: 0.6571, 1: 0.0223, 2: 0.2316, 3: 0.0308, 4: 0.0583})

In [141]:
rf.save('rfmodel')

In [142]:
from pyspark.ml.classification import RandomForestClassificationModel

In [143]:
rf = RandomForestClassificationModel.load('rfmodel')

In [144]:
rf

RandomForestClassificationModel: uid=RandomForestClassifier_008eee78e262, numTrees=50, numClasses=2, numFeatures=5