-
Notifications
You must be signed in to change notification settings - Fork 0
/
9_3_decision_tree_titanic.Rmd
246 lines (191 loc) · 6.75 KB
/
9_3_decision_tree_titanic.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
---
title: "9_3_DecisionTree_RF"
output: html_document
editor_options:
chunk_output_type: inline
---
## Decision Tree
* Decision Tree:效果通常比Regression和分類法差。
* bootstrapping:在訓練決策樹時,不僅訓練一次,而是反覆隨機抽取部分樣本出來建,這樣每棵樹建起來都不太一樣,判斷出來的分類也會不同。比方說,有60棵樹覺得會Survived,40棵樹覺得會Dead,那麼就會得到60%的Survived rate。從原本一棵樹的結果就是0/1,變成一個機率,該機率被稱為bagging。
* Random Forest:在利用bootstrapping建立決策樹時,對於每棵樹的建立過程,原本是隨機挑variable,但這樣很容易掉進local optimum,因此random forest就讓每次都隨機放棄部分變數,這樣可以讓建出來的樹群很多樣。
* Boosting: 每次挑變數時,上回合哪個變數挑的特別好,我給他的權重就越高,或者哪裡特別不好,我給他的權重就越低。
### The Titanic
![](https://i.redd.it/v1g0agfdhcpy.jpg)
```{r}
library(tidyverse)
library(stringr)
library(tidytext)
library(magrittr)
library(lubridate)
library(modelr) # Modelling Functions that Work with the Pipe
library(broom) # Convert Statistical Analysis Objects into Tidy Data Frames
set.seed(1234)
theme_set(theme_minimal())
??theme_set
# browseURL("https://cfss.uchicago.edu/stat004_decision_trees.html")
# brwoseURL("https://docs.google.com/presentation/d/16caTlFSAhBUD4-WJhJ4swVojGrVXHmRO65XNtqsn0VA/edit#slide=id.g3b6b62405a_0_0")
```
## The Titanic
```{r}
# install.packages("titanic")
library(titanic)
titanic <- titanic_train %>%
as_tibble()
titanic %>%
head() %>%
knitr::kable()
```
## Simple tree
### category -> factor
```{r}
library(tree)
titanic_tree_data <- titanic %>%
mutate(Survived = if_else(Survived == 1, "Survived", "Died"),
Survived = as.factor(Survived),
Sex = as.factor(Sex))
DT::saveWidget(DT::datatable(titanic_tree_data), "titanic.html")
system("open titanic.html")
```
## count variables
```{r}
titanic_tree_data %>% count(Survived) %>% View
titanic_tree_data %>% count(Survived, Sex) %>% spread(Sex, n) %>% View
titanic_tree_data %>% count(Survived, Pclass) %>% spread(Pclass, n) %>% View
df <- titanic_tree_data %>% count(Survived, Sex) %>%
group_by(Sex) %>%
summarize(p = sum(-n/sum(n) * log2(n/sum(n)))) %>%
ungroup() %>% left_join(count(titanic_tree_data, Sex))
entropy <- function(a, b){-a/(a+b)*log2(a/(a+b)) -b/(a+b)*log2(b/(a+b))}
all <- 549 + 342
entropy0 <- entropy(549, 342)
(IG.Sex <- entropy0 -
entropy(81,233) * (81 + 233) / all -
entropy(468,109) * (468 + 109) / all)
(IG.Pclass <- entropy0 -
entropy(80, 136) * (81 + 136) / all -
entropy(97, 87) * (97 + 87) / all -
entropy(372, 119) * (372 + 119) / all)
```
### trainning
```{r}
titanic_tree <- tree(Survived ~ Age + Sex, data = titanic_tree_data)
summary(titanic_tree)
```
### plot
```{r}
plot(titanic_tree)
text(titanic_tree, pretty = 0)
```
## Complicate tree
```{r}
titanic_tree_full_data <- titanic %>%
mutate(Survived = if_else(Survived == 1, "Survived",
if_else(Survived == 0, "Died",
NA_character_))) %>%
mutate_if(is.character, as.factor)
titanic_tree_full <- tree(Survived ~ Pclass + Sex + Age + SibSp +
Parch + Fare + Embarked,
data = titanic_tree_full_data)
summary(titanic_tree_full)
```
### plotting
```{r}
plot(titanic_tree_full)
text(titanic_tree_full, pretty = 0)
```
### Prunning
```{r}
titanic_tree_messy <- tree(Survived ~ Pclass + Sex + Age + SibSp +
Parch + Fare + Embarked,
data = titanic_tree_full_data,
control = tree.control(
nobs = nrow(titanic_tree_full_data),
mindev = 0,
minsize = 10)
)
summary(titanic_tree_messy)
```
```{r}
plot(titanic_tree_messy)
text(titanic_tree_messy, pretty = 0)
```
## Pros and cos of Decision Tree
* Decision trees are an entirely different method of estimating functional forms as compared to linear regression. There are some benefits to trees:
**pros**:
1. They are easy to explain. Most people, even if they lack statistical training, can understand decision trees.
2. They are easily presented as visualizations, and pretty interpretable.
3. Qualitative predictors are easily handled without the need to create a long series of dummy variables.
**cons**
1. Their accuracy rates are generally lower than other regression and classification approaches.
2. Trees can be non-robust. That is, a small change in the data or inclusion/exclusion of a handful of observations can dramatically alter the final estimated tree.
Fortuntately, there is an easy way to improve on these poor predictions: **by aggregating many decision trees and averaging across them**, we can substantially improve performance.
## glm by caret package
```{r}
# devtools::install_github('topepo/caret/pkg/caret') (x)
# install.packages("ddalpha")
# install.packages('caret', dependencies=TRUE)
# install.packages('caret')
library(caret)
titanic_clean <- titanic %>%
filter(!is.na(Survived), !is.na(Age))
caret_glm <- train(Survived ~ Age, data = titanic_clean,
method = "glm",
family = binomial,
trControl = trainControl(method = "none"))
summary(caret_glm)
```
## glm function similar to previous caret results
```{r}
glm_glm <- glm(Survived ~ Age, data = titanic_clean, family = "binomial")
summary(glm_glm)
```
## Random forest
```{r}
titanic_rf_data <- titanic_tree_full_data %>%
select(Survived, Pclass, Sex, Age, SibSp, Parch, Fare, Embarked) %>%
na.omit()
titanic_rf_data
```
### training
```{r}
age_sex_rf <- train(Survived ~ Age + Sex, data = titanic_rf_data,
method = "rf",
ntree = 200,
trControl = trainControl(method = "oob"))
age_sex_rf
```
### result
```{r}
str(age_sex_rf, max.level = 1)
```
### stat of model
```{r}
age_sex_rf$finalModel
```
* total 200 trees
* randomly use 2 varables to define the split
* out-of-bag error rate
```
Call:
randomForest(x = x, y = y, ntree = 200, mtry = param$mtry)
Type of random forest: classification
Number of trees: 200
No. of variables tried at each split: 2
OOB estimate of error rate: 24.23%
Confusion matrix:
Died Survived class.error
Died 350 74 0.1745283
Survived 99 191 0.3413793
```
### confusion matrix
```{r}
knitr::kable(age_sex_rf$finalModel$confusion)
```
<!-- # Loading jieba -->
<!-- ```{r} -->
<!-- library(jiebaR) -->
<!-- segment_not <- c("台灣", "臺灣") -->
<!-- cutter <- worker() -->
<!-- new_user_word(cutter, segment_not) -->
<!-- stopWords <- readRDS("data/stopWords.rds") -->
<!-- ``` -->