Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
minor slide updates and created student script
- Loading branch information
1 parent
523945d
commit ff1a689
Showing
4 changed files
with
127 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
--- | ||
title: "Supervised Modeling Process" | ||
output: html_notebook | ||
--- | ||
|
||
# Prerequisites | ||
|
||
```{r slide-4} | ||
# Packages required | ||
library(rsample) | ||
library(caret) | ||
library(tidyverse) | ||
# Data required | ||
## ames data | ||
ames <- AmesHousing::make_ames() | ||
## attrition data | ||
churn <- rsample::attrition | ||
``` | ||
|
||
|
||
# Mechanics of data splitting | ||
|
||
Two most common ways of splitting data include: | ||
|
||
* simple random sampling: randomly select observations | ||
* stratified sampling: preserving distributions | ||
- classification: sampling within the classes to preserve the | ||
distribution of the outcome in the training and test sets | ||
- regression: determine the quartiles of the data set and sample within those | ||
artificial groups | ||
|
||
```{r slide-8} | ||
set.seed(123) # for reproducibility | ||
split <- initial_split(diamonds, strata = "price", prop = 0.7) | ||
train <- training(split) | ||
test <- testing(split) | ||
# Do the distributions line up? | ||
ggplot(train, aes(x = price)) + | ||
geom_line(stat = "density", | ||
trim = TRUE) + | ||
geom_line(data = test, | ||
stat = "density", | ||
trim = TRUE, col = "red") | ||
``` | ||
|
||
|
||
# Your Turn! | ||
|
||
1. Use __rsample__ to split the Ames housing data (`ames`) and the Employee attrition data (`churn`) using stratified sampling and with a 80% split. | ||
2. Verify that the distribution between training and test sets are similar. | ||
|
||
```{r slide-9} | ||
# ames data | ||
set.seed(123) | ||
ames_split <- initial_split(ames, prop = _____, strata = "Sale_Price") | ||
ames_train <- training(_____) | ||
ames_test <- testing(_____) | ||
# attrition data | ||
set.seed(123) | ||
churn_split <- initial_split(churn, prop = _____, strata = "Attrition") | ||
churn_train <- training(_____) | ||
churn_test <- testing(_____) | ||
``` | ||
|
||
|
||
# Putting the process together | ||
|
||
Let's put these pieces together and analyze the Ames housing data: | ||
|
||
1. Split into training vs testing data | ||
2. Specify a resampling procedure | ||
3. Create our hyperparameter grid | ||
4. Execute grid search | ||
5. Evaluate performance | ||
|
||
___This grid search takes ~2 min___ | ||
|
||
```{r slide-35} | ||
# 1. stratified sampling with the rsample package | ||
set.seed(123) | ||
split <- initial_split(ames, prop = 0.7, strata = "Sale_Price") | ||
ames_train <- training(split) | ||
ames_test <- testing(split) | ||
# 2. create a resampling method | ||
cv <- trainControl( | ||
method = "repeatedcv", | ||
number = 10, | ||
repeats = 5 | ||
) | ||
# 3. create a hyperparameter grid search | ||
hyper_grid <- expand.grid(k = seq(2, 26, by = 2)) | ||
# 4. execute grid search with knn model | ||
# use RMSE as preferred metric | ||
knn_fit <- train( | ||
Sale_Price ~ ., | ||
data = ames_train, | ||
method = "knn", | ||
trControl = cv, | ||
tuneGrid = hyper_grid, | ||
metric = "RMSE" | ||
) | ||
# 5. evaluate results | ||
# print model results | ||
knn_fit | ||
# plot cross validation results | ||
ggplot(knn_fit$results, aes(k, RMSE)) + | ||
geom_line() + | ||
geom_point() + | ||
scale_y_continuous(labels = scales::dollar) | ||
``` |