# 随机森林

## 1.创建对象

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('random_forest').getOrCreate()

## 2.读取数据集

In [3]:
df = spark.read.csv('data/affairs.csv', inferSchema=True, header=True)
df

DataFrame[rate_marriage: int, age: double, yrs_married: double, children: double, religious: int, affairs: int]

## 3.数据分析

In [4]:
print((df.count(), len(df.columns)))

(6366, 6)


In [5]:
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)



In [6]:
df.show(5)

+-------------+----+-----------+--------+---------+-------+
|rate_marriage| age|yrs_married|children|religious|affairs|
+-------------+----+-----------+--------+---------+-------+
|            5|32.0|        6.0|     1.0|        3|      0|
|            4|22.0|        2.5|     0.0|        2|      0|
|            3|32.0|        9.0|     3.0|        3|      1|
|            3|27.0|       13.0|     3.0|        1|      1|
|            4|22.0|        2.5|     0.0|        1|      1|
+-------------+----+-----------+--------+---------+-------+
only showing top 5 rows



In [7]:
df.describe().select('summary','rate_marriage','age','yrs_married', 'children', 'religious').show()

+-------+------------------+------------------+-----------------+------------------+------------------+
|summary|     rate_marriage|               age|      yrs_married|          children|         religious|
+-------+------------------+------------------+-----------------+------------------+------------------+
|  count|              6366|              6366|             6366|              6366|              6366|
|   mean| 4.109644989004084|29.082862079798932| 9.00942507068803|1.3968740182218033|2.4261702796104303|
| stddev|0.9614295945655025| 6.847881883668817|7.280119972766412| 1.433470828560344|0.8783688402641785|
|    min|                 1|              17.5|              0.5|               0.0|                 1|
|    max|                 5|              42.0|             23.0|               5.5|                 4|
+-------+------------------+------------------+-----------------+------------------+------------------+



In [8]:
df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 2053|
|      0| 4313|
+-------+-----+



In [9]:
df.groupBy('rate_marriage').count().show()

+-------------+-----+
|rate_marriage|count|
+-------------+-----+
|            1|   99|
|            3|  993|
|            5| 2684|
|            4| 2242|
|            2|  348|
+-------------+-----+



In [10]:
df.groupBy('affairs','rate_marriage').count().orderBy('rate_marriage','affairs','count', ascending=True).show()

+-------+-------------+-----+
|affairs|rate_marriage|count|
+-------+-------------+-----+
|      0|            1|   25|
|      1|            1|   74|
|      0|            2|  127|
|      1|            2|  221|
|      0|            3|  446|
|      1|            3|  547|
|      0|            4| 1518|
|      1|            4|  724|
|      0|            5| 2197|
|      1|            5|  487|
+-------+-------------+-----+



In [11]:
df.groupBy('affairs','religious').count().orderBy('religious','affairs','count', ascending=True).show()

+-------+---------+-----+
|affairs|religious|count|
+-------+---------+-----+
|      0|        1|  613|
|      1|        1|  408|
|      0|        2| 1448|
|      1|        2|  819|
|      0|        3| 1715|
|      1|        3|  707|
|      0|        4|  537|
|      1|        4|  119|
+-------+---------+-----+



In [12]:
df.groupBy('affairs','children').count().orderBy('children','affairs','count', ascending=True).show()

+-------+--------+-----+
|affairs|children|count|
+-------+--------+-----+
|      0|     0.0| 1912|
|      1|     0.0|  502|
|      0|     1.0|  747|
|      1|     1.0|  412|
|      0|     2.0|  873|
|      1|     2.0|  608|
|      0|     3.0|  460|
|      1|     3.0|  321|
|      0|     4.0|  197|
|      1|     4.0|  131|
|      0|     5.5|  124|
|      1|     5.5|   79|
+-------+--------+-----+



In [13]:
df.groupBy('affairs').mean().show()

+-------+------------------+------------------+------------------+------------------+------------------+------------+
|affairs|avg(rate_marriage)|          avg(age)|  avg(yrs_married)|     avg(children)|    avg(religious)|avg(affairs)|
+-------+------------------+------------------+------------------+------------------+------------------+------------+
|      1|3.6473453482708234|30.537018996590355|11.152459814905017|1.7289332683877252| 2.261568436434486|         1.0|
|      0| 4.329700904242986| 28.39067934152562| 7.989334569904939|1.2388128912589844|2.5045212149316023|         0.0|
+-------+------------------+------------------+------------------+------------------+------------------+------------+



## 4.特征工程

In [14]:
from pyspark.ml.feature import VectorAssembler
df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'], outputCol='features')
df = df_assembler.transform(df)
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)
 |-- features: vector (nullable = true)



In [15]:
df.select(['features','affairs']).show(10, False)

+-----------------------+-------+
|features               |affairs|
+-----------------------+-------+
|[5.0,32.0,6.0,1.0,3.0] |0      |
|[4.0,22.0,2.5,0.0,2.0] |0      |
|[3.0,32.0,9.0,3.0,3.0] |1      |
|[3.0,27.0,13.0,3.0,1.0]|1      |
|[4.0,22.0,2.5,0.0,1.0] |1      |
|[4.0,37.0,16.5,4.0,3.0]|1      |
|[5.0,27.0,9.0,1.0,1.0] |1      |
|[4.0,27.0,9.0,0.0,2.0] |1      |
|[5.0,37.0,23.0,5.5,2.0]|1      |
|[5.0,37.0,23.0,5.5,2.0]|1      |
+-----------------------+-------+
only showing top 10 rows



In [16]:
model_df = df.select(['features','affairs'])

## 5.划分数据集

In [17]:
train_df, test_df = model_df.randomSplit([0.75,0.25])
print(train_df.count())

4758


In [18]:
train_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 1515|
|      0| 3243|
+-------+-----+



In [19]:
test_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1|  538|
|      0| 1070|
+-------+-----+



## 6.训练模型

In [20]:
from pyspark.ml.classification import RandomForestClassifier
rl_classifier = RandomForestClassifier(labelCol='affairs', numTrees=50).fit(train_df)

## 7.测试数据评估

In [23]:
rl_predictions = rl_classifier.transform(test_df)
rl_predictions.show(20, False)

+-----------------------+-------+---------------------------------------+----------------------------------------+----------+
|features               |affairs|rawPrediction                          |probability                             |prediction|
+-----------------------+-------+---------------------------------------+----------------------------------------+----------+
|[1.0,22.0,2.5,0.0,1.0] |1      |[15.703172001412614,34.29682799858739] |[0.31406344002825226,0.6859365599717477]|1.0       |
|[1.0,22.0,2.5,1.0,1.0] |1      |[15.685716111714266,34.314283888285736]|[0.3137143222342853,0.6862856777657147] |1.0       |
|[1.0,22.0,2.5,1.0,2.0] |0      |[16.178993949848483,33.82100605015152] |[0.32357987899696966,0.6764201210030304]|1.0       |
|[1.0,22.0,2.5,1.0,2.0] |0      |[16.178993949848483,33.82100605015152] |[0.32357987899696966,0.6764201210030304]|1.0       |
|[1.0,22.0,2.5,1.0,3.0] |1      |[20.12165229148593,29.87834770851407]  |[0.4024330458297186,0.5975669541702814] |1.0 

In [24]:
rl_predictions.groupBy('prediction').count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|       0.0| 1294|
|       1.0|  314|
+----------+-----+



In [25]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.evaluation import BinaryClassificationEvaluator

### 准确率

In [27]:
rl_accuracy = MulticlassClassificationEvaluator(labelCol='affairs', metricName='accuracy').evaluate(rl_predictions)
print('{0:.0%}'.format(rl_accuracy))

71%


### 精度

In [29]:
rl_precision = MulticlassClassificationEvaluator(labelCol='affairs', metricName='weightedPrecision').evaluate(rl_predictions)
print('{0:.0%}'.format(rl_precision))

69%


### AUC曲线下的面积

In [31]:
rl_auc = BinaryClassificationEvaluator(labelCol='affairs').evaluate(rl_predictions)
print('{0:.2%}'.format(rl_auc))

72.50%


In [32]:
rl_classifier.featureImportances

SparseVector(5, {0: 0.5858, 1: 0.03, 2: 0.2539, 3: 0.0454, 4: 0.0849})

In [33]:
df.schema['features'].metadata['ml_attr']['attrs']

{'numeric': [{'idx': 0, 'name': 'rate_marriage'},
  {'idx': 1, 'name': 'age'},
  {'idx': 2, 'name': 'yrs_married'},
  {'idx': 3, 'name': 'children'},
  {'idx': 4, 'name': 'religious'}]}

## 8.保存模型

In [34]:
from pyspark.ml.classification import RandomForestClassificationModel
rl_classifier.save('data/rf_model')

In [38]:
rl = RandomForestClassificationModel.load('data/rf_model')
new_preditions = rl.transform(train_df)

In [39]:
new_preditions.show(20, False)

+----------------------+-------+---------------------------------------+----------------------------------------+----------+
|features              |affairs|rawPrediction                          |probability                             |prediction|
+----------------------+-------+---------------------------------------+----------------------------------------+----------+
|[1.0,17.5,0.5,0.0,2.0]|0      |[42.72903564356653,7.270964356433465]  |[0.8545807128713306,0.1454192871286693] |0.0       |
|[1.0,22.0,2.5,0.0,1.0]|1      |[15.703172001412614,34.29682799858739] |[0.31406344002825226,0.6859365599717477]|1.0       |
|[1.0,22.0,2.5,0.0,2.0]|1      |[16.324302953793563,33.67569704620644] |[0.32648605907587125,0.6735139409241288]|1.0       |
|[1.0,22.0,2.5,1.0,1.0]|1      |[15.685716111714266,34.314283888285736]|[0.3137143222342853,0.6862856777657147] |1.0       |
|[1.0,22.0,2.5,1.0,2.0]|1      |[16.178993949848483,33.82100605015152] |[0.32357987899696966,0.6764201210030304]|1.0       |
