-
Notifications
You must be signed in to change notification settings - Fork 67
/
Working_with_rsets.Rmd
215 lines (164 loc) · 7.92 KB
/
Working_with_rsets.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
---
title: "Working with resampling sets"
vignette: >
%\VignetteEngine{knitr::rmarkdown}
%\VignetteIndexEntry{Working with resampling sets}
output:
knitr:::html_vignette:
toc: yes
---
```{r ex_setup, include=FALSE}
knitr::opts_chunk$set(
message = FALSE,
digits = 3,
collapse = TRUE,
comment = "#>",
eval = rlang::is_installed("ggplot2") && rlang::is_installed("modeldata")
)
options(digits = 3, width = 90)
```
```{r ggplot2_setup, include = FALSE}
library(ggplot2)
theme_set(theme_bw())
```
## Introduction
The rsample package can be used to create objects containing resamples of your original data. This vignette contains a demonstration of how those objects can be used for data analysis.
Let's use the `attrition` data set. From its documentation:
> These data are from the IBM Watson Analytics Lab. The website describes the data with "Uncover the factors that lead to employee attrition and explore important questions such as ‘show me a breakdown of distance from home by job role and attrition’ or 'compare average monthly income by education and attrition'. This is a fictional data set created by IBM data scientists." There are 1470 rows.
The data can be accessed using
```{r attrition, message=FALSE}
library(rsample)
data("attrition", package = "modeldata")
names(attrition)
table(attrition$Attrition)
```
## Model Assessment
Let's fit a logistic regression model to the data with model terms for the job satisfaction, gender, and monthly income.
If we were fitting the model to the entire data set, we might model attrition using
```r
glm(Attrition ~ JobSatisfaction + Gender + MonthlyIncome,
data = attrition, family = binomial)
```
For convenience, we'll create a formula object that will be used later:
```{r form, message=FALSE}
mod_form <- as.formula(Attrition ~ JobSatisfaction + Gender + MonthlyIncome)
```
To evaluate this model, we will use 10 repeats of 10-fold cross-validation and use the 100 holdout samples to evaluate the overall accuracy of the model.
First, let's make the splits of the data:
```{r model_vfold, message=FALSE}
library(rsample)
set.seed(4622)
rs_obj <- vfold_cv(attrition, v = 10, repeats = 10)
rs_obj
```
Now let's write a function that will, for each resample:
1. obtain the analysis data set (i.e. the 90% used for modeling)
1. fit a logistic regression model
1. predict the assessment data (the other 10% not used for the model) using the broom package
1. determine if each sample was predicted correctly.
Here is our function:
```{r lm_func}
## splits will be the `rsplit` object with the 90/10 partition
holdout_results <- function(splits, ...) {
# Fit the model to the 90%
mod <- glm(..., data = analysis(splits), family = binomial)
# Save the 10%
holdout <- assessment(splits)
# `augment` will save the predictions with the holdout data set
res <- broom::augment(mod, newdata = holdout)
# Class predictions on the assessment set from class probs
lvls <- levels(holdout$Attrition)
predictions <- factor(ifelse(res$.fitted > 0, lvls[2], lvls[1]),
levels = lvls)
# Calculate whether the prediction was correct
res$correct <- predictions == holdout$Attrition
# Return the assessment data set with the additional columns
res
}
```
For example:
```{r onefold, warning = FALSE}
example <- holdout_results(rs_obj$splits[[1]], mod_form)
dim(example)
dim(assessment(rs_obj$splits[[1]]))
## newly added columns:
example[1:10, setdiff(names(example), names(attrition))]
```
For this model, the `.fitted` value is the linear predictor in log-odds units.
To compute this data set for each of the 100 resamples, we'll use the `map()` function from the purrr package:
```{r model_purrr, warning=FALSE}
library(purrr)
rs_obj$results <- map(rs_obj$splits,
holdout_results,
mod_form)
rs_obj
```
Now we can compute the accuracy values for all of the assessment data sets:
```{r model_acc}
rs_obj$accuracy <- map_dbl(rs_obj$results, function(x) mean(x$correct))
summary(rs_obj$accuracy)
```
Keep in mind that the baseline accuracy to beat is the rate of non-attrition, which is `r ifelse(rlang::is_installed("modeldata"), round(mean(attrition$Attrition == "No"), 3), 0.839)`. Not a great model so far.
## Using the Bootstrap to Make Comparisons
Traditionally, the bootstrap has been primarily used to empirically determine the sampling distribution of a test statistic. Given a set of samples with replacement, a statistic can be calculated on each analysis set and the results can be used to make inferences (such as confidence intervals).
For example, are there differences in the median monthly income between genders?
```{r type_plot, fig.alt = "Two boxplots of monthly income separated by gender, showing a slight difference in median but largely overlapping boxes."}
ggplot(attrition, aes(x = Gender, y = MonthlyIncome)) +
geom_boxplot() +
scale_y_log10()
```
If we wanted to compare the genders, we could conduct a _t_-test or rank-based test. Instead, let's use the bootstrap to see if there is a difference in the median incomes for the two groups. We need a simple function to compute this statistic on the resample:
```{r mean_diff}
median_diff <- function(splits) {
x <- analysis(splits)
median(x$MonthlyIncome[x$Gender == "Female"]) -
median(x$MonthlyIncome[x$Gender == "Male"])
}
```
Now we would create a large number of bootstrap samples (say 2000+). For illustration, we'll only do 500 in this document.
```{r boot_mean_diff}
set.seed(353)
bt_resamples <- bootstraps(attrition, times = 500)
```
This function is then computed across each resample:
```{r stats}
bt_resamples$wage_diff <- map_dbl(bt_resamples$splits, median_diff)
```
The bootstrap distribution of this statistic has a slightly bimodal and skewed distribution:
```{r stats_plot, fig.alt = "The bootstrap distribution of the differences in median monthly income: it is slightly bimodal and left-skewed."}
ggplot(bt_resamples, aes(x = wage_diff)) +
geom_line(stat = "density", adjust = 1.25) +
xlab("Difference in Median Monthly Income (Female - Male)")
```
The variation is considerable in this statistic. One method of computing a confidence interval is to take the percentiles of the bootstrap distribution. A 95% confidence interval for the difference in the means would be:
```{r ci}
quantile(bt_resamples$wage_diff,
probs = c(0.025, 0.975))
```
The calculated 95% confidence interval contains zero, so we don't have evidence for a difference in median income between these genders at a confidence level of 95%.
## Bootstrap Estimates of Model Coefficients
Unless there is already a column in the resample object that contains the fitted model, a function can be used to fit the model and save all of the model coefficients. The [broom package](https://cran.r-project.org/package=broom) package has a `tidy()` function that will save the coefficients in a data frame. Instead of returning a data frame with a row for each model term, we will save a data frame with a single row and columns for each model term. As before, `purrr::map()` can be used to estimate and save these values for each split.
```{r coefs}
glm_coefs <- function(splits, ...) {
## use `analysis` or `as.data.frame` to get the analysis data
mod <- glm(..., data = analysis(splits), family = binomial)
as.data.frame(t(coef(mod)))
}
bt_resamples$betas <- map(.x = bt_resamples$splits,
.f = glm_coefs,
mod_form)
bt_resamples
bt_resamples$betas[[1]]
```
## Keeping Tidy
As previously mentioned, the [broom package](https://cran.r-project.org/package=broom) contains a class called `tidy` that created representations of objects that can be easily used for analysis, plotting, etc. rsample contains `tidy` methods for `rset` and `rsplit` objects. For example:
```{r tidy_rsplit}
first_resample <- bt_resamples$splits[[1]]
class(first_resample)
tidy(first_resample)
```
and
```{r tidy_rset}
class(bt_resamples)
tidy(bt_resamples)
```