-
Notifications
You must be signed in to change notification settings - Fork 2
/
09-classification-solutions.Rmd
111 lines (76 loc) · 2.63 KB
/
09-classification-solutions.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
---
title: "Classification - Solutions"
output: html_notebook
editor_options:
chunk_output_type: inline
---
```{r setup, include=FALSE}
library(tidyverse)
library(tidymodels)
library(tune)
# read in the data
stackoverflow <- read_rds(here::here("materials/data/stackoverflow.rds"))
# split the data
set.seed(100) # Important!
so_split <- initial_split(stackoverflow, strata = remote)
so_train <- training(so_split)
so_test <- testing(so_split)
```
# Your Turn 1
Using the `so_train` and `so_test` data sets, how many individuals in our training set are remote? How about in the testing set?
```{r}
so_train %>%
count(remote)
so_test %>%
count(remote)
```
# Your Turn 2
Fill in the blanks. Use the `tree_spec()` model provided and `fit()` to:
1. Train a CART-based model with the formula = `remote ~ years_coded_job + salary`.
2. Remind yourself what the output looks like!
3. Predict remote status with the testing data.
4. Keep `set.seed(100)` at the start of your code.
```{r}
tree_spec <-
decision_tree() %>%
set_engine(engine = "rpart") %>%
set_mode("classification")
set.seed(100) # Important!
tree_fit <- fit(tree_spec,
remote ~ years_coded_job + salary,
data = so_train)
tree_fit
predict(tree_fit, new_data = so_test)
```
# Your Turn 3
Create a data frame of the observed and predicted remote status for the `so_test` data. Then use `count()` to count the number of individuals (i.e., rows) by their true and predicted remote status. Answer the following questions:
1. How many predictions did we make?
2. How many times is "remote" status predicted?
3. How many respondents are actually remote?
4. How many predictions did we get right?
*Hint: You can create a 2x2 table using* `count(var1, var2)`
```{r}
tree_predict <- predict(tree_fit, new_data = so_test)
all_preds <- so_test %>%
select(remote) %>%
bind_cols(tree_predict)
all_preds %>%
count(.pred_class, truth = remote)
```
# Your Turn 4
Build the necessary data frame, and use `roc_curve()` to calculate the data needed to construct the full ROC curve.
What is the necessary threshold for achieving specificity >.75?
```{r}
all_preds <- so_test %>%
select(remote) %>%
bind_cols(predict(tree_fit, new_data = so_test)) %>%
bind_cols(predict(tree_fit, new_data = so_test, type = "prob"))
roc_curve(all_preds, truth = remote, estimate = .pred_Remote)
```
# Your Turn 5
Use `roc_auc()` to calculate the area under the ROC curve. Then plot the ROC curve using `autoplot()`.
```{r}
roc_auc(all_preds, truth = remote, estimate = .pred_Remote)
roc_curve(all_preds, truth = remote, estimate = .pred_Remote) %>%
autoplot()
```