From 89af7973a3274db5d4c7c315c4d236213bcd73c7 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 24 Mar 2025 12:15:46 -0700 Subject: [PATCH 1/7] use setWidgetIdSeed() in learn/statistics/survival-case-study --- .../survival-case-study/index/execute-results/html.json | 4 ++-- learn/statistics/survival-case-study/index.html.md | 4 ++-- learn/statistics/survival-case-study/index.qmd | 2 ++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/_freeze/learn/statistics/survival-case-study/index/execute-results/html.json b/_freeze/learn/statistics/survival-case-study/index/execute-results/html.json index b8779a48..173a3b6e 100644 --- a/_freeze/learn/statistics/survival-case-study/index/execute-results/html.json +++ b/_freeze/learn/statistics/survival-case-study/index/execute-results/html.json @@ -1,8 +1,8 @@ { - "hash": "a568498b54fcca7979e9432c0977d394", + "hash": "647e49f07e35fcb8b406fc0488cf7703", "result": { "engine": "knitr", - "markdown": "---\ntitle: \"How long until building complaints are dispositioned? A survival analysis case study\"\ncategories:\n - statistical analysis\n - survival analysis\ntype: learn-subsection\nweight: 9\ndescription: | \n Learn how to use tidymodels for survival analysis.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: aorsf, censored, glmnet, modeldatatoo, and tidymodels.\n\nSurvival analysis is a field of statistics and machine learning for analyzing the time to an event. While it has its roots in medical research, the event of interest can be anything from customer churn to machine failure. Methods from survival analysis take into account that some observations may not yet have experienced the event of interest and are thus _censored_. \n\nHere we want to predict the time it takes for a complaint to be dispositioned^[In this context, the term _disposition_ means that there has been a decision or resolution regarding the complaint that is the conclusion of the process.] by the Department of Buildings in New York City. We are going to walk through a complete analysis from beginning to end, showing how to analyze time-to-event data.\n\nLet's start with loading the tidymodels and censored packages (the parsnip extension package for survival analysis models).\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nlibrary(censored)\n#> Loading required package: survival\n```\n:::\n\n\n\n\n## The buildings complaints data\n\nThe city of New York publishes data on the [complaints](https://data.cityofnewyork.us/Housing-Development/DOB-Complaints-Received/eabe-havv/about_data) received by the Department of Buildings. The data includes information on the type of complaint, the date it was entered in their records, the date it was dispositioned, and the location of the building the complaint was about. We are using a subset of the data, available in the modeldatatoo package.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nbuilding_complaints <- modeldatatoo::data_building_complaints()\nglimpse(building_complaints)\n#> Rows: 4,234\n#> Columns: 11\n#> $ days_to_disposition 72, 1, 41, 45, 16, 62, 56, 11, 35, 38, 39, 106, 1,…\n#> $ status \"ACTIVE\", \"ACTIVE\", \"ACTIVE\", \"ACTIVE\", \"ACTIVE\", …\n#> $ year_entered 2023, 2023, 2023, 2023, 2023, 2023, 2023, 2023, 20…\n#> $ latitude 40.66173, 40.57668, 40.73242, 40.68245, 40.63156, …\n#> $ longitude -73.98297, -74.00453, -73.87630, -73.79367, -73.99…\n#> $ borough Brooklyn, Brooklyn, Queens, Queens, Brooklyn, Quee…\n#> $ special_district None, None, None, None, None, None, None, None, No…\n#> $ unit Q-L, Q-L, SPOPS, Q-L, BKLYN, Q-L, Q-L, SPOPS, Q-L,…\n#> $ community_board 307, 313, 404, 412, 312, 406, 306, 306, 409, 404, …\n#> $ complaint_category 45, 45, 49, 45, 31, 45, 45, 49, 45, 45, 45, 4A, 31…\n#> $ complaint_priority B, B, C, B, C, B, B, C, B, B, B, B, C, C, B, B, B,…\n```\n:::\n\n\n\n\nBefore we dive into survival analysis, let's get a impression of how the complaints are distributed across the city. We have complaints in all five boroughs, albeit with a somewhat lower density of complaints in Staten Island.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n::: {.cell-output-display}\n\n```{=html}\n
\n\n```\n\n\nBuilding complaints in New York City (closed complaints in purple, active complaints in pink).\n:::\n:::\n\n\n\n\nIn the dataset, we can see the `days_to_disposition` as well as the `status` of the complaint. For a complaint with the status `\"ACTIVE\"`, the time to disposition is censored, meaning we do know that it has taken at least that long, but not how long for it to be completely resolved. \n\nThe standard form for time-to-event data are `Surv` objects which capture the time as well as the event status. As with all transformations of the response, it is advisable to do this before heading into the model fitting process with tidymodels.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nbuilding_complaints <- building_complaints %>% \n mutate(\n disposition_surv = Surv(days_to_disposition, status == \"CLOSED\"), \n .keep = \"unused\"\n )\n```\n:::\n\n\n\n\n## Data splitting and resampling\n\nFor our resampling strategy, let's use a [3-way split](https://www.tmwr.org/resampling#validation) into training, validation, and test set.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(403)\ncomplaints_split <- initial_validation_split(building_complaints)\n```\n:::\n\n\n\n\nFirst, let's pull out the training data and have a brief look at the response using a [Kaplan-Meier curve](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3059453/). \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncomplaints_train <- training(complaints_split)\n\nsurvfit(disposition_surv ~ 1, data = complaints_train) %>% plot()\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-7-1.svg){fig-align='center' fig-alt='A Kaplan-Meier curve dropping rapidly initially, then reaching about 10% survival rate at around 100 days, and finally trailing off until about 400 days.' width=672}\n:::\n:::\n\n\n\n\nWe can see that the majority of complaints is dispositioned relatively quickly, but some complaints are still active after 100 days.\n\n## A first model\n\nThe censored package includes parametric, semi-parametric, and tree-based models for this type of analysis. To start, we are fitting a parametric survival model with the default of assuming a Weibull distribution on the time to disposition. We'll explore the more flexible models once we have a sense of how well this more restrictive model performs on this dataset.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nsurvreg_spec <- survival_reg() %>% \n set_engine(\"survival\") %>% \n set_mode(\"censored regression\")\n```\n:::\n\n\n\n\nWe have several missing values in `complaint_priority` that we are turning into a separate category, `\"unknown\"`. We are also combining the less common categories for `community_board` and `unit` into an `\"other\"` category to reduce the number of levels in the predictors. The complaint category often does not tell us much more than the unit, with several complaint categories being handled by a specific unit only. This can lead to the model being unable to estimate some of the coefficients. Since our goal here is only to get a rough idea of how well the model performs, we are removing the complaint category for now.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nrec_other <- recipe(disposition_surv ~ ., data = complaints_train) %>% \n step_unknown(complaint_priority) %>% \n step_rm(complaint_category) %>% \n step_novel(community_board, unit) %>%\n step_other(community_board, unit, threshold = 0.02)\n```\n:::\n\n\n\n\nWe combine the recipe and the model into a workflow. This allows us to easily resample the model because all preprocessing steps are applied to the training set and the validation set for us.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nsurvreg_wflow <- workflow() %>% \n add_recipe(rec_other) %>% \n add_model(survreg_spec)\n```\n:::\n\n\n\n\nTo fit and evaluate the model, we need the training and validation sets. While we can access them each on their own, `validation_set()` extracts them both, in a manner that emulates a single resample of the data. This enables us to use `fit_resamples()` and other tuning functions in the same way as if we had used some other resampling scheme (such as cross-validation). \n\nWe are calculating several performance metrics: the Brier score, its integrated version, the area under the ROC curve, and the concordance index. Note that all of these are used in a version tailored to survival analysis. The concordance index uses the predicted event time to measure the model’s ability to rank the observations correctly. The Brier score and the ROC curve use the predicted probability of survival at a given time. We evaluate these metrics every 30 days up to 300 days, as provided in the `eval_time` argument. The Brier score is a measure of the accuracy of the predicted probabilities, while the ROC curve is a measure of the model’s ability to discriminate between events and non-events at the given time point. Because these metrics are defined “at a given time,” they are also referred to as *dynamic metrics*.\n\n::: {.callout-tip}\nFor more information see the [Dynamic Performance Metrics for Event Time Data](../survival-metrics/) article.\n:::\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncomplaints_rset <- validation_set(complaints_split)\n\nsurvival_metrics <- metric_set(brier_survival_integrated, brier_survival,\n roc_auc_survival, concordance_survival)\nevaluation_time_points <- seq(0, 300, 30)\n\nset.seed(1)\nsurvreg_res <- fit_resamples(\n survreg_wflow,\n resamples = complaints_rset,\n metrics = survival_metrics,\n eval_time = evaluation_time_points, \n control = control_resamples(save_pred = TRUE)\n)\n```\n:::\n\n\n\n\nThe structure of survival model predictions is slightly different from classification and regression model predictions:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\npreds <- collect_predictions(survreg_res)\npreds\n#> # A tibble: 847 × 6\n#> .pred .pred_time id .row disposition_surv .config \n#> \n#> 1 96.6 validation 2541 35+ Preprocessor1…\n#> 2 18.7 validation 2542 129+ Preprocessor1…\n#> 3 29.5 validation 2543 4+ Preprocessor1…\n#> 4 29.8 validation 2544 5+ Preprocessor1…\n#> 5 24.8 validation 2545 1+ Preprocessor1…\n#> 6 58.4 validation 2546 76+ Preprocessor1…\n#> 7 71.3 validation 2547 51+ Preprocessor1…\n#> 8 102. validation 2548 44+ Preprocessor1…\n#> 9 47.1 validation 2549 15+ Preprocessor1…\n#> 10 28.5 validation 2550 61+ Preprocessor1…\n#> # ℹ 837 more rows\n```\n:::\n\n\n\n\nThe predicted survival time is in the `.pred_time` column and the predicted survival probabilities are in the `.pred` list column. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\npreds$.pred[[6]]\n#> # A tibble: 11 × 3\n#> .eval_time .pred_survival .weight_censored\n#> \n#> 1 0 1 1 \n#> 2 30 0.554 1.04\n#> 3 60 0.360 1.19\n#> 4 90 0.245 NA \n#> 5 120 0.171 NA \n#> 6 150 0.121 NA \n#> 7 180 0.0874 NA \n#> 8 210 0.0637 NA \n#> 9 240 0.0468 NA \n#> 10 270 0.0347 NA \n#> 11 300 0.0259 NA\n```\n:::\n\n\n\n\nFor each observation, `.pred` contains a tibble with the evaluation time `.eval_time` and the corresponding survival probability `.pred_survival`. The column `.weight_censored` contains the weights used in the calculation of the dynamic performance metrics. \n\n::: {.callout-tip}\nFor details on the weights see the [Accounting for Censoring in Performance Metrics for Event Time Data](../survival-metrics-details/) article.\n:::\n\nOf the metrics we calculated with these predictions, let's take a look at the AUC ROC first.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncollect_metrics(survreg_res) %>% \n filter(.metric == \"roc_auc_survival\") %>% \n ggplot(aes(.eval_time, mean)) + \n geom_line() + \n labs(x = \"Evaluation Time\", y = \"Area Under the ROC Curve\")\n```\n\n::: {.cell-output-display}\n![](figs/survreg-roc-auc-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nWe can discriminate between events and non-events reasonably well, especially in the first 30 and 60 days. How about the probabilities that the categorization into event and non-event is based on? \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncollect_metrics(survreg_res) %>% \n filter(.metric == \"brier_survival\") %>% \n ggplot(aes(.eval_time, mean)) + \n geom_line() + \n labs(x = \"Evaluation Time\", y = \"Brier Score\")\n```\n\n::: {.cell-output-display}\n![](figs/survreg-brier-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThe accuracy of the predicted probabilities is generally good, albeit lowest for evaluation times of 30 and 60 days. The integrated Brier score is a measure of the overall accuracy of the predicted probabilities. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncollect_metrics(survreg_res) %>% \n filter(.metric == \"brier_survival_integrated\")\n#> # A tibble: 1 × 7\n#> .metric .estimator .eval_time mean n std_err .config \n#> \n#> 1 brier_survival_integrated standard NA 0.0512 1 NA Preproce…\n```\n:::\n\n\n\n\nWhich metric to optimise for depends on whether separation or calibration is more important in the modeling problem at hand. We'll go with calibration here. Since we don't have a particular evaluation time that we want to predict well at, we are going to use the integrated Brier score as our main performance metric.\n\n## Try out more models\n\nLumping factor levels together based on frequencies can lead to a loss of information so let's also try some different approaches. We can let a random forest model group the factor levels via the tree splits. Alternatively, we can turn the factors into dummy variables and use a regularized model to select relevant factor levels.\n\nFirst, let’s create the recipes for these two approaches:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nrec_unknown <- recipe(disposition_surv ~ ., data = complaints_train) %>% \n step_unknown(complaint_priority) \n\nrec_dummies <- rec_unknown %>% \n step_novel(all_nominal_predictors()) %>%\n step_dummy(all_nominal_predictors()) %>% \n step_zv(all_predictors()) %>% \n step_normalize(all_numeric_predictors())\n```\n:::\n\n\n\n\nNext, let's create the model specifications and tag several hyperparameters for tuning. \nFor the random forest, we are using the `\"aorsf\"` engine for accelerated oblique random survival forests. An oblique tree can split on linear combinations of the predictors, i.e., it provides more flexibility in the splits than a tree which splits on a single predictor.\nFor the regularized model, we are using the `\"glmnet\"` engine for a semi-parametric Cox proportional hazards model.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\noblique_spec <- rand_forest(mtry = tune(), min_n = tune()) %>% \n set_engine(\"aorsf\") %>% \n set_mode(\"censored regression\")\n\noblique_wflow <- workflow() %>% \n add_recipe(rec_unknown) %>% \n add_model(oblique_spec)\n\ncoxnet_spec <- proportional_hazards(penalty = tune()) %>% \n set_engine(\"glmnet\") %>% \n set_mode(\"censored regression\")\n\ncoxnet_wflow <- workflow() %>% \n add_recipe(rec_dummies) %>% \n add_model(coxnet_spec)\n```\n:::\n\n\n\n\nWe can tune workflows with any of the `tune_*()` functions such as `tune_grid()` for grid search or `tune_bayes()` for Bayesian optimization. Here we are using grid search for simplicity.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(1)\noblique_res <- tune_grid(\n oblique_wflow,\n resamples = complaints_rset,\n grid = 10,\n metrics = survival_metrics,\n eval_time = evaluation_time_points, \n control = control_grid(save_workflow = TRUE)\n)\n#> i Creating pre-processing data to finalize unknown parameter: mtry\n\nset.seed(1)\ncoxnet_res <- tune_grid(\n coxnet_wflow,\n resamples = complaints_rset,\n grid = 10,\n metrics = survival_metrics,\n eval_time = evaluation_time_points, \n control = control_grid(save_workflow = TRUE)\n)\n```\n:::\n\n\n\n\nSo do any of these models perform better than the parametric survival model?\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nshow_best(oblique_res, metric = \"brier_survival_integrated\", n = 5)\n#> # A tibble: 5 × 9\n#> mtry min_n .metric .estimator .eval_time mean n std_err .config\n#> \n#> 1 9 27 brier_survival… standard NA 0.0469 1 NA Prepro…\n#> 2 6 23 brier_survival… standard NA 0.0469 1 NA Prepro…\n#> 3 5 6 brier_survival… standard NA 0.0471 1 NA Prepro…\n#> 4 8 10 brier_survival… standard NA 0.0472 1 NA Prepro…\n#> 5 7 40 brier_survival… standard NA 0.0475 1 NA Prepro…\n\nshow_best(coxnet_res, metric = \"brier_survival_integrated\", n = 5)\n#> # A tibble: 5 × 8\n#> penalty .metric .estimator .eval_time mean n std_err .config\n#> \n#> 1 0.00517 brier_surviv… standard NA 0.0499 1 NA Prepro…\n#> 2 0.000000316 brier_surviv… standard NA 0.0506 1 NA Prepro…\n#> 3 0.0000379 brier_surviv… standard NA 0.0506 1 NA Prepro…\n#> 4 0.00000000240 brier_surviv… standard NA 0.0506 1 NA Prepro…\n#> 5 0.0000000277 brier_surviv… standard NA 0.0506 1 NA Prepro…\n```\n:::\n\n::: {.cell layout-align=\"center\"}\n\n:::\n\n\n\n\nThe best regularized Cox model performs a little better than the parametric survival model, with an integrated Brier score of 0.0499 compared to 0.0512 for the parametric model. The random forest performs yet a little better with an integrated Brier score of 0.0469.\n\n## The final model\n\nWe chose the random forest model as the final model. So let's finalize the workflow by replacing the `tune()` placeholders with the best hyperparameters.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nparam_best <- select_best(oblique_res, metric = \"brier_survival_integrated\")\n\nlast_oblique_wflow <- finalize_workflow(oblique_wflow, param_best)\n```\n:::\n\n\n\n\nWe can now fit the final model on the training data and evaluate it on the test data.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(2)\nlast_oblique_fit <- last_fit(\n last_oblique_wflow, \n split = complaints_split,\n metrics = survival_metrics,\n eval_time = evaluation_time_points, \n)\n\ncollect_metrics(last_oblique_fit) %>% \n filter(.metric == \"brier_survival_integrated\")\n#> # A tibble: 1 × 5\n#> .metric .estimator .estimate .eval_time .config \n#> \n#> 1 brier_survival_integrated standard 0.0431 NA Preprocessor1_Model1\n```\n:::\n\n\n\n\nThe Brier score across the different evaluation time points is also very similar between the validation set and the test set.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nbrier_val <- collect_metrics(oblique_res) %>% \n filter(.metric == \"brier_survival\") %>% \n filter(mtry == param_best$mtry, min_n == param_best$min_n) %>% \n mutate(Data = \"Validation\") \nbrier_test <- collect_metrics(last_oblique_fit) %>% \n filter(.metric == \"brier_survival\") %>% \n mutate(Data = \"Testing\") %>% \n rename(mean = .estimate)\nbind_rows(brier_val, brier_test) %>% \n ggplot(aes(.eval_time, mean, col = Data)) + \n geom_line() + \n labs(x = \"Evaluation Time\", y = \"Brier Score\")\n```\n\n::: {.cell-output-display}\n![](figs/final-fit-brier-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nTo finish, we can extract the fitted workflow to either predict directly on new data or deploy the model.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncomplaints_model <- extract_workflow(last_oblique_fit)\n\ncomplaints_5 <- testing(complaints_split) %>% slice(1:5)\npredict(complaints_model, new_data = complaints_5, type = \"time\")\n#> # A tibble: 5 × 1\n#> .pred_time\n#> \n#> 1 81.1\n#> 2 47.4\n#> 3 96.4\n#> 4 79.9\n#> 5 77.7\n```\n:::\n\n\n\n\nFor more information on survival analysis with tidymodels see the [`survival analysis` tag](https://www.tidymodels.org/learn/index.html#category=survival%20analysis).\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> aorsf 0.1.5 2024-05-30 CRAN (R 4.4.0)\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> censored 0.3.3 2025-02-14 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> glmnet 4.1-8 2023-08-22 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> modeldatatoo 0.3.0 2024-03-29 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", + "markdown": "---\ntitle: \"How long until building complaints are dispositioned? A survival analysis case study\"\ncategories:\n - statistical analysis\n - survival analysis\ntype: learn-subsection\nweight: 9\ndescription: | \n Learn how to use tidymodels for survival analysis.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: aorsf, censored, glmnet, modeldatatoo, and tidymodels.\n\nSurvival analysis is a field of statistics and machine learning for analyzing the time to an event. While it has its roots in medical research, the event of interest can be anything from customer churn to machine failure. Methods from survival analysis take into account that some observations may not yet have experienced the event of interest and are thus _censored_. \n\nHere we want to predict the time it takes for a complaint to be dispositioned^[In this context, the term _disposition_ means that there has been a decision or resolution regarding the complaint that is the conclusion of the process.] by the Department of Buildings in New York City. We are going to walk through a complete analysis from beginning to end, showing how to analyze time-to-event data.\n\nLet's start with loading the tidymodels and censored packages (the parsnip extension package for survival analysis models).\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nlibrary(censored)\n#> Loading required package: survival\n```\n:::\n\n\n\n\n## The buildings complaints data\n\nThe city of New York publishes data on the [complaints](https://data.cityofnewyork.us/Housing-Development/DOB-Complaints-Received/eabe-havv/about_data) received by the Department of Buildings. The data includes information on the type of complaint, the date it was entered in their records, the date it was dispositioned, and the location of the building the complaint was about. We are using a subset of the data, available in the modeldatatoo package.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nbuilding_complaints <- modeldatatoo::data_building_complaints()\nglimpse(building_complaints)\n#> Rows: 4,234\n#> Columns: 11\n#> $ days_to_disposition 72, 1, 41, 45, 16, 62, 56, 11, 35, 38, 39, 106, 1,…\n#> $ status \"ACTIVE\", \"ACTIVE\", \"ACTIVE\", \"ACTIVE\", \"ACTIVE\", …\n#> $ year_entered 2023, 2023, 2023, 2023, 2023, 2023, 2023, 2023, 20…\n#> $ latitude 40.66173, 40.57668, 40.73242, 40.68245, 40.63156, …\n#> $ longitude -73.98297, -74.00453, -73.87630, -73.79367, -73.99…\n#> $ borough Brooklyn, Brooklyn, Queens, Queens, Brooklyn, Quee…\n#> $ special_district None, None, None, None, None, None, None, None, No…\n#> $ unit Q-L, Q-L, SPOPS, Q-L, BKLYN, Q-L, Q-L, SPOPS, Q-L,…\n#> $ community_board 307, 313, 404, 412, 312, 406, 306, 306, 409, 404, …\n#> $ complaint_category 45, 45, 49, 45, 31, 45, 45, 49, 45, 45, 45, 4A, 31…\n#> $ complaint_priority B, B, C, B, C, B, B, C, B, B, B, B, C, C, B, B, B,…\n```\n:::\n\n\n\n\nBefore we dive into survival analysis, let's get a impression of how the complaints are distributed across the city. We have complaints in all five boroughs, albeit with a somewhat lower density of complaints in Staten Island.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n::: {.cell-output-display}\n\n```{=html}\n
\n\n```\n\n\nBuilding complaints in New York City (closed complaints in purple, active complaints in pink).\n:::\n:::\n\n\n\n\nIn the dataset, we can see the `days_to_disposition` as well as the `status` of the complaint. For a complaint with the status `\"ACTIVE\"`, the time to disposition is censored, meaning we do know that it has taken at least that long, but not how long for it to be completely resolved. \n\nThe standard form for time-to-event data are `Surv` objects which capture the time as well as the event status. As with all transformations of the response, it is advisable to do this before heading into the model fitting process with tidymodels.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nbuilding_complaints <- building_complaints %>% \n mutate(\n disposition_surv = Surv(days_to_disposition, status == \"CLOSED\"), \n .keep = \"unused\"\n )\n```\n:::\n\n\n\n\n## Data splitting and resampling\n\nFor our resampling strategy, let's use a [3-way split](https://www.tmwr.org/resampling#validation) into training, validation, and test set.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(403)\ncomplaints_split <- initial_validation_split(building_complaints)\n```\n:::\n\n\n\n\nFirst, let's pull out the training data and have a brief look at the response using a [Kaplan-Meier curve](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3059453/). \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncomplaints_train <- training(complaints_split)\n\nsurvfit(disposition_surv ~ 1, data = complaints_train) %>% plot()\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-7-1.svg){fig-align='center' fig-alt='A Kaplan-Meier curve dropping rapidly initially, then reaching about 10% survival rate at around 100 days, and finally trailing off until about 400 days.' width=672}\n:::\n:::\n\n\n\n\nWe can see that the majority of complaints is dispositioned relatively quickly, but some complaints are still active after 100 days.\n\n## A first model\n\nThe censored package includes parametric, semi-parametric, and tree-based models for this type of analysis. To start, we are fitting a parametric survival model with the default of assuming a Weibull distribution on the time to disposition. We'll explore the more flexible models once we have a sense of how well this more restrictive model performs on this dataset.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nsurvreg_spec <- survival_reg() %>% \n set_engine(\"survival\") %>% \n set_mode(\"censored regression\")\n```\n:::\n\n\n\n\nWe have several missing values in `complaint_priority` that we are turning into a separate category, `\"unknown\"`. We are also combining the less common categories for `community_board` and `unit` into an `\"other\"` category to reduce the number of levels in the predictors. The complaint category often does not tell us much more than the unit, with several complaint categories being handled by a specific unit only. This can lead to the model being unable to estimate some of the coefficients. Since our goal here is only to get a rough idea of how well the model performs, we are removing the complaint category for now.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nrec_other <- recipe(disposition_surv ~ ., data = complaints_train) %>% \n step_unknown(complaint_priority) %>% \n step_rm(complaint_category) %>% \n step_novel(community_board, unit) %>%\n step_other(community_board, unit, threshold = 0.02)\n```\n:::\n\n\n\n\nWe combine the recipe and the model into a workflow. This allows us to easily resample the model because all preprocessing steps are applied to the training set and the validation set for us.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nsurvreg_wflow <- workflow() %>% \n add_recipe(rec_other) %>% \n add_model(survreg_spec)\n```\n:::\n\n\n\n\nTo fit and evaluate the model, we need the training and validation sets. While we can access them each on their own, `validation_set()` extracts them both, in a manner that emulates a single resample of the data. This enables us to use `fit_resamples()` and other tuning functions in the same way as if we had used some other resampling scheme (such as cross-validation). \n\nWe are calculating several performance metrics: the Brier score, its integrated version, the area under the ROC curve, and the concordance index. Note that all of these are used in a version tailored to survival analysis. The concordance index uses the predicted event time to measure the model’s ability to rank the observations correctly. The Brier score and the ROC curve use the predicted probability of survival at a given time. We evaluate these metrics every 30 days up to 300 days, as provided in the `eval_time` argument. The Brier score is a measure of the accuracy of the predicted probabilities, while the ROC curve is a measure of the model’s ability to discriminate between events and non-events at the given time point. Because these metrics are defined “at a given time,” they are also referred to as *dynamic metrics*.\n\n::: {.callout-tip}\nFor more information see the [Dynamic Performance Metrics for Event Time Data](../survival-metrics/) article.\n:::\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncomplaints_rset <- validation_set(complaints_split)\n\nsurvival_metrics <- metric_set(brier_survival_integrated, brier_survival,\n roc_auc_survival, concordance_survival)\nevaluation_time_points <- seq(0, 300, 30)\n\nset.seed(1)\nsurvreg_res <- fit_resamples(\n survreg_wflow,\n resamples = complaints_rset,\n metrics = survival_metrics,\n eval_time = evaluation_time_points, \n control = control_resamples(save_pred = TRUE)\n)\n```\n:::\n\n\n\n\nThe structure of survival model predictions is slightly different from classification and regression model predictions:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\npreds <- collect_predictions(survreg_res)\npreds\n#> # A tibble: 847 × 6\n#> .pred .pred_time id .row disposition_surv .config \n#> \n#> 1 96.6 validation 2541 35+ Preprocessor1…\n#> 2 18.7 validation 2542 129+ Preprocessor1…\n#> 3 29.5 validation 2543 4+ Preprocessor1…\n#> 4 29.8 validation 2544 5+ Preprocessor1…\n#> 5 24.8 validation 2545 1+ Preprocessor1…\n#> 6 58.4 validation 2546 76+ Preprocessor1…\n#> 7 71.3 validation 2547 51+ Preprocessor1…\n#> 8 102. validation 2548 44+ Preprocessor1…\n#> 9 47.1 validation 2549 15+ Preprocessor1…\n#> 10 28.5 validation 2550 61+ Preprocessor1…\n#> # ℹ 837 more rows\n```\n:::\n\n\n\n\nThe predicted survival time is in the `.pred_time` column and the predicted survival probabilities are in the `.pred` list column. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\npreds$.pred[[6]]\n#> # A tibble: 11 × 3\n#> .eval_time .pred_survival .weight_censored\n#> \n#> 1 0 1 1 \n#> 2 30 0.554 1.04\n#> 3 60 0.360 1.19\n#> 4 90 0.245 NA \n#> 5 120 0.171 NA \n#> 6 150 0.121 NA \n#> 7 180 0.0874 NA \n#> 8 210 0.0637 NA \n#> 9 240 0.0468 NA \n#> 10 270 0.0347 NA \n#> 11 300 0.0259 NA\n```\n:::\n\n\n\n\nFor each observation, `.pred` contains a tibble with the evaluation time `.eval_time` and the corresponding survival probability `.pred_survival`. The column `.weight_censored` contains the weights used in the calculation of the dynamic performance metrics. \n\n::: {.callout-tip}\nFor details on the weights see the [Accounting for Censoring in Performance Metrics for Event Time Data](../survival-metrics-details/) article.\n:::\n\nOf the metrics we calculated with these predictions, let's take a look at the AUC ROC first.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncollect_metrics(survreg_res) %>% \n filter(.metric == \"roc_auc_survival\") %>% \n ggplot(aes(.eval_time, mean)) + \n geom_line() + \n labs(x = \"Evaluation Time\", y = \"Area Under the ROC Curve\")\n```\n\n::: {.cell-output-display}\n![](figs/survreg-roc-auc-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nWe can discriminate between events and non-events reasonably well, especially in the first 30 and 60 days. How about the probabilities that the categorization into event and non-event is based on? \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncollect_metrics(survreg_res) %>% \n filter(.metric == \"brier_survival\") %>% \n ggplot(aes(.eval_time, mean)) + \n geom_line() + \n labs(x = \"Evaluation Time\", y = \"Brier Score\")\n```\n\n::: {.cell-output-display}\n![](figs/survreg-brier-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThe accuracy of the predicted probabilities is generally good, albeit lowest for evaluation times of 30 and 60 days. The integrated Brier score is a measure of the overall accuracy of the predicted probabilities. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncollect_metrics(survreg_res) %>% \n filter(.metric == \"brier_survival_integrated\")\n#> # A tibble: 1 × 7\n#> .metric .estimator .eval_time mean n std_err .config \n#> \n#> 1 brier_survival_integrated standard NA 0.0512 1 NA Preproce…\n```\n:::\n\n\n\n\nWhich metric to optimise for depends on whether separation or calibration is more important in the modeling problem at hand. We'll go with calibration here. Since we don't have a particular evaluation time that we want to predict well at, we are going to use the integrated Brier score as our main performance metric.\n\n## Try out more models\n\nLumping factor levels together based on frequencies can lead to a loss of information so let's also try some different approaches. We can let a random forest model group the factor levels via the tree splits. Alternatively, we can turn the factors into dummy variables and use a regularized model to select relevant factor levels.\n\nFirst, let’s create the recipes for these two approaches:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nrec_unknown <- recipe(disposition_surv ~ ., data = complaints_train) %>% \n step_unknown(complaint_priority) \n\nrec_dummies <- rec_unknown %>% \n step_novel(all_nominal_predictors()) %>%\n step_dummy(all_nominal_predictors()) %>% \n step_zv(all_predictors()) %>% \n step_normalize(all_numeric_predictors())\n```\n:::\n\n\n\n\nNext, let's create the model specifications and tag several hyperparameters for tuning. \nFor the random forest, we are using the `\"aorsf\"` engine for accelerated oblique random survival forests. An oblique tree can split on linear combinations of the predictors, i.e., it provides more flexibility in the splits than a tree which splits on a single predictor.\nFor the regularized model, we are using the `\"glmnet\"` engine for a semi-parametric Cox proportional hazards model.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\noblique_spec <- rand_forest(mtry = tune(), min_n = tune()) %>% \n set_engine(\"aorsf\") %>% \n set_mode(\"censored regression\")\n\noblique_wflow <- workflow() %>% \n add_recipe(rec_unknown) %>% \n add_model(oblique_spec)\n\ncoxnet_spec <- proportional_hazards(penalty = tune()) %>% \n set_engine(\"glmnet\") %>% \n set_mode(\"censored regression\")\n\ncoxnet_wflow <- workflow() %>% \n add_recipe(rec_dummies) %>% \n add_model(coxnet_spec)\n```\n:::\n\n\n\n\nWe can tune workflows with any of the `tune_*()` functions such as `tune_grid()` for grid search or `tune_bayes()` for Bayesian optimization. Here we are using grid search for simplicity.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(1)\noblique_res <- tune_grid(\n oblique_wflow,\n resamples = complaints_rset,\n grid = 10,\n metrics = survival_metrics,\n eval_time = evaluation_time_points, \n control = control_grid(save_workflow = TRUE)\n)\n#> i Creating pre-processing data to finalize unknown parameter: mtry\n\nset.seed(1)\ncoxnet_res <- tune_grid(\n coxnet_wflow,\n resamples = complaints_rset,\n grid = 10,\n metrics = survival_metrics,\n eval_time = evaluation_time_points, \n control = control_grid(save_workflow = TRUE)\n)\n```\n:::\n\n\n\n\nSo do any of these models perform better than the parametric survival model?\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nshow_best(oblique_res, metric = \"brier_survival_integrated\", n = 5)\n#> # A tibble: 5 × 9\n#> mtry min_n .metric .estimator .eval_time mean n std_err .config\n#> \n#> 1 9 27 brier_survival… standard NA 0.0469 1 NA Prepro…\n#> 2 6 23 brier_survival… standard NA 0.0469 1 NA Prepro…\n#> 3 5 6 brier_survival… standard NA 0.0471 1 NA Prepro…\n#> 4 8 10 brier_survival… standard NA 0.0472 1 NA Prepro…\n#> 5 7 40 brier_survival… standard NA 0.0475 1 NA Prepro…\n\nshow_best(coxnet_res, metric = \"brier_survival_integrated\", n = 5)\n#> # A tibble: 5 × 8\n#> penalty .metric .estimator .eval_time mean n std_err .config\n#> \n#> 1 0.00517 brier_surviv… standard NA 0.0499 1 NA Prepro…\n#> 2 0.000000316 brier_surviv… standard NA 0.0506 1 NA Prepro…\n#> 3 0.0000379 brier_surviv… standard NA 0.0506 1 NA Prepro…\n#> 4 0.00000000240 brier_surviv… standard NA 0.0506 1 NA Prepro…\n#> 5 0.0000000277 brier_surviv… standard NA 0.0506 1 NA Prepro…\n```\n:::\n\n::: {.cell layout-align=\"center\"}\n\n:::\n\n\n\n\nThe best regularized Cox model performs a little better than the parametric survival model, with an integrated Brier score of 0.0499 compared to 0.0512 for the parametric model. The random forest performs yet a little better with an integrated Brier score of 0.0469.\n\n## The final model\n\nWe chose the random forest model as the final model. So let's finalize the workflow by replacing the `tune()` placeholders with the best hyperparameters.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nparam_best <- select_best(oblique_res, metric = \"brier_survival_integrated\")\n\nlast_oblique_wflow <- finalize_workflow(oblique_wflow, param_best)\n```\n:::\n\n\n\n\nWe can now fit the final model on the training data and evaluate it on the test data.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(2)\nlast_oblique_fit <- last_fit(\n last_oblique_wflow, \n split = complaints_split,\n metrics = survival_metrics,\n eval_time = evaluation_time_points, \n)\n\ncollect_metrics(last_oblique_fit) %>% \n filter(.metric == \"brier_survival_integrated\")\n#> # A tibble: 1 × 5\n#> .metric .estimator .estimate .eval_time .config \n#> \n#> 1 brier_survival_integrated standard 0.0431 NA Preprocessor1_Model1\n```\n:::\n\n\n\n\nThe Brier score across the different evaluation time points is also very similar between the validation set and the test set.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nbrier_val <- collect_metrics(oblique_res) %>% \n filter(.metric == \"brier_survival\") %>% \n filter(mtry == param_best$mtry, min_n == param_best$min_n) %>% \n mutate(Data = \"Validation\") \nbrier_test <- collect_metrics(last_oblique_fit) %>% \n filter(.metric == \"brier_survival\") %>% \n mutate(Data = \"Testing\") %>% \n rename(mean = .estimate)\nbind_rows(brier_val, brier_test) %>% \n ggplot(aes(.eval_time, mean, col = Data)) + \n geom_line() + \n labs(x = \"Evaluation Time\", y = \"Brier Score\")\n```\n\n::: {.cell-output-display}\n![](figs/final-fit-brier-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nTo finish, we can extract the fitted workflow to either predict directly on new data or deploy the model.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncomplaints_model <- extract_workflow(last_oblique_fit)\n\ncomplaints_5 <- testing(complaints_split) %>% slice(1:5)\npredict(complaints_model, new_data = complaints_5, type = \"time\")\n#> # A tibble: 5 × 1\n#> .pred_time\n#> \n#> 1 81.1\n#> 2 47.4\n#> 3 96.4\n#> 4 79.9\n#> 5 77.7\n```\n:::\n\n\n\n\nFor more information on survival analysis with tidymodels see the [`survival analysis` tag](https://www.tidymodels.org/learn/index.html#category=survival%20analysis).\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> aorsf 0.1.5 2024-05-30 CRAN (R 4.4.0)\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> censored 0.3.3 2025-02-14 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> glmnet 4.1-8 2023-08-22 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> modeldatatoo 0.3.0 2024-03-29 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", "supporting": [], "filters": [ "rmarkdown/pagebreak.lua" diff --git a/learn/statistics/survival-case-study/index.html.md b/learn/statistics/survival-case-study/index.html.md index c0c30a2e..94051ee5 100644 --- a/learn/statistics/survival-case-study/index.html.md +++ b/learn/statistics/survival-case-study/index.html.md @@ -62,8 +62,8 @@ Before we dive into survival analysis, let's get a impression of how the complai ::: {.cell-output-display} ```{=html} -
- +
+ ``` Building complaints in New York City (closed complaints in purple, active complaints in pink). diff --git a/learn/statistics/survival-case-study/index.qmd b/learn/statistics/survival-case-study/index.qmd index be436a1e..ee44b2ef 100644 --- a/learn/statistics/survival-case-study/index.qmd +++ b/learn/statistics/survival-case-study/index.qmd @@ -26,6 +26,8 @@ source(here::here("common.R")) library(tidymodels) library(sessioninfo) library(leaflet) +library(htmlwidgets) +setWidgetIdSeed(1234) pkgs <- c("tidymodels", "censored", "modeldatatoo", "glmnet", "aorsf") theme_set(theme_bw() + theme(legend.position = "top")) ``` From e385345f8e11d60c894c7bb8fb827ca469a6b9d8 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 24 Mar 2025 12:22:10 -0700 Subject: [PATCH 2/7] set seed in learn/models/sub-sampling --- .../index/execute-results/html.json | 4 +- .../sub-sampling/figs/merge-metrics-1.svg | 259 +++++++++--------- learn/models/sub-sampling/index.html.md | 7 +- learn/models/sub-sampling/index.qmd | 3 +- 4 files changed, 136 insertions(+), 137 deletions(-) diff --git a/_freeze/learn/models/sub-sampling/index/execute-results/html.json b/_freeze/learn/models/sub-sampling/index/execute-results/html.json index 6682b681..cbfa10e6 100644 --- a/_freeze/learn/models/sub-sampling/index/execute-results/html.json +++ b/_freeze/learn/models/sub-sampling/index/execute-results/html.json @@ -1,8 +1,8 @@ { - "hash": "7700bbf97d82157e85703159af1aab39", + "hash": "47f50af03cb12d3ec299b80c5d93b121", "result": { "engine": "knitr", - "markdown": "---\ntitle: \"Subsampling for class imbalances\"\ncategories:\n - model fitting\n - pre-processing\ntype: learn-subsection\nweight: 3\ndescription: | \n Improve model performance in imbalanced data sets through undersampling or oversampling.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: discrim, klaR, readr, ROSE, themis, and tidymodels.\n\nSubsampling a training set, either undersampling or oversampling the appropriate class or classes, can be a helpful approach to dealing with classification data where one or more classes occur very infrequently. In such a situation (without compensating for it), most models will overfit to the majority class and produce very good statistics for the class containing the frequently occurring classes while the minority classes have poor performance. \n\nThis article describes subsampling for dealing with class imbalances. For better understanding, some knowledge of classification metrics like sensitivity, specificity, and receiver operating characteristic curves is required. See Section 3.2.2 in [Kuhn and Johnson (2019)](https://bookdown.org/max/FES/measuring-performance.html) for more information on these metrics. \n\n## Simulated data\n\nConsider a two-class problem where the first class has a very low rate of occurrence. The data were simulated and can be imported into R using the code below:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nimbal_data <- \n \n readr::read_csv(\"https://tidymodels.org/learn/models/sub-sampling/imbal_data.csv\") %>% \n mutate(Class = factor(Class))\ndim(imbal_data)\n#> [1] 1200 16\ntable(imbal_data$Class)\n#> \n#> Class1 Class2 \n#> 60 1140\n```\n:::\n\n\n\n\nIf \"Class1\" is the event of interest, it is very likely that a classification model would be able to achieve very good _specificity_ since almost all of the data are of the second class. _Sensitivity_, however, would likely be poor since the models will optimize accuracy (or other loss functions) by predicting everything to be the majority class. \n\nOne result of class imbalance when there are two classes is that the default probability cutoff of 50% is inappropriate; a different cutoff that is more extreme might be able to achieve good performance. \n\n## Subsampling the data\n\nOne way to alleviate this issue is to _subsample_ the data. There are a number of ways to do this but the most simple one is to _sample down_ (undersample) the majority class data until it occurs with the same frequency as the minority class. While it may seem counterintuitive, throwing out a large percentage of your data can be effective at producing a useful model that can recognize both the majority and minority classes. In some cases, this even means that the overall performance of the model is better (e.g. improved area under the ROC curve). However, subsampling almost always produces models that are _better calibrated_, meaning that the distributions of the class probabilities are more well behaved. As a result, the default 50% cutoff is much more likely to produce better sensitivity and specificity values than they would otherwise. \n\nLet's explore subsampling using `themis::step_rose()` in a recipe for the simulated data. It uses the ROSE (random over sampling examples) method from [Menardi, G. and Torelli, N. (2014)](https://scholar.google.com/scholar?hl=en&q=%22training+and+assessing+classification+rules+with+imbalanced+data%22). This is an example of an oversampling strategy, rather than undersampling.\n\nIn terms of workflow:\n\n * It is extremely important that subsampling occurs _inside of resampling_. Otherwise, the resampling process can produce [poor estimates of model performance](https://topepo.github.io/caret/subsampling-for-class-imbalances.html#resampling). \n * The subsampling process should only be applied to the analysis set. The assessment set should reflect the event rates seen \"in the wild\" and, for this reason, the `skip` argument to `step_downsample()` and other subsampling recipes steps has a default of `TRUE`. \n\nHere is a simple recipe implementing oversampling: \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nlibrary(themis)\nimbal_rec <- \n recipe(Class ~ ., data = imbal_data) %>%\n step_rose(Class)\n```\n:::\n\n\n\n\nFor a model, let's use a [quadratic discriminant analysis](https://en.wikipedia.org/wiki/Quadratic_classifier#Quadratic_discriminant_analysis) (QDA) model. From the discrim package, this model can be specified using:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(discrim)\nqda_mod <- \n discrim_regularized(frac_common_cov = 0, frac_identity = 0) %>% \n set_engine(\"klaR\")\n```\n:::\n\n\n\n\nTo keep these objects bound together, they can be combined in a [workflow](https://workflows.tidymodels.org/):\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nqda_rose_wflw <- \n workflow() %>% \n add_model(qda_mod) %>% \n add_recipe(imbal_rec)\nqda_rose_wflw\n#> ══ Workflow ══════════════════════════════════════════════════════════\n#> Preprocessor: Recipe\n#> Model: discrim_regularized()\n#> \n#> ── Preprocessor ──────────────────────────────────────────────────────\n#> 1 Recipe Step\n#> \n#> • step_rose()\n#> \n#> ── Model ─────────────────────────────────────────────────────────────\n#> Regularized Discriminant Model Specification (classification)\n#> \n#> Main Arguments:\n#> frac_common_cov = 0\n#> frac_identity = 0\n#> \n#> Computational engine: klaR\n```\n:::\n\n\n\n\n## Model performance\n\nStratified, repeated 10-fold cross-validation is used to resample the model:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(5732)\ncv_folds <- vfold_cv(imbal_data, strata = \"Class\", repeats = 5)\n```\n:::\n\n\n\n\nTo measure model performance, let's use two metrics:\n\n * The area under the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) is an overall assessment of performance across _all_ cutoffs. Values near one indicate very good results while values near 0.5 would imply that the model is very poor. \n * The _J_ index (a.k.a. [Youden's _J_](https://en.wikipedia.org/wiki/Youden%27s_J_statistic) statistic) is `sensitivity + specificity - 1`. Values near one are once again best. \n\nIf a model is poorly calibrated, the ROC curve value might not show diminished performance. However, the _J_ index would be lower for models with pathological distributions for the class probabilities. The yardstick package will be used to compute these metrics. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncls_metrics <- metric_set(roc_auc, j_index)\n```\n:::\n\n\n\n\nNow, we train the models and generate the results using `tune::fit_resamples()`:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(2180)\nqda_rose_res <- fit_resamples(\n qda_rose_wflw, \n resamples = cv_folds, \n metrics = cls_metrics\n)\n\ncollect_metrics(qda_rose_res)\n#> # A tibble: 2 × 6\n#> .metric .estimator mean n std_err .config \n#> \n#> 1 j_index binary 0.804 50 0.0178 Preprocessor1_Model1\n#> 2 roc_auc binary 0.953 50 0.00459 Preprocessor1_Model1\n```\n:::\n\n\n\n\nWhat do the results look like without using ROSE? We can create another workflow and fit the QDA model along the same resamples:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nqda_wflw <- \n workflow() %>% \n add_model(qda_mod) %>% \n add_formula(Class ~ .)\n\nset.seed(2180)\nqda_only_res <- fit_resamples(qda_wflw, resamples = cv_folds, metrics = cls_metrics)\ncollect_metrics(qda_only_res)\n#> # A tibble: 2 × 6\n#> .metric .estimator mean n std_err .config \n#> \n#> 1 j_index binary 0.250 50 0.0288 Preprocessor1_Model1\n#> 2 roc_auc binary 0.953 50 0.00479 Preprocessor1_Model1\n```\n:::\n\n\n\n\nIt looks like ROSE helped a lot, especially with the J-index. Class imbalance sampling methods tend to greatly improve metrics based on the hard class predictions (i.e., the categorical predictions) because the default cutoff tends to be a better balance of sensitivity and specificity. \n\nLet's plot the metrics for each resample to see how the individual results changed. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nno_sampling <- \n qda_only_res %>% \n collect_metrics(summarize = FALSE) %>% \n dplyr::select(-.estimator) %>% \n mutate(sampling = \"no_sampling\")\n\nwith_sampling <- \n qda_rose_res %>% \n collect_metrics(summarize = FALSE) %>% \n dplyr::select(-.estimator) %>% \n mutate(sampling = \"rose\")\n\nbind_rows(no_sampling, with_sampling) %>% \n mutate(label = paste(id2, id)) %>% \n ggplot(aes(x = sampling, y = .estimate, group = label)) + \n geom_line(alpha = .4) + \n facet_wrap(~ .metric, scales = \"free_y\")\n```\n\n::: {.cell-output-display}\n![](figs/merge-metrics-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThis visually demonstrates that the subsampling mostly affects metrics that use the hard class predictions. \n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> discrim 1.0.1 2023-03-08 CRAN (R 4.4.0)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> klaR 1.7-3 2023-12-13 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> readr 2.1.5 2024-01-10 CRAN (R 4.4.0)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> ROSE 0.0-4 2021-06-14 CRAN (R 4.4.0)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> themis 1.0.3 2025-01-23 CRAN (R 4.4.1)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", + "markdown": "---\ntitle: \"Subsampling for class imbalances\"\ncategories:\n - model fitting\n - pre-processing\ntype: learn-subsection\nweight: 3\ndescription: | \n Improve model performance in imbalanced data sets through undersampling or oversampling.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: discrim, klaR, readr, ROSE, themis, and tidymodels.\n\nSubsampling a training set, either undersampling or oversampling the appropriate class or classes, can be a helpful approach to dealing with classification data where one or more classes occur very infrequently. In such a situation (without compensating for it), most models will overfit to the majority class and produce very good statistics for the class containing the frequently occurring classes while the minority classes have poor performance. \n\nThis article describes subsampling for dealing with class imbalances. For better understanding, some knowledge of classification metrics like sensitivity, specificity, and receiver operating characteristic curves is required. See Section 3.2.2 in [Kuhn and Johnson (2019)](https://bookdown.org/max/FES/measuring-performance.html) for more information on these metrics. \n\n## Simulated data\n\nConsider a two-class problem where the first class has a very low rate of occurrence. The data were simulated and can be imported into R using the code below:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nimbal_data <- \n readr::read_csv(\"https://tidymodels.org/learn/models/sub-sampling/imbal_data.csv\") %>% \n mutate(Class = factor(Class))\ndim(imbal_data)\n#> [1] 1200 16\ntable(imbal_data$Class)\n#> \n#> Class1 Class2 \n#> 60 1140\n```\n:::\n\n\n\n\n\nIf \"Class1\" is the event of interest, it is very likely that a classification model would be able to achieve very good _specificity_ since almost all of the data are of the second class. _Sensitivity_, however, would likely be poor since the models will optimize accuracy (or other loss functions) by predicting everything to be the majority class. \n\nOne result of class imbalance when there are two classes is that the default probability cutoff of 50% is inappropriate; a different cutoff that is more extreme might be able to achieve good performance. \n\n## Subsampling the data\n\nOne way to alleviate this issue is to _subsample_ the data. There are a number of ways to do this but the most simple one is to _sample down_ (undersample) the majority class data until it occurs with the same frequency as the minority class. While it may seem counterintuitive, throwing out a large percentage of your data can be effective at producing a useful model that can recognize both the majority and minority classes. In some cases, this even means that the overall performance of the model is better (e.g. improved area under the ROC curve). However, subsampling almost always produces models that are _better calibrated_, meaning that the distributions of the class probabilities are more well behaved. As a result, the default 50% cutoff is much more likely to produce better sensitivity and specificity values than they would otherwise. \n\nLet's explore subsampling using `themis::step_rose()` in a recipe for the simulated data. It uses the ROSE (random over sampling examples) method from [Menardi, G. and Torelli, N. (2014)](https://scholar.google.com/scholar?hl=en&q=%22training+and+assessing+classification+rules+with+imbalanced+data%22). This is an example of an oversampling strategy, rather than undersampling.\n\nIn terms of workflow:\n\n * It is extremely important that subsampling occurs _inside of resampling_. Otherwise, the resampling process can produce [poor estimates of model performance](https://topepo.github.io/caret/subsampling-for-class-imbalances.html#resampling). \n * The subsampling process should only be applied to the analysis set. The assessment set should reflect the event rates seen \"in the wild\" and, for this reason, the `skip` argument to `step_downsample()` and other subsampling recipes steps has a default of `TRUE`. \n\nHere is a simple recipe implementing oversampling: \n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nlibrary(themis)\nset.seed(1234)\n\nimbal_rec <- \n recipe(Class ~ ., data = imbal_data) %>%\n step_rose(Class)\n```\n:::\n\n\n\n\n\nFor a model, let's use a [quadratic discriminant analysis](https://en.wikipedia.org/wiki/Quadratic_classifier#Quadratic_discriminant_analysis) (QDA) model. From the discrim package, this model can be specified using:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(discrim)\nqda_mod <- \n discrim_regularized(frac_common_cov = 0, frac_identity = 0) %>% \n set_engine(\"klaR\")\n```\n:::\n\n\n\n\n\nTo keep these objects bound together, they can be combined in a [workflow](https://workflows.tidymodels.org/):\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nqda_rose_wflw <- \n workflow() %>% \n add_model(qda_mod) %>% \n add_recipe(imbal_rec)\nqda_rose_wflw\n#> ══ Workflow ══════════════════════════════════════════════════════════\n#> Preprocessor: Recipe\n#> Model: discrim_regularized()\n#> \n#> ── Preprocessor ──────────────────────────────────────────────────────\n#> 1 Recipe Step\n#> \n#> • step_rose()\n#> \n#> ── Model ─────────────────────────────────────────────────────────────\n#> Regularized Discriminant Model Specification (classification)\n#> \n#> Main Arguments:\n#> frac_common_cov = 0\n#> frac_identity = 0\n#> \n#> Computational engine: klaR\n```\n:::\n\n\n\n\n\n## Model performance\n\nStratified, repeated 10-fold cross-validation is used to resample the model:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(5732)\ncv_folds <- vfold_cv(imbal_data, strata = \"Class\", repeats = 5)\n```\n:::\n\n\n\n\n\nTo measure model performance, let's use two metrics:\n\n * The area under the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) is an overall assessment of performance across _all_ cutoffs. Values near one indicate very good results while values near 0.5 would imply that the model is very poor. \n * The _J_ index (a.k.a. [Youden's _J_](https://en.wikipedia.org/wiki/Youden%27s_J_statistic) statistic) is `sensitivity + specificity - 1`. Values near one are once again best. \n\nIf a model is poorly calibrated, the ROC curve value might not show diminished performance. However, the _J_ index would be lower for models with pathological distributions for the class probabilities. The yardstick package will be used to compute these metrics. \n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncls_metrics <- metric_set(roc_auc, j_index)\n```\n:::\n\n\n\n\n\nNow, we train the models and generate the results using `tune::fit_resamples()`:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(2180)\nqda_rose_res <- fit_resamples(\n qda_rose_wflw, \n resamples = cv_folds, \n metrics = cls_metrics\n)\n\ncollect_metrics(qda_rose_res)\n#> # A tibble: 2 × 6\n#> .metric .estimator mean n std_err .config \n#> \n#> 1 j_index binary 0.777 50 0.0199 Preprocessor1_Model1\n#> 2 roc_auc binary 0.949 50 0.00508 Preprocessor1_Model1\n```\n:::\n\n\n\n\n\nWhat do the results look like without using ROSE? We can create another workflow and fit the QDA model along the same resamples:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nqda_wflw <- \n workflow() %>% \n add_model(qda_mod) %>% \n add_formula(Class ~ .)\n\nset.seed(2180)\nqda_only_res <- fit_resamples(qda_wflw, resamples = cv_folds, metrics = cls_metrics)\ncollect_metrics(qda_only_res)\n#> # A tibble: 2 × 6\n#> .metric .estimator mean n std_err .config \n#> \n#> 1 j_index binary 0.250 50 0.0288 Preprocessor1_Model1\n#> 2 roc_auc binary 0.953 50 0.00479 Preprocessor1_Model1\n```\n:::\n\n\n\n\n\nIt looks like ROSE helped a lot, especially with the J-index. Class imbalance sampling methods tend to greatly improve metrics based on the hard class predictions (i.e., the categorical predictions) because the default cutoff tends to be a better balance of sensitivity and specificity. \n\nLet's plot the metrics for each resample to see how the individual results changed. \n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nno_sampling <- \n qda_only_res %>% \n collect_metrics(summarize = FALSE) %>% \n dplyr::select(-.estimator) %>% \n mutate(sampling = \"no_sampling\")\n\nwith_sampling <- \n qda_rose_res %>% \n collect_metrics(summarize = FALSE) %>% \n dplyr::select(-.estimator) %>% \n mutate(sampling = \"rose\")\n\nbind_rows(no_sampling, with_sampling) %>% \n mutate(label = paste(id2, id)) %>% \n ggplot(aes(x = sampling, y = .estimate, group = label)) + \n geom_line(alpha = .4) + \n facet_wrap(~ .metric, scales = \"free_y\")\n```\n\n::: {.cell-output-display}\n![](figs/merge-metrics-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nThis visually demonstrates that the subsampling mostly affects metrics that use the hard class predictions. \n\n## Session information {#session-info}\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> discrim 1.0.1 2023-03-08 CRAN (R 4.4.0)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> klaR 1.7-3 2023-12-13 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> readr 2.1.5 2024-01-10 CRAN (R 4.4.0)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> ROSE 0.0-4 2021-06-14 CRAN (R 4.4.0)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> themis 1.0.3 2025-01-23 CRAN (R 4.4.1)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", "supporting": [], "filters": [ "rmarkdown/pagebreak.lua" diff --git a/learn/models/sub-sampling/figs/merge-metrics-1.svg b/learn/models/sub-sampling/figs/merge-metrics-1.svg index 21a7fe9a..94651630 100644 --- a/learn/models/sub-sampling/figs/merge-metrics-1.svg +++ b/learn/models/sub-sampling/figs/merge-metrics-1.svg @@ -30,67 +30,66 @@ - - - - - - - - - + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -102,65 +101,65 @@ - - - - - - + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -194,24 +193,22 @@ no_sampling rose -0.85 -0.90 -0.95 +0.85 +0.90 +0.95 1.00 - - - + + + -0.00 -0.25 -0.50 -0.75 -1.00 - - - - - +0.00 +0.25 +0.50 +0.75 + + + + sampling .estimate diff --git a/learn/models/sub-sampling/index.html.md b/learn/models/sub-sampling/index.html.md index 50ca90a3..dadae7f0 100644 --- a/learn/models/sub-sampling/index.html.md +++ b/learn/models/sub-sampling/index.html.md @@ -28,7 +28,6 @@ Consider a two-class problem where the first class has a very low rate of occurr ```{.r .cell-code} imbal_data <- - readr::read_csv("https://tidymodels.org/learn/models/sub-sampling/imbal_data.csv") %>% mutate(Class = factor(Class)) dim(imbal_data) @@ -62,6 +61,8 @@ Here is a simple recipe implementing oversampling: ```{.r .cell-code} library(tidymodels) library(themis) +set.seed(1234) + imbal_rec <- recipe(Class ~ ., data = imbal_data) %>% step_rose(Class) @@ -152,8 +153,8 @@ collect_metrics(qda_rose_res) #> # A tibble: 2 × 6 #> .metric .estimator mean n std_err .config #> -#> 1 j_index binary 0.804 50 0.0178 Preprocessor1_Model1 -#> 2 roc_auc binary 0.953 50 0.00459 Preprocessor1_Model1 +#> 1 j_index binary 0.777 50 0.0199 Preprocessor1_Model1 +#> 2 roc_auc binary 0.949 50 0.00508 Preprocessor1_Model1 ``` ::: diff --git a/learn/models/sub-sampling/index.qmd b/learn/models/sub-sampling/index.qmd index acd59c42..c2f94067 100644 --- a/learn/models/sub-sampling/index.qmd +++ b/learn/models/sub-sampling/index.qmd @@ -54,7 +54,6 @@ Consider a two-class problem where the first class has a very low rate of occurr #| label: "load-data" #| message: false imbal_data <- - readr::read_csv("https://tidymodels.org/learn/models/sub-sampling/imbal_data.csv") %>% mutate(Class = factor(Class)) dim(imbal_data) @@ -82,6 +81,8 @@ Here is a simple recipe implementing oversampling: #| label: "rec" library(tidymodels) library(themis) +set.seed(1234) + imbal_rec <- recipe(Class ~ ., data = imbal_data) %>% step_rose(Class) From cd61d6be14e8a7bcd53e793fb8e7a68cfafa3d69 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 24 Mar 2025 12:25:24 -0700 Subject: [PATCH 3/7] set seed in learn/statistics/infer --- .../infer/index/execute-results/html.json | 4 +- .../infer/figs/unnamed-chunk-23-1.svg | 90 ++++++------ learn/statistics/infer/figs/visualize-1.svg | 121 ++++++++-------- learn/statistics/infer/figs/visualize2-1.svg | 131 +++++++++--------- learn/statistics/infer/index.html.md | 77 +++++----- learn/statistics/infer/index.qmd | 3 + 6 files changed, 219 insertions(+), 207 deletions(-) diff --git a/_freeze/learn/statistics/infer/index/execute-results/html.json b/_freeze/learn/statistics/infer/index/execute-results/html.json index 4917b2e7..035adcec 100644 --- a/_freeze/learn/statistics/infer/index/execute-results/html.json +++ b/_freeze/learn/statistics/infer/index/execute-results/html.json @@ -1,8 +1,8 @@ { - "hash": "583880b2d724eb59ddf30bf0ad3e3f01", + "hash": "4158a1c8a702b6966548fdb8f055388e", "result": { "engine": "knitr", - "markdown": "---\ntitle: \"Hypothesis testing using resampling and tidy data\"\ncategories:\n - statistical analysis\n - hypothesis testing\n - bootstrapping\ntype: learn-subsection\nweight: 4\ndescription: | \n Perform common hypothesis tests for statistical inference using flexible functions.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n## Introduction\n\nThis article only requires the tidymodels package. \n\nThe tidymodels package [infer](https://infer.tidymodels.org/) implements an expressive grammar to perform statistical inference that coheres with the `tidyverse` design framework. Rather than providing methods for specific statistical tests, this package consolidates the principles that are shared among common hypothesis tests into a set of 4 main verbs (functions), supplemented with many utilities to visualize and extract information from their outputs.\n\nRegardless of which hypothesis test we're using, we're still asking the same kind of question: \n\n>Is the effect or difference in our observed data real, or due to chance? \n\nTo answer this question, we start by assuming that the observed data came from some world where \"nothing is going on\" (i.e. the observed effect was simply due to random chance), and call this assumption our **null hypothesis**. (In reality, we might not believe in the null hypothesis at all; the null hypothesis is in opposition to the **alternate hypothesis**, which supposes that the effect present in the observed data is actually due to the fact that \"something is going on.\") We then calculate a **test statistic** from our data that describes the observed effect. We can use this test statistic to calculate a **p-value**, giving the probability that our observed data could come about if the null hypothesis was true. If this probability is below some pre-defined **significance level** $\\alpha$, then we can reject our null hypothesis.\n\nIf you are new to hypothesis testing, take a look at \n\n* [Section 9.2 of _Statistical Inference via Data Science_](https://moderndive.com/9-hypothesis-testing.html#understanding-ht)\n* The American Statistical Association's recent [statement on p-values](https://doi.org/10.1080/00031305.2016.1154108) \n\nThe workflow of this package is designed around these ideas. Starting from some data set,\n\n+ `specify()` allows you to specify the variable, or relationship between variables, that you're interested in,\n+ `hypothesize()` allows you to declare the null hypothesis,\n+ `generate()` allows you to generate data reflecting the null hypothesis, and\n+ `calculate()` allows you to calculate a distribution of statistics from the generated data to form the null distribution.\n\nThroughout this vignette, we make use of `gss`, a data set available in infer containing a sample of 500 observations of 11 variables from the *General Social Survey*. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels) # Includes the infer package\n\n# load in the data set\ndata(gss)\n\n# take a look at its structure\ndplyr::glimpse(gss)\n#> Rows: 500\n#> Columns: 11\n#> $ year 2014, 1994, 1998, 1996, 1994, 1996, 1990, 2016, 2000, 1998, 20…\n#> $ age 36, 34, 24, 42, 31, 32, 48, 36, 30, 33, 21, 30, 38, 49, 25, 56…\n#> $ sex male, female, male, male, male, female, female, female, female…\n#> $ college degree, no degree, degree, no degree, degree, no degree, no de…\n#> $ partyid ind, rep, ind, ind, rep, rep, dem, ind, rep, dem, dem, ind, de…\n#> $ hompop 3, 4, 1, 4, 2, 4, 2, 1, 5, 2, 4, 3, 4, 4, 2, 2, 3, 2, 1, 2, 5,…\n#> $ hours 50, 31, 40, 40, 40, 53, 32, 20, 40, 40, 23, 52, 38, 72, 48, 40…\n#> $ income $25000 or more, $20000 - 24999, $25000 or more, $25000 or more…\n#> $ class middle class, working class, working class, working class, mid…\n#> $ finrela below average, below average, below average, above average, ab…\n#> $ weight 0.8960034, 1.0825000, 0.5501000, 1.0864000, 1.0825000, 1.08640…\n```\n:::\n\n\n\n\nEach row is an individual survey response, containing some basic demographic information on the respondent as well as some additional variables. See `?gss` for more information on the variables included and their source. Note that this data (and our examples on it) are for demonstration purposes only, and will not necessarily provide accurate estimates unless weighted properly. For these examples, let's suppose that this data set is a representative sample of a population we want to learn about: American adults.\n\n## Specify variables\n\nThe `specify()` function can be used to specify which of the variables in the data set you're interested in. If you're only interested in, say, the `age` of the respondents, you might write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = age)\n#> Response: age (numeric)\n#> # A tibble: 500 × 1\n#> age\n#> \n#> 1 36\n#> 2 34\n#> 3 24\n#> 4 42\n#> 5 31\n#> 6 32\n#> 7 48\n#> 8 36\n#> 9 30\n#> 10 33\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\nOn the front end, the output of `specify()` just looks like it selects off the columns in the dataframe that you've specified. What do we see if we check the class of this object, though?\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = age) %>%\n class()\n#> [1] \"infer\" \"tbl_df\" \"tbl\" \"data.frame\"\n```\n:::\n\n\n\n\nWe can see that the infer class has been appended on top of the dataframe classes; this new class stores some extra metadata.\n\nIf you're interested in two variables (`age` and `partyid`, for example) you can `specify()` their relationship in one of two (equivalent) ways:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# as a formula\ngss %>%\n specify(age ~ partyid)\n#> Response: age (numeric)\n#> Explanatory: partyid (factor)\n#> # A tibble: 500 × 2\n#> age partyid\n#> \n#> 1 36 ind \n#> 2 34 rep \n#> 3 24 ind \n#> 4 42 ind \n#> 5 31 rep \n#> 6 32 rep \n#> 7 48 dem \n#> 8 36 ind \n#> 9 30 rep \n#> 10 33 dem \n#> # ℹ 490 more rows\n\n# with the named arguments\ngss %>%\n specify(response = age, explanatory = partyid)\n#> Response: age (numeric)\n#> Explanatory: partyid (factor)\n#> # A tibble: 500 × 2\n#> age partyid\n#> \n#> 1 36 ind \n#> 2 34 rep \n#> 3 24 ind \n#> 4 42 ind \n#> 5 31 rep \n#> 6 32 rep \n#> 7 48 dem \n#> 8 36 ind \n#> 9 30 rep \n#> 10 33 dem \n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\nIf you're doing inference on one proportion or a difference in proportions, you will need to use the `success` argument to specify which level of your `response` variable is a success. For instance, if you're interested in the proportion of the population with a college degree, you might use the following code:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# specifying for inference on proportions\ngss %>%\n specify(response = college, success = \"degree\")\n#> Response: college (factor)\n#> # A tibble: 500 × 1\n#> college \n#> \n#> 1 degree \n#> 2 no degree\n#> 3 degree \n#> 4 no degree\n#> 5 degree \n#> 6 no degree\n#> 7 no degree\n#> 8 degree \n#> 9 degree \n#> 10 no degree\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n## Declare the hypothesis\n\nThe next step in the infer pipeline is often to declare a null hypothesis using `hypothesize()`. The first step is to supply one of \"independence\" or \"point\" to the `null` argument. If your null hypothesis assumes independence between two variables, then this is all you need to supply to `hypothesize()`:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(college ~ partyid, success = \"degree\") %>%\n hypothesize(null = \"independence\")\n#> Response: college (factor)\n#> Explanatory: partyid (factor)\n#> Null Hypothesis: independence\n#> # A tibble: 500 × 2\n#> college partyid\n#> \n#> 1 degree ind \n#> 2 no degree rep \n#> 3 degree ind \n#> 4 no degree ind \n#> 5 degree rep \n#> 6 no degree rep \n#> 7 no degree dem \n#> 8 degree ind \n#> 9 degree rep \n#> 10 no degree dem \n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\nIf you're doing inference on a point estimate, you will also need to provide one of `p` (the true proportion of successes, between 0 and 1), `mu` (the true mean), `med` (the true median), or `sigma` (the true standard deviation). For instance, if the null hypothesis is that the mean number of hours worked per week in our population is 40, we would write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40)\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 500 × 1\n#> hours\n#> \n#> 1 50\n#> 2 31\n#> 3 40\n#> 4 40\n#> 5 40\n#> 6 53\n#> 7 32\n#> 8 20\n#> 9 40\n#> 10 40\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\nAgain, from the front-end, the dataframe outputted from `hypothesize()` looks almost exactly the same as it did when it came out of `specify()`, but infer now \"knows\" your null hypothesis.\n\n## Generate the distribution\n\nOnce we've asserted our null hypothesis using `hypothesize()`, we can construct a null distribution based on this hypothesis. We can do this using one of several methods, supplied in the `type` argument:\n\n* `bootstrap`: A bootstrap sample will be drawn for each replicate, where a sample of size equal to the input sample size is drawn (with replacement) from the input sample data. \n* `permute`: For each replicate, each input value will be randomly reassigned (without replacement) to a new output value in the sample. \n* `simulate`: A value will be sampled from a theoretical distribution with parameters specified in `hypothesize()` for each replicate. (This option is currently only applicable for testing point estimates.) \n\nContinuing on with our example above, about the average number of hours worked a week, we might write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\")\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 2,500,000 × 2\n#> # Groups: replicate [5,000]\n#> replicate hours\n#> \n#> 1 1 53.6\n#> 2 1 38.6\n#> 3 1 48.6\n#> 4 1 39.6\n#> 5 1 53.6\n#> 6 1 38.6\n#> 7 1 38.6\n#> 8 1 46.6\n#> 9 1 28.6\n#> 10 1 38.6\n#> # ℹ 2,499,990 more rows\n```\n:::\n\n\n\n\nIn the above example, we take 5000 bootstrap samples to form our null distribution.\n\nTo generate a null distribution for the independence of two variables, we could also randomly reshuffle the pairings of explanatory and response variables to break any existing association. For instance, to generate 5000 replicates that can be used to create a null distribution under the assumption that political party affiliation is not affected by age:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(partyid ~ age) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\")\n#> Response: partyid (factor)\n#> Explanatory: age (numeric)\n#> Null Hypothesis: independence\n#> # A tibble: 2,500,000 × 3\n#> # Groups: replicate [5,000]\n#> partyid age replicate\n#> \n#> 1 dem 36 1\n#> 2 dem 34 1\n#> 3 dem 24 1\n#> 4 rep 42 1\n#> 5 ind 31 1\n#> 6 dem 32 1\n#> 7 ind 48 1\n#> 8 ind 36 1\n#> 9 rep 30 1\n#> 10 ind 33 1\n#> # ℹ 2,499,990 more rows\n```\n:::\n\n\n\n\n## Calculate statistics\n\nDepending on whether you're carrying out computation-based inference or theory-based inference, you will either supply `calculate()` with the output of `generate()` or `hypothesize()`, respectively. The function, for one, takes in a `stat` argument, which is currently one of `\"mean\"`, `\"median\"`, `\"sum\"`, `\"sd\"`, `\"prop\"`, `\"count\"`, `\"diff in means\"`, `\"diff in medians\"`, `\"diff in props\"`, `\"Chisq\"`, `\"F\"`, `\"t\"`, `\"z\"`, `\"slope\"`, or `\"correlation\"`. For example, continuing our example above to calculate the null distribution of mean hours worked per week:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\") %>%\n calculate(stat = \"mean\")\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 5,000 × 2\n#> replicate stat\n#> \n#> 1 1 39.7\n#> 2 2 39.4\n#> 3 3 39.3\n#> 4 4 39.6\n#> 5 5 40.1\n#> 6 6 40.9\n#> 7 7 39.0\n#> 8 8 40.9\n#> 9 9 38.3\n#> 10 10 39.6\n#> # ℹ 4,990 more rows\n```\n:::\n\n\n\n\nThe output of `calculate()` here shows us the sample statistic (in this case, the mean) for each of our 1000 replicates. If you're carrying out inference on differences in means, medians, or proportions, or $t$ and $z$ statistics, you will need to supply an `order` argument, giving the order in which the explanatory variables should be subtracted. For instance, to find the difference in mean age of those that have a college degree and those that don't, we might write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(age ~ college) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(\"diff in means\", order = c(\"degree\", \"no degree\"))\n#> Response: age (numeric)\n#> Explanatory: college (factor)\n#> Null Hypothesis: independence\n#> # A tibble: 5,000 × 2\n#> replicate stat\n#> \n#> 1 1 -2.89 \n#> 2 2 -2.62 \n#> 3 3 -0.620 \n#> 4 4 -0.320 \n#> 5 5 0.0680\n#> 6 6 0.112 \n#> 7 7 1.47 \n#> 8 8 -1.47 \n#> 9 9 0.139 \n#> 10 10 -0.390 \n#> # ℹ 4,990 more rows\n```\n:::\n\n\n\n\n## Other utilities\n\nThe infer package also offers several utilities to extract meaning out of summary statistics and null distributions; the package provides functions to visualize where a statistic is relative to a distribution (with `visualize()`), calculate p-values (with `get_p_value()`), and calculate confidence intervals (with `get_confidence_interval()`).\n\nTo illustrate, we'll go back to the example of determining whether the mean number of hours worked per week is 40 hours.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# find the point estimate\npoint_estimate <- gss %>%\n specify(response = hours) %>%\n calculate(stat = \"mean\")\n\n# generate a null distribution\nnull_dist <- gss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\") %>%\n calculate(stat = \"mean\")\n```\n:::\n\n\n\n\n(Notice the warning: `Removed 1244 rows containing missing values.` This would be worth noting if you were actually carrying out this hypothesis test.)\n\nOur point estimate 41.382 seems *pretty* close to 40, but a little bit different. We might wonder if this difference is just due to random chance, or if the mean number of hours worked per week in the population really isn't 40.\n\nWe could initially just visualize the null distribution.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_dist %>%\n visualize()\n```\n\n::: {.cell-output-display}\n![](figs/visualize-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nWhere does our sample's observed statistic lie on this distribution? We can use the `obs_stat` argument to specify this.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_dist %>%\n visualize() +\n shade_p_value(obs_stat = point_estimate, direction = \"two_sided\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize2-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nNotice that infer has also shaded the regions of the null distribution that are as (or more) extreme than our observed statistic. (Also, note that we now use the `+` operator to apply the `shade_p_value()` function. This is because `visualize()` outputs a plot object from ggplot2 instead of a dataframe, and the `+` operator is needed to add the p-value layer to the plot object.) The red bar looks like it's slightly far out on the right tail of the null distribution, so observing a sample mean of 41.382 hours would be somewhat unlikely if the mean was actually 40 hours. How unlikely, though?\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# get a two-tailed p-value\np_value <- null_dist %>%\n get_p_value(obs_stat = point_estimate, direction = \"two_sided\")\n\np_value\n#> # A tibble: 1 × 1\n#> p_value\n#> \n#> 1 0.0292\n```\n:::\n\n\n\n\nIt looks like the p-value is 0.0292, which is pretty small---if the true mean number of hours worked per week was actually 40, the probability of our sample mean being this far (1.382 hours) from 40 would be 0.0292. This may or may not be statistically significantly different, depending on the significance level $\\alpha$ you decided on *before* you ran this analysis. If you had set $\\alpha = .05$, then this difference would be statistically significant, but if you had set $\\alpha = .01$, then it would not be.\n\nTo get a confidence interval around our estimate, we can write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# start with the null distribution\nnull_dist %>%\n # calculate the confidence interval around the point estimate\n get_confidence_interval(point_estimate = point_estimate,\n # at the 95% confidence level\n level = .95,\n # using the standard error\n type = \"se\")\n#> # A tibble: 1 × 2\n#> lower_ci upper_ci\n#> \n#> 1 40.1 42.7\n```\n:::\n\n\n\n\nAs you can see, 40 hours per week is not contained in this interval, which aligns with our previous conclusion that this finding is significant at the confidence level $\\alpha = .05$.\n\n## Theoretical methods\n\nThe infer package also provides functionality to use theoretical methods for `\"Chisq\"`, `\"F\"` and `\"t\"` test statistics. \n\nGenerally, to find a null distribution using theory-based methods, use the same code that you would use to find the null distribution using randomization-based methods, but skip the `generate()` step. For example, if we wanted to find a null distribution for the relationship between age (`age`) and party identification (`partyid`) using randomization, we could write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_f_distn <- gss %>%\n specify(age ~ partyid) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\nTo find the null distribution using theory-based methods, instead, skip the `generate()` step entirely:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_f_distn_theoretical <- gss %>%\n specify(age ~ partyid) %>%\n hypothesize(null = \"independence\") %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\nWe'll calculate the observed statistic to make use of in the following visualizations; this procedure is the same, regardless of the methods used to find the null distribution.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nF_hat <- gss %>% \n specify(age ~ partyid) %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\nNow, instead of just piping the null distribution into `visualize()`, as we would do if we wanted to visualize the randomization-based null distribution, we also need to provide `method = \"theoretical\"` to `visualize()`.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nvisualize(null_f_distn_theoretical, method = \"theoretical\") +\n shade_p_value(obs_stat = F_hat, direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-22-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nTo get a sense of how the theory-based and randomization-based null distributions relate, we can pipe the randomization-based null distribution into `visualize()` and also specify `method = \"both\"`\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nvisualize(null_f_distn, method = \"both\") +\n shade_p_value(obs_stat = F_hat, direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-23-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThat's it! This vignette covers most all of the key functionality of infer. See `help(package = \"infer\")` for a full list of functions and vignettes.\n\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", + "markdown": "---\ntitle: \"Hypothesis testing using resampling and tidy data\"\ncategories:\n - statistical analysis\n - hypothesis testing\n - bootstrapping\ntype: learn-subsection\nweight: 4\ndescription: | \n Perform common hypothesis tests for statistical inference using flexible functions.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n\n## Introduction\n\nThis article only requires the tidymodels package. \n\nThe tidymodels package [infer](https://infer.tidymodels.org/) implements an expressive grammar to perform statistical inference that coheres with the `tidyverse` design framework. Rather than providing methods for specific statistical tests, this package consolidates the principles that are shared among common hypothesis tests into a set of 4 main verbs (functions), supplemented with many utilities to visualize and extract information from their outputs.\n\nRegardless of which hypothesis test we're using, we're still asking the same kind of question: \n\n>Is the effect or difference in our observed data real, or due to chance? \n\nTo answer this question, we start by assuming that the observed data came from some world where \"nothing is going on\" (i.e. the observed effect was simply due to random chance), and call this assumption our **null hypothesis**. (In reality, we might not believe in the null hypothesis at all; the null hypothesis is in opposition to the **alternate hypothesis**, which supposes that the effect present in the observed data is actually due to the fact that \"something is going on.\") We then calculate a **test statistic** from our data that describes the observed effect. We can use this test statistic to calculate a **p-value**, giving the probability that our observed data could come about if the null hypothesis was true. If this probability is below some pre-defined **significance level** $\\alpha$, then we can reject our null hypothesis.\n\nIf you are new to hypothesis testing, take a look at \n\n* [Section 9.2 of _Statistical Inference via Data Science_](https://moderndive.com/9-hypothesis-testing.html#understanding-ht)\n* The American Statistical Association's recent [statement on p-values](https://doi.org/10.1080/00031305.2016.1154108) \n\nThe workflow of this package is designed around these ideas. Starting from some data set,\n\n+ `specify()` allows you to specify the variable, or relationship between variables, that you're interested in,\n+ `hypothesize()` allows you to declare the null hypothesis,\n+ `generate()` allows you to generate data reflecting the null hypothesis, and\n+ `calculate()` allows you to calculate a distribution of statistics from the generated data to form the null distribution.\n\nThroughout this vignette, we make use of `gss`, a data set available in infer containing a sample of 500 observations of 11 variables from the *General Social Survey*. \n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels) # Includes the infer package\n\n# Set seed\nset.seed(1234)\n\n# load in the data set\ndata(gss)\n\n# take a look at its structure\ndplyr::glimpse(gss)\n#> Rows: 500\n#> Columns: 11\n#> $ year 2014, 1994, 1998, 1996, 1994, 1996, 1990, 2016, 2000, 1998, 20…\n#> $ age 36, 34, 24, 42, 31, 32, 48, 36, 30, 33, 21, 30, 38, 49, 25, 56…\n#> $ sex male, female, male, male, male, female, female, female, female…\n#> $ college degree, no degree, degree, no degree, degree, no degree, no de…\n#> $ partyid ind, rep, ind, ind, rep, rep, dem, ind, rep, dem, dem, ind, de…\n#> $ hompop 3, 4, 1, 4, 2, 4, 2, 1, 5, 2, 4, 3, 4, 4, 2, 2, 3, 2, 1, 2, 5,…\n#> $ hours 50, 31, 40, 40, 40, 53, 32, 20, 40, 40, 23, 52, 38, 72, 48, 40…\n#> $ income $25000 or more, $20000 - 24999, $25000 or more, $25000 or more…\n#> $ class middle class, working class, working class, working class, mid…\n#> $ finrela below average, below average, below average, above average, ab…\n#> $ weight 0.8960034, 1.0825000, 0.5501000, 1.0864000, 1.0825000, 1.08640…\n```\n:::\n\n\n\n\n\nEach row is an individual survey response, containing some basic demographic information on the respondent as well as some additional variables. See `?gss` for more information on the variables included and their source. Note that this data (and our examples on it) are for demonstration purposes only, and will not necessarily provide accurate estimates unless weighted properly. For these examples, let's suppose that this data set is a representative sample of a population we want to learn about: American adults.\n\n## Specify variables\n\nThe `specify()` function can be used to specify which of the variables in the data set you're interested in. If you're only interested in, say, the `age` of the respondents, you might write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = age)\n#> Response: age (numeric)\n#> # A tibble: 500 × 1\n#> age\n#> \n#> 1 36\n#> 2 34\n#> 3 24\n#> 4 42\n#> 5 31\n#> 6 32\n#> 7 48\n#> 8 36\n#> 9 30\n#> 10 33\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\nOn the front end, the output of `specify()` just looks like it selects off the columns in the dataframe that you've specified. What do we see if we check the class of this object, though?\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = age) %>%\n class()\n#> [1] \"infer\" \"tbl_df\" \"tbl\" \"data.frame\"\n```\n:::\n\n\n\n\n\nWe can see that the infer class has been appended on top of the dataframe classes; this new class stores some extra metadata.\n\nIf you're interested in two variables (`age` and `partyid`, for example) you can `specify()` their relationship in one of two (equivalent) ways:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# as a formula\ngss %>%\n specify(age ~ partyid)\n#> Response: age (numeric)\n#> Explanatory: partyid (factor)\n#> # A tibble: 500 × 2\n#> age partyid\n#> \n#> 1 36 ind \n#> 2 34 rep \n#> 3 24 ind \n#> 4 42 ind \n#> 5 31 rep \n#> 6 32 rep \n#> 7 48 dem \n#> 8 36 ind \n#> 9 30 rep \n#> 10 33 dem \n#> # ℹ 490 more rows\n\n# with the named arguments\ngss %>%\n specify(response = age, explanatory = partyid)\n#> Response: age (numeric)\n#> Explanatory: partyid (factor)\n#> # A tibble: 500 × 2\n#> age partyid\n#> \n#> 1 36 ind \n#> 2 34 rep \n#> 3 24 ind \n#> 4 42 ind \n#> 5 31 rep \n#> 6 32 rep \n#> 7 48 dem \n#> 8 36 ind \n#> 9 30 rep \n#> 10 33 dem \n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\nIf you're doing inference on one proportion or a difference in proportions, you will need to use the `success` argument to specify which level of your `response` variable is a success. For instance, if you're interested in the proportion of the population with a college degree, you might use the following code:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# specifying for inference on proportions\ngss %>%\n specify(response = college, success = \"degree\")\n#> Response: college (factor)\n#> # A tibble: 500 × 1\n#> college \n#> \n#> 1 degree \n#> 2 no degree\n#> 3 degree \n#> 4 no degree\n#> 5 degree \n#> 6 no degree\n#> 7 no degree\n#> 8 degree \n#> 9 degree \n#> 10 no degree\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\n## Declare the hypothesis\n\nThe next step in the infer pipeline is often to declare a null hypothesis using `hypothesize()`. The first step is to supply one of \"independence\" or \"point\" to the `null` argument. If your null hypothesis assumes independence between two variables, then this is all you need to supply to `hypothesize()`:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(college ~ partyid, success = \"degree\") %>%\n hypothesize(null = \"independence\")\n#> Response: college (factor)\n#> Explanatory: partyid (factor)\n#> Null Hypothesis: independence\n#> # A tibble: 500 × 2\n#> college partyid\n#> \n#> 1 degree ind \n#> 2 no degree rep \n#> 3 degree ind \n#> 4 no degree ind \n#> 5 degree rep \n#> 6 no degree rep \n#> 7 no degree dem \n#> 8 degree ind \n#> 9 degree rep \n#> 10 no degree dem \n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\nIf you're doing inference on a point estimate, you will also need to provide one of `p` (the true proportion of successes, between 0 and 1), `mu` (the true mean), `med` (the true median), or `sigma` (the true standard deviation). For instance, if the null hypothesis is that the mean number of hours worked per week in our population is 40, we would write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40)\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 500 × 1\n#> hours\n#> \n#> 1 50\n#> 2 31\n#> 3 40\n#> 4 40\n#> 5 40\n#> 6 53\n#> 7 32\n#> 8 20\n#> 9 40\n#> 10 40\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\nAgain, from the front-end, the dataframe outputted from `hypothesize()` looks almost exactly the same as it did when it came out of `specify()`, but infer now \"knows\" your null hypothesis.\n\n## Generate the distribution\n\nOnce we've asserted our null hypothesis using `hypothesize()`, we can construct a null distribution based on this hypothesis. We can do this using one of several methods, supplied in the `type` argument:\n\n* `bootstrap`: A bootstrap sample will be drawn for each replicate, where a sample of size equal to the input sample size is drawn (with replacement) from the input sample data. \n* `permute`: For each replicate, each input value will be randomly reassigned (without replacement) to a new output value in the sample. \n* `simulate`: A value will be sampled from a theoretical distribution with parameters specified in `hypothesize()` for each replicate. (This option is currently only applicable for testing point estimates.) \n\nContinuing on with our example above, about the average number of hours worked a week, we might write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\")\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 2,500,000 × 2\n#> # Groups: replicate [5,000]\n#> replicate hours\n#> \n#> 1 1 58.6\n#> 2 1 35.6\n#> 3 1 28.6\n#> 4 1 38.6\n#> 5 1 28.6\n#> 6 1 38.6\n#> 7 1 38.6\n#> 8 1 57.6\n#> 9 1 58.6\n#> 10 1 38.6\n#> # ℹ 2,499,990 more rows\n```\n:::\n\n\n\n\n\nIn the above example, we take 5000 bootstrap samples to form our null distribution.\n\nTo generate a null distribution for the independence of two variables, we could also randomly reshuffle the pairings of explanatory and response variables to break any existing association. For instance, to generate 5000 replicates that can be used to create a null distribution under the assumption that political party affiliation is not affected by age:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(partyid ~ age) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\")\n#> Response: partyid (factor)\n#> Explanatory: age (numeric)\n#> Null Hypothesis: independence\n#> # A tibble: 2,500,000 × 3\n#> # Groups: replicate [5,000]\n#> partyid age replicate\n#> \n#> 1 ind 36 1\n#> 2 ind 34 1\n#> 3 ind 24 1\n#> 4 rep 42 1\n#> 5 dem 31 1\n#> 6 dem 32 1\n#> 7 dem 48 1\n#> 8 rep 36 1\n#> 9 ind 30 1\n#> 10 dem 33 1\n#> # ℹ 2,499,990 more rows\n```\n:::\n\n\n\n\n\n## Calculate statistics\n\nDepending on whether you're carrying out computation-based inference or theory-based inference, you will either supply `calculate()` with the output of `generate()` or `hypothesize()`, respectively. The function, for one, takes in a `stat` argument, which is currently one of `\"mean\"`, `\"median\"`, `\"sum\"`, `\"sd\"`, `\"prop\"`, `\"count\"`, `\"diff in means\"`, `\"diff in medians\"`, `\"diff in props\"`, `\"Chisq\"`, `\"F\"`, `\"t\"`, `\"z\"`, `\"slope\"`, or `\"correlation\"`. For example, continuing our example above to calculate the null distribution of mean hours worked per week:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\") %>%\n calculate(stat = \"mean\")\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 5,000 × 2\n#> replicate stat\n#> \n#> 1 1 39.8\n#> 2 2 39.6\n#> 3 3 39.8\n#> 4 4 39.2\n#> 5 5 39.0\n#> 6 6 39.8\n#> 7 7 40.6\n#> 8 8 40.6\n#> 9 9 40.4\n#> 10 10 39.0\n#> # ℹ 4,990 more rows\n```\n:::\n\n\n\n\n\nThe output of `calculate()` here shows us the sample statistic (in this case, the mean) for each of our 1000 replicates. If you're carrying out inference on differences in means, medians, or proportions, or $t$ and $z$ statistics, you will need to supply an `order` argument, giving the order in which the explanatory variables should be subtracted. For instance, to find the difference in mean age of those that have a college degree and those that don't, we might write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(age ~ college) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(\"diff in means\", order = c(\"degree\", \"no degree\"))\n#> Response: age (numeric)\n#> Explanatory: college (factor)\n#> Null Hypothesis: independence\n#> # A tibble: 5,000 × 2\n#> replicate stat\n#> \n#> 1 1 -0.0378\n#> 2 2 1.55 \n#> 3 3 0.465 \n#> 4 4 1.39 \n#> 5 5 -0.161 \n#> 6 6 -0.179 \n#> 7 7 0.0151\n#> 8 8 0.914 \n#> 9 9 -1.32 \n#> 10 10 -0.426 \n#> # ℹ 4,990 more rows\n```\n:::\n\n\n\n\n\n## Other utilities\n\nThe infer package also offers several utilities to extract meaning out of summary statistics and null distributions; the package provides functions to visualize where a statistic is relative to a distribution (with `visualize()`), calculate p-values (with `get_p_value()`), and calculate confidence intervals (with `get_confidence_interval()`).\n\nTo illustrate, we'll go back to the example of determining whether the mean number of hours worked per week is 40 hours.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# find the point estimate\npoint_estimate <- gss %>%\n specify(response = hours) %>%\n calculate(stat = \"mean\")\n\n# generate a null distribution\nnull_dist <- gss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\") %>%\n calculate(stat = \"mean\")\n```\n:::\n\n\n\n\n\n(Notice the warning: `Removed 1244 rows containing missing values.` This would be worth noting if you were actually carrying out this hypothesis test.)\n\nOur point estimate 41.382 seems *pretty* close to 40, but a little bit different. We might wonder if this difference is just due to random chance, or if the mean number of hours worked per week in the population really isn't 40.\n\nWe could initially just visualize the null distribution.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_dist %>%\n visualize()\n```\n\n::: {.cell-output-display}\n![](figs/visualize-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nWhere does our sample's observed statistic lie on this distribution? We can use the `obs_stat` argument to specify this.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_dist %>%\n visualize() +\n shade_p_value(obs_stat = point_estimate, direction = \"two_sided\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize2-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nNotice that infer has also shaded the regions of the null distribution that are as (or more) extreme than our observed statistic. (Also, note that we now use the `+` operator to apply the `shade_p_value()` function. This is because `visualize()` outputs a plot object from ggplot2 instead of a dataframe, and the `+` operator is needed to add the p-value layer to the plot object.) The red bar looks like it's slightly far out on the right tail of the null distribution, so observing a sample mean of 41.382 hours would be somewhat unlikely if the mean was actually 40 hours. How unlikely, though?\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# get a two-tailed p-value\np_value <- null_dist %>%\n get_p_value(obs_stat = point_estimate, direction = \"two_sided\")\n\np_value\n#> # A tibble: 1 × 1\n#> p_value\n#> \n#> 1 0.046\n```\n:::\n\n\n\n\n\nIt looks like the p-value is 0.046, which is pretty small---if the true mean number of hours worked per week was actually 40, the probability of our sample mean being this far (1.382 hours) from 40 would be 0.046. This may or may not be statistically significantly different, depending on the significance level $\\alpha$ you decided on *before* you ran this analysis. If you had set $\\alpha = .05$, then this difference would be statistically significant, but if you had set $\\alpha = .01$, then it would not be.\n\nTo get a confidence interval around our estimate, we can write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# start with the null distribution\nnull_dist %>%\n # calculate the confidence interval around the point estimate\n get_confidence_interval(point_estimate = point_estimate,\n # at the 95% confidence level\n level = .95,\n # using the standard error\n type = \"se\")\n#> # A tibble: 1 × 2\n#> lower_ci upper_ci\n#> \n#> 1 40.1 42.7\n```\n:::\n\n\n\n\n\nAs you can see, 40 hours per week is not contained in this interval, which aligns with our previous conclusion that this finding is significant at the confidence level $\\alpha = .05$.\n\n## Theoretical methods\n\nThe infer package also provides functionality to use theoretical methods for `\"Chisq\"`, `\"F\"` and `\"t\"` test statistics. \n\nGenerally, to find a null distribution using theory-based methods, use the same code that you would use to find the null distribution using randomization-based methods, but skip the `generate()` step. For example, if we wanted to find a null distribution for the relationship between age (`age`) and party identification (`partyid`) using randomization, we could write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_f_distn <- gss %>%\n specify(age ~ partyid) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\n\nTo find the null distribution using theory-based methods, instead, skip the `generate()` step entirely:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_f_distn_theoretical <- gss %>%\n specify(age ~ partyid) %>%\n hypothesize(null = \"independence\") %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\n\nWe'll calculate the observed statistic to make use of in the following visualizations; this procedure is the same, regardless of the methods used to find the null distribution.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nF_hat <- gss %>% \n specify(age ~ partyid) %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\n\nNow, instead of just piping the null distribution into `visualize()`, as we would do if we wanted to visualize the randomization-based null distribution, we also need to provide `method = \"theoretical\"` to `visualize()`.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nvisualize(null_f_distn_theoretical, method = \"theoretical\") +\n shade_p_value(obs_stat = F_hat, direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-22-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nTo get a sense of how the theory-based and randomization-based null distributions relate, we can pipe the randomization-based null distribution into `visualize()` and also specify `method = \"both\"`\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nvisualize(null_f_distn, method = \"both\") +\n shade_p_value(obs_stat = F_hat, direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-23-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nThat's it! This vignette covers most all of the key functionality of infer. See `help(package = \"infer\")` for a full list of functions and vignettes.\n\n\n## Session information {#session-info}\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", "supporting": [], "filters": [ "rmarkdown/pagebreak.lua" diff --git a/learn/statistics/infer/figs/unnamed-chunk-23-1.svg b/learn/statistics/infer/figs/unnamed-chunk-23-1.svg index 66689fae..abcbec69 100644 --- a/learn/statistics/infer/figs/unnamed-chunk-23-1.svg +++ b/learn/statistics/infer/figs/unnamed-chunk-23-1.svg @@ -30,60 +30,60 @@ - - - - - - - - + + + + + + + + - - - + + + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + 0.0 -0.2 -0.4 -0.6 +0.2 +0.4 +0.6 - - - + + + - - - + + + 0 -2 -4 -6 +2 +4 +6 F stat density Simulation-Based and Theoretical F Null Distributions diff --git a/learn/statistics/infer/figs/visualize-1.svg b/learn/statistics/infer/figs/visualize-1.svg index c04ddcaa..64d830e5 100644 --- a/learn/statistics/infer/figs/visualize-1.svg +++ b/learn/statistics/infer/figs/visualize-1.svg @@ -24,69 +24,72 @@ - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + -0 -250 -500 -750 - - - - - - - - - -38 -39 -40 -41 -42 -stat +0 +250 +500 +750 +1000 + + + + + + + + + + +38 +39 +40 +41 +42 +stat count -Simulation-Based Null Distribution +Simulation-Based Null Distribution diff --git a/learn/statistics/infer/figs/visualize2-1.svg b/learn/statistics/infer/figs/visualize2-1.svg index cc932d21..f9c52955 100644 --- a/learn/statistics/infer/figs/visualize2-1.svg +++ b/learn/statistics/infer/figs/visualize2-1.svg @@ -24,74 +24,77 @@ - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + -0 -250 -500 -750 - - - - - - - - - -38 -39 -40 -41 -42 -stat +0 +250 +500 +750 +1000 + + + + + + + + + + +38 +39 +40 +41 +42 +stat count -Simulation-Based Null Distribution +Simulation-Based Null Distribution diff --git a/learn/statistics/infer/index.html.md b/learn/statistics/infer/index.html.md index 7e7d86c2..aa3df7ea 100644 --- a/learn/statistics/infer/index.html.md +++ b/learn/statistics/infer/index.html.md @@ -44,6 +44,9 @@ Throughout this vignette, we make use of `gss`, a data set available in infer co ```{.r .cell-code} library(tidymodels) # Includes the infer package +# Set seed +set.seed(1234) + # load in the data set data(gss) @@ -263,15 +266,15 @@ gss %>% #> # Groups: replicate [5,000] #> replicate hours #> -#> 1 1 53.6 -#> 2 1 38.6 -#> 3 1 48.6 -#> 4 1 39.6 -#> 5 1 53.6 +#> 1 1 58.6 +#> 2 1 35.6 +#> 3 1 28.6 +#> 4 1 38.6 +#> 5 1 28.6 #> 6 1 38.6 #> 7 1 38.6 -#> 8 1 46.6 -#> 9 1 28.6 +#> 8 1 57.6 +#> 9 1 58.6 #> 10 1 38.6 #> # ℹ 2,499,990 more rows ``` @@ -295,16 +298,16 @@ gss %>% #> # Groups: replicate [5,000] #> partyid age replicate #> -#> 1 dem 36 1 -#> 2 dem 34 1 -#> 3 dem 24 1 +#> 1 ind 36 1 +#> 2 ind 34 1 +#> 3 ind 24 1 #> 4 rep 42 1 -#> 5 ind 31 1 +#> 5 dem 31 1 #> 6 dem 32 1 -#> 7 ind 48 1 -#> 8 ind 36 1 -#> 9 rep 30 1 -#> 10 ind 33 1 +#> 7 dem 48 1 +#> 8 rep 36 1 +#> 9 ind 30 1 +#> 10 dem 33 1 #> # ℹ 2,499,990 more rows ``` ::: @@ -326,16 +329,16 @@ gss %>% #> # A tibble: 5,000 × 2 #> replicate stat #> -#> 1 1 39.7 -#> 2 2 39.4 -#> 3 3 39.3 -#> 4 4 39.6 -#> 5 5 40.1 -#> 6 6 40.9 -#> 7 7 39.0 -#> 8 8 40.9 -#> 9 9 38.3 -#> 10 10 39.6 +#> 1 1 39.8 +#> 2 2 39.6 +#> 3 3 39.8 +#> 4 4 39.2 +#> 5 5 39.0 +#> 6 6 39.8 +#> 7 7 40.6 +#> 8 8 40.6 +#> 9 9 40.4 +#> 10 10 39.0 #> # ℹ 4,990 more rows ``` ::: @@ -356,16 +359,16 @@ gss %>% #> # A tibble: 5,000 × 2 #> replicate stat #> -#> 1 1 -2.89 -#> 2 2 -2.62 -#> 3 3 -0.620 -#> 4 4 -0.320 -#> 5 5 0.0680 -#> 6 6 0.112 -#> 7 7 1.47 -#> 8 8 -1.47 -#> 9 9 0.139 -#> 10 10 -0.390 +#> 1 1 -0.0378 +#> 2 2 1.55 +#> 3 3 0.465 +#> 4 4 1.39 +#> 5 5 -0.161 +#> 6 6 -0.179 +#> 7 7 0.0151 +#> 8 8 0.914 +#> 9 9 -1.32 +#> 10 10 -0.426 #> # ℹ 4,990 more rows ``` ::: @@ -439,11 +442,11 @@ p_value #> # A tibble: 1 × 1 #> p_value #> -#> 1 0.0292 +#> 1 0.046 ``` ::: -It looks like the p-value is 0.0292, which is pretty small---if the true mean number of hours worked per week was actually 40, the probability of our sample mean being this far (1.382 hours) from 40 would be 0.0292. This may or may not be statistically significantly different, depending on the significance level $\alpha$ you decided on *before* you ran this analysis. If you had set $\alpha = .05$, then this difference would be statistically significant, but if you had set $\alpha = .01$, then it would not be. +It looks like the p-value is 0.046, which is pretty small---if the true mean number of hours worked per week was actually 40, the probability of our sample mean being this far (1.382 hours) from 40 would be 0.046. This may or may not be statistically significantly different, depending on the significance level $\alpha$ you decided on *before* you ran this analysis. If you had set $\alpha = .05$, then this difference would be statistically significant, but if you had set $\alpha = .01$, then it would not be. To get a confidence interval around our estimate, we can write: diff --git a/learn/statistics/infer/index.qmd b/learn/statistics/infer/index.qmd index 45d27854..756a2930 100644 --- a/learn/statistics/infer/index.qmd +++ b/learn/statistics/infer/index.qmd @@ -62,6 +62,9 @@ Throughout this vignette, we make use of `gss`, a data set available in infer co #| message: false library(tidymodels) # Includes the infer package +# Set seed +set.seed(1234) + # load in the data set data(gss) From 45ba8700e7ba9973070fdc83483d453e219ff2a2 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 24 Mar 2025 12:27:03 -0700 Subject: [PATCH 4/7] set seed in learn/statistics/xtabs --- .../xtabs/index/execute-results/html.json | 4 +- .../xtabs/figs/visualize-indep-1.svg | 89 +++++++++--------- .../xtabs/figs/visualize-indep-both-1.svg | 94 +++++++++---------- .../xtabs/figs/visualize-indep-gof-1.svg | 82 ++++++++-------- learn/statistics/xtabs/index.html.md | 9 +- learn/statistics/xtabs/index.qmd | 1 + 6 files changed, 142 insertions(+), 137 deletions(-) diff --git a/_freeze/learn/statistics/xtabs/index/execute-results/html.json b/_freeze/learn/statistics/xtabs/index/execute-results/html.json index 8212fdfc..77344fdf 100644 --- a/_freeze/learn/statistics/xtabs/index/execute-results/html.json +++ b/_freeze/learn/statistics/xtabs/index/execute-results/html.json @@ -1,8 +1,8 @@ { - "hash": "6b1827a3ca3453e761cd6eed17884cd9", + "hash": "317489a76f9d246bc4a62f4fc19ebf67", "result": { "engine": "knitr", - "markdown": "---\ntitle: \"Statistical analysis of contingency tables\"\ncategories:\n - statistical analysis\n - hypothesis testing\ntype: learn-subsection\nweight: 5\ndescription: | \n Use tests of independence and goodness of fit to analyze tables of counts.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n\n## Introduction\n\nThis article only requires that you have the tidymodels package installed.\n\nIn this vignette, we'll walk through conducting a $\\chi^2$ (chi-squared) test of independence and a chi-squared goodness of fit test using infer. We'll start out with a chi-squared test of independence, which can be used to test the association between two categorical variables. Then, we'll move on to a chi-squared goodness of fit test, which tests how well the distribution of one categorical variable can be approximated by some theoretical distribution.\n\nThroughout this vignette, we'll make use of the `ad_data` data set (available in the modeldata package, which is part of tidymodels). This data set is related to cognitive impairment in 333 patients from [Craig-Schapiro _et al_ (2011)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3079734/). See `?ad_data` for more information on the variables included and their source. One of the main research questions in these data were how a person's genetics related to the Apolipoprotein E gene affect their cognitive skills. The data shows: \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels) # Includes the infer package\n\ndata(ad_data, package = \"modeldata\")\nad_data %>%\n select(Genotype, Class)\n#> # A tibble: 333 × 2\n#> Genotype Class \n#> \n#> 1 E3E3 Control \n#> 2 E3E4 Control \n#> 3 E3E4 Control \n#> 4 E3E4 Control \n#> 5 E3E3 Control \n#> 6 E4E4 Impaired\n#> 7 E2E3 Control \n#> 8 E2E3 Control \n#> 9 E3E3 Control \n#> 10 E2E3 Impaired\n#> # ℹ 323 more rows\n```\n:::\n\n\n\n\nThe three main genetic variants are called E2, E3, and E4. The values in `Genotype` represent the genetic makeup of patients based on what they inherited from their parents (i.e, a value of \"E2E4\" means E2 from one parent and E4 from the other). \n\n## Test of independence\n\nTo carry out a chi-squared test of independence, we'll examine the association between their cognitive ability (impaired and healthy) and the genetic makeup. This is what the relationship looks like in the sample data:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n::: {.cell-output-display}\n![](figs/plot-indep-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nIf there were no relationship, we would expect to see the purple bars reaching to the same length, regardless of cognitive ability. Are the differences we see here, though, just due to random noise?\n\nFirst, to calculate the observed statistic, we can use `specify()` and `calculate()`.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# calculate the observed statistic\nobserved_indep_statistic <- ad_data %>%\n specify(Genotype ~ Class) %>%\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nThe observed $\\chi^2$ statistic is 21.5774809. Now, we want to compare this statistic to a null distribution, generated under the assumption that these variables are not actually related, to get a sense of how likely it would be for us to see this observed statistic if there were actually no association between cognitive ability and genetics.\n\nWe can `generate()` the null distribution in one of two ways: using randomization or theory-based methods. The randomization approach permutes the response and explanatory variables, so that each person's genetics is matched up with a random cognitive rating from the sample in order to break up any association between the two.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# generate the null distribution using randomization\nnull_distribution_simulated <- ad_data %>%\n specify(Genotype ~ Class) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nNote that, in the line `specify(Genotype ~ Class)` above, we could use the equivalent syntax `specify(response = Genotype, explanatory = Class)`. The same goes in the code below, which generates the null distribution using theory-based methods instead of randomization.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# generate the null distribution by theoretical approximation\nnull_distribution_theoretical <- ad_data %>%\n specify(Genotype ~ Class) %>%\n hypothesize(null = \"independence\") %>%\n # note that we skip the generation step here!\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nTo get a sense for what these distributions look like, and where our observed statistic falls, we can use `visualize()`:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# visualize the null distribution and test statistic!\nnull_distribution_simulated %>%\n visualize() + \n shade_p_value(observed_indep_statistic,\n direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize-indep-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nWe could also visualize the observed statistic against the theoretical null distribution. Note that we skip the `generate()` and `calculate()` steps when using the theoretical approach, and that we now need to provide `method = \"theoretical\"` to `visualize()`.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# visualize the theoretical null distribution and test statistic!\nad_data %>%\n specify(Genotype ~ Class) %>%\n hypothesize(null = \"independence\") %>%\n visualize(method = \"theoretical\") + \n shade_p_value(observed_indep_statistic,\n direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize-indep-theor-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nTo visualize both the randomization-based and theoretical null distributions to get a sense of how the two relate, we can pipe the randomization-based null distribution into `visualize()`, and further provide `method = \"both\"`.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# visualize both null distributions and the test statistic!\nnull_distribution_simulated %>%\n visualize(method = \"both\") + \n shade_p_value(observed_indep_statistic,\n direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize-indep-both-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nEither way, it looks like our observed test statistic would be fairly unlikely if there were actually no association between cognition and genotype. More exactly, we can calculate the p-value:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# calculate the p value from the observed statistic and null distribution\np_value_independence <- null_distribution_simulated %>%\n get_p_value(obs_stat = observed_indep_statistic,\n direction = \"greater\")\n\np_value_independence\n#> # A tibble: 1 × 1\n#> p_value\n#> \n#> 1 0.0002\n```\n:::\n\n\n\n\nThus, if there were really no relationship between cognition and genotype, the probability that we would see a statistic as or more extreme than 21.5774809 is approximately 2\\times 10^{-4}.\n\nNote that, equivalently to the steps shown above, the package supplies a wrapper function, `chisq_test`, to carry out Chi-Squared tests of independence on tidy data. The syntax goes like this:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nchisq_test(ad_data, Genotype ~ Class)\n#> # A tibble: 1 × 3\n#> statistic chisq_df p_value\n#> \n#> 1 21.6 5 0.000630\n```\n:::\n\n\n\n\n\n## Goodness of fit\n\nNow, moving on to a chi-squared goodness of fit test, we'll take a look at just the genotype data. Many papers have investigated the relationship of Apolipoprotein E to diseases. For example, [Song _et al_ (2004)](https://annals.org/aim/article-abstract/717641/meta-analysis-apolipoprotein-e-genotypes-risk-coronary-heart-disease) conducted a meta-analysis of numerous studies that looked at this gene and heart disease. In their paper, they describe the frequency of the different genotypes across many samples. For the cognition study, it might be interesting to see if our sample of genotypes was consistent with this literature (treating the rates, for this analysis, as known). \n\nThe rates of the meta-analysis and our observed data are: \n \n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# Song, Y., Stampfer, M. J., & Liu, S. (2004). Meta-Analysis: Apolipoprotein E \n# Genotypes and Risk for Coronary Heart Disease. Annals of Internal Medicine, \n# 141(2), 137.\nmeta_rates <- c(\"E2E2\" = 0.71, \"E2E3\" = 11.4, \"E2E4\" = 2.32,\n \"E3E3\" = 61.0, \"E3E4\" = 22.6, \"E4E4\" = 2.22)\nmeta_rates <- meta_rates/sum(meta_rates) # these add up to slightly > 100%\n\nobs_rates <- table(ad_data$Genotype)/nrow(ad_data)\nround(cbind(obs_rates, meta_rates) * 100, 2)\n#> obs_rates meta_rates\n#> E2E2 0.60 0.71\n#> E2E3 11.11 11.37\n#> E2E4 2.40 2.31\n#> E3E3 50.15 60.85\n#> E3E4 31.83 22.54\n#> E4E4 3.90 2.21\n```\n:::\n\n\n\n\nSuppose our null hypothesis is that `Genotype` follows the same frequency distribution as the meta-analysis. Lets now test whether this difference in distributions is statistically significant.\n\nFirst, to carry out this hypothesis test, we would calculate our observed statistic.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# calculating the null distribution\nobserved_gof_statistic <- ad_data %>%\n specify(response = Genotype) %>%\n hypothesize(null = \"point\", p = meta_rates) %>%\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nThe observed statistic is 23.3838483. Now, generating a null distribution, by just dropping in a call to `generate()`:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# generating a null distribution\nnull_distribution_gof <- ad_data %>%\n specify(response = Genotype) %>%\n hypothesize(null = \"point\", p = meta_rates) %>%\n generate(reps = 5000, type = \"simulate\") %>%\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nAgain, to get a sense for what these distributions look like, and where our observed statistic falls, we can use `visualize()`:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# visualize the null distribution and test statistic!\nnull_distribution_gof %>%\n visualize() + \n shade_p_value(observed_gof_statistic,\n direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize-indep-gof-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThis statistic seems like it would be unlikely if our rates were the same as the rates from the meta-analysis! How unlikely, though? Calculating the p-value:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# calculate the p-value\np_value_gof <- null_distribution_gof %>%\n get_p_value(observed_gof_statistic,\n direction = \"greater\")\n\np_value_gof\n#> # A tibble: 1 × 1\n#> p_value\n#> \n#> 1 0.0012\n```\n:::\n\n\n\n\nThus, if each genotype occurred at the same rate as the Song paper, the probability that we would see a distribution like the one we did is approximately 0.0012.\n\nAgain, equivalently to the steps shown above, the package supplies a wrapper function, `chisq_test`, to carry out chi-squared goodness of fit tests on tidy data. The syntax goes like this:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nchisq_test(ad_data, response = Genotype, p = meta_rates)\n#> # A tibble: 1 × 3\n#> statistic chisq_df p_value\n#> \n#> 1 23.4 5 0.000285\n```\n:::\n\n\n\n\n\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", + "markdown": "---\ntitle: \"Statistical analysis of contingency tables\"\ncategories:\n - statistical analysis\n - hypothesis testing\ntype: learn-subsection\nweight: 5\ndescription: | \n Use tests of independence and goodness of fit to analyze tables of counts.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n\n## Introduction\n\nThis article only requires that you have the tidymodels package installed.\n\nIn this vignette, we'll walk through conducting a $\\chi^2$ (chi-squared) test of independence and a chi-squared goodness of fit test using infer. We'll start out with a chi-squared test of independence, which can be used to test the association between two categorical variables. Then, we'll move on to a chi-squared goodness of fit test, which tests how well the distribution of one categorical variable can be approximated by some theoretical distribution.\n\nThroughout this vignette, we'll make use of the `ad_data` data set (available in the modeldata package, which is part of tidymodels). This data set is related to cognitive impairment in 333 patients from [Craig-Schapiro _et al_ (2011)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3079734/). See `?ad_data` for more information on the variables included and their source. One of the main research questions in these data were how a person's genetics related to the Apolipoprotein E gene affect their cognitive skills. The data shows: \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels) # Includes the infer package\nset.seed(1234)\n\ndata(ad_data, package = \"modeldata\")\nad_data %>%\n select(Genotype, Class)\n#> # A tibble: 333 × 2\n#> Genotype Class \n#> \n#> 1 E3E3 Control \n#> 2 E3E4 Control \n#> 3 E3E4 Control \n#> 4 E3E4 Control \n#> 5 E3E3 Control \n#> 6 E4E4 Impaired\n#> 7 E2E3 Control \n#> 8 E2E3 Control \n#> 9 E3E3 Control \n#> 10 E2E3 Impaired\n#> # ℹ 323 more rows\n```\n:::\n\n\n\n\nThe three main genetic variants are called E2, E3, and E4. The values in `Genotype` represent the genetic makeup of patients based on what they inherited from their parents (i.e, a value of \"E2E4\" means E2 from one parent and E4 from the other). \n\n## Test of independence\n\nTo carry out a chi-squared test of independence, we'll examine the association between their cognitive ability (impaired and healthy) and the genetic makeup. This is what the relationship looks like in the sample data:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n::: {.cell-output-display}\n![](figs/plot-indep-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nIf there were no relationship, we would expect to see the purple bars reaching to the same length, regardless of cognitive ability. Are the differences we see here, though, just due to random noise?\n\nFirst, to calculate the observed statistic, we can use `specify()` and `calculate()`.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# calculate the observed statistic\nobserved_indep_statistic <- ad_data %>%\n specify(Genotype ~ Class) %>%\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nThe observed $\\chi^2$ statistic is 21.5774809. Now, we want to compare this statistic to a null distribution, generated under the assumption that these variables are not actually related, to get a sense of how likely it would be for us to see this observed statistic if there were actually no association between cognitive ability and genetics.\n\nWe can `generate()` the null distribution in one of two ways: using randomization or theory-based methods. The randomization approach permutes the response and explanatory variables, so that each person's genetics is matched up with a random cognitive rating from the sample in order to break up any association between the two.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# generate the null distribution using randomization\nnull_distribution_simulated <- ad_data %>%\n specify(Genotype ~ Class) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nNote that, in the line `specify(Genotype ~ Class)` above, we could use the equivalent syntax `specify(response = Genotype, explanatory = Class)`. The same goes in the code below, which generates the null distribution using theory-based methods instead of randomization.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# generate the null distribution by theoretical approximation\nnull_distribution_theoretical <- ad_data %>%\n specify(Genotype ~ Class) %>%\n hypothesize(null = \"independence\") %>%\n # note that we skip the generation step here!\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nTo get a sense for what these distributions look like, and where our observed statistic falls, we can use `visualize()`:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# visualize the null distribution and test statistic!\nnull_distribution_simulated %>%\n visualize() + \n shade_p_value(observed_indep_statistic,\n direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize-indep-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nWe could also visualize the observed statistic against the theoretical null distribution. Note that we skip the `generate()` and `calculate()` steps when using the theoretical approach, and that we now need to provide `method = \"theoretical\"` to `visualize()`.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# visualize the theoretical null distribution and test statistic!\nad_data %>%\n specify(Genotype ~ Class) %>%\n hypothesize(null = \"independence\") %>%\n visualize(method = \"theoretical\") + \n shade_p_value(observed_indep_statistic,\n direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize-indep-theor-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nTo visualize both the randomization-based and theoretical null distributions to get a sense of how the two relate, we can pipe the randomization-based null distribution into `visualize()`, and further provide `method = \"both\"`.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# visualize both null distributions and the test statistic!\nnull_distribution_simulated %>%\n visualize(method = \"both\") + \n shade_p_value(observed_indep_statistic,\n direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize-indep-both-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nEither way, it looks like our observed test statistic would be fairly unlikely if there were actually no association between cognition and genotype. More exactly, we can calculate the p-value:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# calculate the p value from the observed statistic and null distribution\np_value_independence <- null_distribution_simulated %>%\n get_p_value(obs_stat = observed_indep_statistic,\n direction = \"greater\")\n\np_value_independence\n#> # A tibble: 1 × 1\n#> p_value\n#> \n#> 1 0.0006\n```\n:::\n\n\n\n\nThus, if there were really no relationship between cognition and genotype, the probability that we would see a statistic as or more extreme than 21.5774809 is approximately 6\\times 10^{-4}.\n\nNote that, equivalently to the steps shown above, the package supplies a wrapper function, `chisq_test`, to carry out Chi-Squared tests of independence on tidy data. The syntax goes like this:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nchisq_test(ad_data, Genotype ~ Class)\n#> # A tibble: 1 × 3\n#> statistic chisq_df p_value\n#> \n#> 1 21.6 5 0.000630\n```\n:::\n\n\n\n\n\n## Goodness of fit\n\nNow, moving on to a chi-squared goodness of fit test, we'll take a look at just the genotype data. Many papers have investigated the relationship of Apolipoprotein E to diseases. For example, [Song _et al_ (2004)](https://annals.org/aim/article-abstract/717641/meta-analysis-apolipoprotein-e-genotypes-risk-coronary-heart-disease) conducted a meta-analysis of numerous studies that looked at this gene and heart disease. In their paper, they describe the frequency of the different genotypes across many samples. For the cognition study, it might be interesting to see if our sample of genotypes was consistent with this literature (treating the rates, for this analysis, as known). \n\nThe rates of the meta-analysis and our observed data are: \n \n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# Song, Y., Stampfer, M. J., & Liu, S. (2004). Meta-Analysis: Apolipoprotein E \n# Genotypes and Risk for Coronary Heart Disease. Annals of Internal Medicine, \n# 141(2), 137.\nmeta_rates <- c(\"E2E2\" = 0.71, \"E2E3\" = 11.4, \"E2E4\" = 2.32,\n \"E3E3\" = 61.0, \"E3E4\" = 22.6, \"E4E4\" = 2.22)\nmeta_rates <- meta_rates/sum(meta_rates) # these add up to slightly > 100%\n\nobs_rates <- table(ad_data$Genotype)/nrow(ad_data)\nround(cbind(obs_rates, meta_rates) * 100, 2)\n#> obs_rates meta_rates\n#> E2E2 0.60 0.71\n#> E2E3 11.11 11.37\n#> E2E4 2.40 2.31\n#> E3E3 50.15 60.85\n#> E3E4 31.83 22.54\n#> E4E4 3.90 2.21\n```\n:::\n\n\n\n\nSuppose our null hypothesis is that `Genotype` follows the same frequency distribution as the meta-analysis. Lets now test whether this difference in distributions is statistically significant.\n\nFirst, to carry out this hypothesis test, we would calculate our observed statistic.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# calculating the null distribution\nobserved_gof_statistic <- ad_data %>%\n specify(response = Genotype) %>%\n hypothesize(null = \"point\", p = meta_rates) %>%\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nThe observed statistic is 23.3838483. Now, generating a null distribution, by just dropping in a call to `generate()`:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# generating a null distribution\nnull_distribution_gof <- ad_data %>%\n specify(response = Genotype) %>%\n hypothesize(null = \"point\", p = meta_rates) %>%\n generate(reps = 5000, type = \"simulate\") %>%\n calculate(stat = \"Chisq\")\n```\n:::\n\n\n\n\nAgain, to get a sense for what these distributions look like, and where our observed statistic falls, we can use `visualize()`:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# visualize the null distribution and test statistic!\nnull_distribution_gof %>%\n visualize() + \n shade_p_value(observed_gof_statistic,\n direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize-indep-gof-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThis statistic seems like it would be unlikely if our rates were the same as the rates from the meta-analysis! How unlikely, though? Calculating the p-value:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# calculate the p-value\np_value_gof <- null_distribution_gof %>%\n get_p_value(observed_gof_statistic,\n direction = \"greater\")\n\np_value_gof\n#> # A tibble: 1 × 1\n#> p_value\n#> \n#> 1 0.001\n```\n:::\n\n\n\n\nThus, if each genotype occurred at the same rate as the Song paper, the probability that we would see a distribution like the one we did is approximately 0.001.\n\nAgain, equivalently to the steps shown above, the package supplies a wrapper function, `chisq_test`, to carry out chi-squared goodness of fit tests on tidy data. The syntax goes like this:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nchisq_test(ad_data, response = Genotype, p = meta_rates)\n#> # A tibble: 1 × 3\n#> statistic chisq_df p_value\n#> \n#> 1 23.4 5 0.000285\n```\n:::\n\n\n\n\n\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", "supporting": [], "filters": [ "rmarkdown/pagebreak.lua" diff --git a/learn/statistics/xtabs/figs/visualize-indep-1.svg b/learn/statistics/xtabs/figs/visualize-indep-1.svg index d63b3e5b..e7ba7413 100644 --- a/learn/statistics/xtabs/figs/visualize-indep-1.svg +++ b/learn/statistics/xtabs/figs/visualize-indep-1.svg @@ -30,59 +30,62 @@ - - - - - - - - + + + + + + + + - - + + + - - - - - - + + + + + + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + 0 -500 -1000 +400 +800 +1200 - - + + + - - - - + + + + 0 -5 -10 -15 -20 +5 +10 +15 +20 stat count Simulation-Based Null Distribution diff --git a/learn/statistics/xtabs/figs/visualize-indep-both-1.svg b/learn/statistics/xtabs/figs/visualize-indep-both-1.svg index 626aba1d..a647330a 100644 --- a/learn/statistics/xtabs/figs/visualize-indep-both-1.svg +++ b/learn/statistics/xtabs/figs/visualize-indep-both-1.svg @@ -30,63 +30,63 @@ - - - - - - - - + + + + + + + + - - - + + + - - - - - - + + + + + + - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + 0.00 -0.05 -0.10 -0.15 +0.05 +0.10 +0.15 - - - + + + - - - - + + + + 0 -5 -10 -15 -20 +5 +10 +15 +20 Chi-Square stat density Simulation-Based and Theoretical Chi-Square Null Distributions diff --git a/learn/statistics/xtabs/figs/visualize-indep-gof-1.svg b/learn/statistics/xtabs/figs/visualize-indep-gof-1.svg index eb974acf..7f3e3c2a 100644 --- a/learn/statistics/xtabs/figs/visualize-indep-gof-1.svg +++ b/learn/statistics/xtabs/figs/visualize-indep-gof-1.svg @@ -30,57 +30,57 @@ - - - - - - + + + + + + - - - + + + - - - - - + + + + + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + 0 -500 -1000 -1500 +500 +1000 +1500 - - - + + + - - - + + + 0 -10 -20 -30 +10 +20 +30 stat count Simulation-Based Null Distribution diff --git a/learn/statistics/xtabs/index.html.md b/learn/statistics/xtabs/index.html.md index 4c71414d..bd4edde1 100644 --- a/learn/statistics/xtabs/index.html.md +++ b/learn/statistics/xtabs/index.html.md @@ -24,6 +24,7 @@ Throughout this vignette, we'll make use of the `ad_data` data set (available in ```{.r .cell-code} library(tidymodels) # Includes the infer package +set.seed(1234) data(ad_data, package = "modeldata") ad_data %>% @@ -168,11 +169,11 @@ p_value_independence #> # A tibble: 1 × 1 #> p_value #> -#> 1 0.0002 +#> 1 0.0006 ``` ::: -Thus, if there were really no relationship between cognition and genotype, the probability that we would see a statistic as or more extreme than 21.5774809 is approximately 2\times 10^{-4}. +Thus, if there were really no relationship between cognition and genotype, the probability that we would see a statistic as or more extreme than 21.5774809 is approximately 6\times 10^{-4}. Note that, equivalently to the steps shown above, the package supplies a wrapper function, `chisq_test`, to carry out Chi-Squared tests of independence on tidy data. The syntax goes like this: @@ -276,11 +277,11 @@ p_value_gof #> # A tibble: 1 × 1 #> p_value #> -#> 1 0.0012 +#> 1 0.001 ``` ::: -Thus, if each genotype occurred at the same rate as the Song paper, the probability that we would see a distribution like the one we did is approximately 0.0012. +Thus, if each genotype occurred at the same rate as the Song paper, the probability that we would see a distribution like the one we did is approximately 0.001. Again, equivalently to the steps shown above, the package supplies a wrapper function, `chisq_test`, to carry out chi-squared goodness of fit tests on tidy data. The syntax goes like this: diff --git a/learn/statistics/xtabs/index.qmd b/learn/statistics/xtabs/index.qmd index d973a850..b705fd94 100644 --- a/learn/statistics/xtabs/index.qmd +++ b/learn/statistics/xtabs/index.qmd @@ -43,6 +43,7 @@ Throughout this vignette, we'll make use of the `ad_data` data set (available in #| warning: false #| message: false library(tidymodels) # Includes the infer package +set.seed(1234) data(ad_data, package = "modeldata") ad_data %>% From c60fc44d86342df9db3a03e47dbd8fd149f7b91f Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 24 Mar 2025 12:28:37 -0700 Subject: [PATCH 5/7] set seed in learn/work/nested-resampling --- .../index/execute-results/html.json | 4 +- .../work/nested-resampling/figs/choose-1.svg | 34 +- .../nested-resampling/figs/rmse-plot-1.svg | 218 +++++----- learn/work/nested-resampling/index.html.md | 4 +- learn/work/nested-resampling/index.qmd | 2 + learn/work/nested-resampling/index.rmarkdown | 371 ++++++++++++++++++ 6 files changed, 504 insertions(+), 129 deletions(-) create mode 100644 learn/work/nested-resampling/index.rmarkdown diff --git a/_freeze/learn/work/nested-resampling/index/execute-results/html.json b/_freeze/learn/work/nested-resampling/index/execute-results/html.json index 13e0ff7f..2fc7e86b 100644 --- a/_freeze/learn/work/nested-resampling/index/execute-results/html.json +++ b/_freeze/learn/work/nested-resampling/index/execute-results/html.json @@ -1,8 +1,8 @@ { - "hash": "ed87ff97154a407d5822e1c2d6b5d9d2", + "hash": "07e8f62215dc10f0d38750bd159b4bc0", "result": { "engine": "knitr", - "markdown": "---\ntitle: \"Nested resampling\"\ncategories:\n - SVMs\ntype: learn-subsection\nweight: 2\ndescription: | \n Estimate the best hyperparameters for a model using nested resampling.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: furrr, kernlab, mlbench, scales, and tidymodels.\n\nIn this article, we discuss an alternative method for evaluating and tuning models, called [nested resampling](https://scholar.google.com/scholar?hl=en&as_sdt=0%2C7&q=%22nested+resampling%22+inner+outer&btnG=). While it is more computationally taxing and challenging to implement than other resampling methods, it has the potential to produce better estimates of model performance.\n\n## Resampling models\n\nA typical scheme for splitting the data when developing a predictive model is to create an initial split of the data into a training and test set. If resampling is used, it is executed on the training set. A series of binary splits is created. In rsample, we use the term *analysis set* for the data that are used to fit the model and the term *assessment set* for the set used to compute performance:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n::: {.cell-output-display}\n![](img/resampling.svg){fig-align='center' width=70%}\n:::\n:::\n\n\n\n\nA common method for tuning models is [grid search](/learn/work/tune-svm/) where a candidate set of tuning parameters is created. The full set of models for every combination of the tuning parameter grid and the resamples is fitted. Each time, the assessment data are used to measure performance and the average value is determined for each tuning parameter.\n\nThe potential problem is that once we pick the tuning parameter associated with the best performance, this performance value is usually quoted as the performance of the model. There is serious potential for *optimization bias* since we use the same data to tune the model and to assess performance. This would result in an optimistic estimate of performance.\n\nNested resampling uses an additional layer of resampling that separates the tuning activities from the process used to estimate the efficacy of the model. An *outer* resampling scheme is used and, for every split in the outer resample, another full set of resampling splits are created on the original analysis set. For example, if 10-fold cross-validation is used on the outside and 5-fold cross-validation on the inside, a total of 500 models will be fit. The parameter tuning will be conducted 10 times and the best parameters are determined from the average of the 5 assessment sets. This process occurs 10 times.\n\nOnce the tuning results are complete, a model is fit to each of the outer resampling splits using the best parameter associated with that resample. The average of the outer method's assessment sets are a unbiased estimate of the model.\n\nWe will simulate some regression data to illustrate the methods. The mlbench package has a function `mlbench::mlbench.friedman1()` that can simulate a complex regression data structure from the [original MARS publication](https://scholar.google.com/scholar?hl=en&q=%22Multivariate+adaptive+regression+splines%22&btnG=&as_sdt=1%2C7&as_sdtp=). A training set size of 100 data points are generated as well as a large set that will be used to characterize how well the resampling procedure performed.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(mlbench)\nsim_data <- function(n) {\n tmp <- mlbench.friedman1(n, sd = 1)\n tmp <- cbind(tmp$x, tmp$y)\n tmp <- as.data.frame(tmp)\n names(tmp)[ncol(tmp)] <- \"y\"\n tmp\n}\n\nset.seed(9815)\ntrain_dat <- sim_data(100)\nlarge_dat <- sim_data(10^5)\n```\n:::\n\n\n\n\n## Nested resampling\n\nTo get started, the types of resampling methods need to be specified. This isn't a large data set, so 5 repeats of 10-fold cross validation will be used as the *outer* resampling method for generating the estimate of overall performance. To tune the model, it would be good to have precise estimates for each of the values of the tuning parameter so let's use 25 iterations of the bootstrap. This means that there will eventually be `5 * 10 * 25 = 1250` models that are fit to the data *per tuning parameter*. These models will be discarded once the performance of the model has been quantified.\n\nTo create the tibble with the resampling specifications:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nresults <- nested_cv(train_dat, \n outside = vfold_cv(repeats = 5), \n inside = bootstraps(times = 25))\nresults\n#> # Nested resampling:\n#> # outer: 10-fold cross-validation repeated 5 times\n#> # inner: Bootstrap sampling\n#> # A tibble: 50 × 4\n#> splits id id2 inner_resamples\n#> \n#> 1 Repeat1 Fold01 \n#> 2 Repeat1 Fold02 \n#> 3 Repeat1 Fold03 \n#> 4 Repeat1 Fold04 \n#> 5 Repeat1 Fold05 \n#> 6 Repeat1 Fold06 \n#> 7 Repeat1 Fold07 \n#> 8 Repeat1 Fold08 \n#> 9 Repeat1 Fold09 \n#> 10 Repeat1 Fold10 \n#> # ℹ 40 more rows\n```\n:::\n\n\n\n\nThe splitting information for each resample is contained in the `split` objects. Focusing on the second fold of the first repeat:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$splits[[2]]\n#> \n#> <90/10/100>\n```\n:::\n\n\n\n\n`<90/10/100>` indicates the number of observations in the analysis set, assessment set, and the original data.\n\nEach element of `inner_resamples` has its own tibble with the bootstrapping splits.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$inner_resamples[[5]]\n#> # Bootstrap sampling \n#> # A tibble: 25 × 2\n#> splits id \n#> \n#> 1 Bootstrap01\n#> 2 Bootstrap02\n#> 3 Bootstrap03\n#> 4 Bootstrap04\n#> 5 Bootstrap05\n#> 6 Bootstrap06\n#> 7 Bootstrap07\n#> 8 Bootstrap08\n#> 9 Bootstrap09\n#> 10 Bootstrap10\n#> # ℹ 15 more rows\n```\n:::\n\n\n\n\nThese are self-contained, meaning that the bootstrap sample is aware that it is a sample of a specific 90% of the data:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$inner_resamples[[5]]$splits[[1]]\n#> \n#> <90/31/90>\n```\n:::\n\n\n\n\nTo start, we need to define how the model will be created and measured. Let's use a radial basis support vector machine model via the function `kernlab::ksvm`. This model is generally considered to have *two* tuning parameters: the SVM cost value and the kernel parameter `sigma`. For illustration purposes here, only the cost value will be tuned and the function `kernlab::sigest` will be used to estimate `sigma` during each model fit. This is automatically done by `ksvm`.\n\nAfter the model is fit to the analysis set, the root-mean squared error (RMSE) is computed on the assessment set. **One important note:** for this model, it is critical to center and scale the predictors before computing dot products. We don't do this operation here because `mlbench.friedman1` simulates all of the predictors to be standardized uniform random variables.\n\nOur function to fit the model and compute the RMSE is:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(kernlab)\n\n# `object` will be an `rsplit` object from our `results` tibble\n# `cost` is the tuning parameter\nsvm_rmse <- function(object, cost = 1) {\n y_col <- ncol(object$data)\n mod <- \n svm_rbf(mode = \"regression\", cost = cost) %>% \n set_engine(\"kernlab\") %>% \n fit(y ~ ., data = analysis(object))\n \n holdout_pred <- \n predict(mod, assessment(object) %>% dplyr::select(-y)) %>% \n bind_cols(assessment(object) %>% dplyr::select(y))\n rmse(holdout_pred, truth = y, estimate = .pred)$.estimate\n}\n\n# In some case, we want to parameterize the function over the tuning parameter:\nrmse_wrapper <- function(cost, object) svm_rmse(object, cost)\n```\n:::\n\n\n\n\nFor the nested resampling, a model needs to be fit for each tuning parameter and each bootstrap split. To do this, create a wrapper:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# `object` will be an `rsplit` object for the bootstrap samples\ntune_over_cost <- function(object) {\n tibble(cost = 2 ^ seq(-2, 8, by = 1)) %>% \n mutate(RMSE = map_dbl(cost, rmse_wrapper, object = object))\n}\n```\n:::\n\n\n\n\nSince this will be called across the set of outer cross-validation splits, another wrapper is required:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# `object` is an `rsplit` object in `results$inner_resamples` \nsummarize_tune_results <- function(object) {\n # Return row-bound tibble that has the 25 bootstrap results\n map_df(object$splits, tune_over_cost) %>%\n # For each value of the tuning parameter, compute the \n # average RMSE which is the inner bootstrap estimate. \n group_by(cost) %>%\n summarize(mean_RMSE = mean(RMSE, na.rm = TRUE),\n n = length(RMSE),\n .groups = \"drop\")\n}\n```\n:::\n\n\n\n\nNow that those functions are defined, we can execute all the inner resampling loops:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ntuning_results <- map(results$inner_resamples, summarize_tune_results) \n```\n:::\n\n\n\n\nAlternatively, since these computations can be run in parallel, we can use the furrr package. Instead of using `map()`, the function `future_map()` parallelizes the iterations using the [future package](https://cran.r-project.org/web/packages/future/vignettes/future-1-overview.html). The `multisession` plan uses the local cores to process the inner resampling loop. The end results are the same as the sequential computations.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(furrr)\nplan(multisession)\n\ntuning_results <- future_map(results$inner_resamples, summarize_tune_results) \n```\n:::\n\n\n\n\nThe object `tuning_results` is a list of data frames for each of the 50 outer resamples.\n\nLet's make a plot of the averaged results to see what the relationship is between the RMSE and the tuning parameters for each of the inner bootstrapping operations:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(scales)\n\npooled_inner <- tuning_results %>% bind_rows\n\nbest_cost <- function(dat) dat[which.min(dat$mean_RMSE),]\n\np <- \n ggplot(pooled_inner, aes(x = cost, y = mean_RMSE)) + \n scale_x_continuous(trans = 'log2') +\n xlab(\"SVM Cost\") + ylab(\"Inner RMSE\")\n\nfor (i in 1:length(tuning_results))\n p <- p +\n geom_line(data = tuning_results[[i]], alpha = .2) +\n geom_point(data = best_cost(tuning_results[[i]]), pch = 16, alpha = 3/4)\n\np <- p + geom_smooth(data = pooled_inner, se = FALSE)\np\n```\n\n::: {.cell-output-display}\n![](figs/rmse-plot-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nEach gray line is a separate bootstrap resampling curve created from a different 90% of the data. The blue line is a LOESS smooth of all the results pooled together.\n\nTo determine the best parameter estimate for each of the outer resampling iterations:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncost_vals <- \n tuning_results %>% \n map_df(best_cost) %>% \n select(cost)\n\nresults <- \n bind_cols(results, cost_vals) %>% \n mutate(cost = factor(cost, levels = paste(2 ^ seq(-2, 8, by = 1))))\n\nggplot(results, aes(x = cost)) + \n geom_bar() + \n xlab(\"SVM Cost\") + \n scale_x_discrete(drop = FALSE)\n```\n\n::: {.cell-output-display}\n![](figs/choose-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nMost of the resamples produced an optimal cost value of 2.0, but the distribution is right-skewed due to the flat trend in the resampling profile once the cost value becomes 10 or larger.\n\nNow that we have these estimates, we can compute the outer resampling results for each of the 50 splits using the corresponding tuning parameter value:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults <- \n results %>% \n mutate(RMSE = map2_dbl(splits, cost, svm_rmse))\n\nsummary(results$RMSE)\n#> Min. 1st Qu. Median Mean 3rd Qu. Max. \n#> 1.574 2.095 2.688 2.697 3.265 4.350\n```\n:::\n\n\n\n\nThe estimated RMSE for the model tuning process is 2.7.\n\nWhat is the RMSE estimate for the non-nested procedure when only the outer resampling method is used? For each cost value in the tuning grid, 50 SVM models are fit and their RMSE values are averaged. The table of cost values and mean RMSE estimates is used to determine the best cost value. The associated RMSE is the biased estimate.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnot_nested <- \n map(results$splits, tune_over_cost) %>%\n bind_rows\n\nouter_summary <- not_nested %>% \n group_by(cost) %>% \n summarize(outer_RMSE = mean(RMSE), n = length(RMSE))\n\nouter_summary\n#> # A tibble: 11 × 3\n#> cost outer_RMSE n\n#> \n#> 1 0.25 3.54 50\n#> 2 0.5 3.11 50\n#> 3 1 2.77 50\n#> 4 2 2.62 50\n#> 5 4 2.65 50\n#> 6 8 2.75 50\n#> 7 16 2.82 50\n#> 8 32 2.82 50\n#> 9 64 2.83 50\n#> 10 128 2.83 50\n#> 11 256 2.82 50\n\nggplot(outer_summary, aes(x = cost, y = outer_RMSE)) + \n geom_point() + \n geom_line() + \n scale_x_continuous(trans = 'log2') +\n xlab(\"SVM Cost\") + ylab(\"RMSE\")\n```\n\n::: {.cell-output-display}\n![](figs/not-nested-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThe non-nested procedure estimates the RMSE to be 2.62. Both estimates are fairly close.\n\nThe approximately true RMSE for an SVM model with a cost value of 2.0 can be approximated with the large sample that was simulated at the beginning.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nfinalModel <- ksvm(y ~ ., data = train_dat, C = 2)\nlarge_pred <- predict(finalModel, large_dat[, -ncol(large_dat)])\nsqrt(mean((large_dat$y - large_pred) ^ 2, na.rm = TRUE))\n#> [1] 2.712059\n```\n:::\n\n\n\n\nThe nested procedure produces a closer estimate to the approximate truth but the non-nested estimate is very similar.\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> furrr 0.3.1 2022-08-15 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> kernlab 0.9-33 2024-08-13 CRAN (R 4.4.0)\n#> mlbench 2.1-6 2024-12-30 CRAN (R 4.4.1)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> scales 1.3.0 2023-11-28 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", + "markdown": "---\ntitle: \"Nested resampling\"\ncategories:\n - SVMs\ntype: learn-subsection\nweight: 2\ndescription: | \n Estimate the best hyperparameters for a model using nested resampling.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: furrr, kernlab, mlbench, scales, and tidymodels.\n\nIn this article, we discuss an alternative method for evaluating and tuning models, called [nested resampling](https://scholar.google.com/scholar?hl=en&as_sdt=0%2C7&q=%22nested+resampling%22+inner+outer&btnG=). While it is more computationally taxing and challenging to implement than other resampling methods, it has the potential to produce better estimates of model performance.\n\n## Resampling models\n\nA typical scheme for splitting the data when developing a predictive model is to create an initial split of the data into a training and test set. If resampling is used, it is executed on the training set. A series of binary splits is created. In rsample, we use the term *analysis set* for the data that are used to fit the model and the term *assessment set* for the set used to compute performance:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n::: {.cell-output-display}\n![](img/resampling.svg){fig-align='center' width=70%}\n:::\n:::\n\n\n\n\nA common method for tuning models is [grid search](/learn/work/tune-svm/) where a candidate set of tuning parameters is created. The full set of models for every combination of the tuning parameter grid and the resamples is fitted. Each time, the assessment data are used to measure performance and the average value is determined for each tuning parameter.\n\nThe potential problem is that once we pick the tuning parameter associated with the best performance, this performance value is usually quoted as the performance of the model. There is serious potential for *optimization bias* since we use the same data to tune the model and to assess performance. This would result in an optimistic estimate of performance.\n\nNested resampling uses an additional layer of resampling that separates the tuning activities from the process used to estimate the efficacy of the model. An *outer* resampling scheme is used and, for every split in the outer resample, another full set of resampling splits are created on the original analysis set. For example, if 10-fold cross-validation is used on the outside and 5-fold cross-validation on the inside, a total of 500 models will be fit. The parameter tuning will be conducted 10 times and the best parameters are determined from the average of the 5 assessment sets. This process occurs 10 times.\n\nOnce the tuning results are complete, a model is fit to each of the outer resampling splits using the best parameter associated with that resample. The average of the outer method's assessment sets are a unbiased estimate of the model.\n\nWe will simulate some regression data to illustrate the methods. The mlbench package has a function `mlbench::mlbench.friedman1()` that can simulate a complex regression data structure from the [original MARS publication](https://scholar.google.com/scholar?hl=en&q=%22Multivariate+adaptive+regression+splines%22&btnG=&as_sdt=1%2C7&as_sdtp=). A training set size of 100 data points are generated as well as a large set that will be used to characterize how well the resampling procedure performed.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(mlbench)\nsim_data <- function(n) {\n tmp <- mlbench.friedman1(n, sd = 1)\n tmp <- cbind(tmp$x, tmp$y)\n tmp <- as.data.frame(tmp)\n names(tmp)[ncol(tmp)] <- \"y\"\n tmp\n}\n\nset.seed(9815)\ntrain_dat <- sim_data(100)\nlarge_dat <- sim_data(10^5)\n```\n:::\n\n\n\n\n## Nested resampling\n\nTo get started, the types of resampling methods need to be specified. This isn't a large data set, so 5 repeats of 10-fold cross validation will be used as the *outer* resampling method for generating the estimate of overall performance. To tune the model, it would be good to have precise estimates for each of the values of the tuning parameter so let's use 25 iterations of the bootstrap. This means that there will eventually be `5 * 10 * 25 = 1250` models that are fit to the data *per tuning parameter*. These models will be discarded once the performance of the model has been quantified.\n\nTo create the tibble with the resampling specifications:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nresults <- nested_cv(train_dat, \n outside = vfold_cv(repeats = 5), \n inside = bootstraps(times = 25))\nresults\n#> # Nested resampling:\n#> # outer: 10-fold cross-validation repeated 5 times\n#> # inner: Bootstrap sampling\n#> # A tibble: 50 × 4\n#> splits id id2 inner_resamples\n#> \n#> 1 Repeat1 Fold01 \n#> 2 Repeat1 Fold02 \n#> 3 Repeat1 Fold03 \n#> 4 Repeat1 Fold04 \n#> 5 Repeat1 Fold05 \n#> 6 Repeat1 Fold06 \n#> 7 Repeat1 Fold07 \n#> 8 Repeat1 Fold08 \n#> 9 Repeat1 Fold09 \n#> 10 Repeat1 Fold10 \n#> # ℹ 40 more rows\n```\n:::\n\n\n\n\nThe splitting information for each resample is contained in the `split` objects. Focusing on the second fold of the first repeat:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$splits[[2]]\n#> \n#> <90/10/100>\n```\n:::\n\n\n\n\n`<90/10/100>` indicates the number of observations in the analysis set, assessment set, and the original data.\n\nEach element of `inner_resamples` has its own tibble with the bootstrapping splits.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$inner_resamples[[5]]\n#> # Bootstrap sampling \n#> # A tibble: 25 × 2\n#> splits id \n#> \n#> 1 Bootstrap01\n#> 2 Bootstrap02\n#> 3 Bootstrap03\n#> 4 Bootstrap04\n#> 5 Bootstrap05\n#> 6 Bootstrap06\n#> 7 Bootstrap07\n#> 8 Bootstrap08\n#> 9 Bootstrap09\n#> 10 Bootstrap10\n#> # ℹ 15 more rows\n```\n:::\n\n\n\n\nThese are self-contained, meaning that the bootstrap sample is aware that it is a sample of a specific 90% of the data:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$inner_resamples[[5]]$splits[[1]]\n#> \n#> <90/31/90>\n```\n:::\n\n\n\n\nTo start, we need to define how the model will be created and measured. Let's use a radial basis support vector machine model via the function `kernlab::ksvm`. This model is generally considered to have *two* tuning parameters: the SVM cost value and the kernel parameter `sigma`. For illustration purposes here, only the cost value will be tuned and the function `kernlab::sigest` will be used to estimate `sigma` during each model fit. This is automatically done by `ksvm`.\n\nAfter the model is fit to the analysis set, the root-mean squared error (RMSE) is computed on the assessment set. **One important note:** for this model, it is critical to center and scale the predictors before computing dot products. We don't do this operation here because `mlbench.friedman1` simulates all of the predictors to be standardized uniform random variables.\n\nOur function to fit the model and compute the RMSE is:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(kernlab)\n\n# `object` will be an `rsplit` object from our `results` tibble\n# `cost` is the tuning parameter\nsvm_rmse <- function(object, cost = 1) {\n y_col <- ncol(object$data)\n mod <- \n svm_rbf(mode = \"regression\", cost = cost) %>% \n set_engine(\"kernlab\") %>% \n fit(y ~ ., data = analysis(object))\n \n holdout_pred <- \n predict(mod, assessment(object) %>% dplyr::select(-y)) %>% \n bind_cols(assessment(object) %>% dplyr::select(y))\n rmse(holdout_pred, truth = y, estimate = .pred)$.estimate\n}\n\n# In some case, we want to parameterize the function over the tuning parameter:\nrmse_wrapper <- function(cost, object) svm_rmse(object, cost)\n```\n:::\n\n\n\n\nFor the nested resampling, a model needs to be fit for each tuning parameter and each bootstrap split. To do this, create a wrapper:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# `object` will be an `rsplit` object for the bootstrap samples\ntune_over_cost <- function(object) {\n tibble(cost = 2 ^ seq(-2, 8, by = 1)) %>% \n mutate(RMSE = map_dbl(cost, rmse_wrapper, object = object))\n}\n```\n:::\n\n\n\n\nSince this will be called across the set of outer cross-validation splits, another wrapper is required:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# `object` is an `rsplit` object in `results$inner_resamples` \nsummarize_tune_results <- function(object) {\n # Return row-bound tibble that has the 25 bootstrap results\n map_df(object$splits, tune_over_cost) %>%\n # For each value of the tuning parameter, compute the \n # average RMSE which is the inner bootstrap estimate. \n group_by(cost) %>%\n summarize(mean_RMSE = mean(RMSE, na.rm = TRUE),\n n = length(RMSE),\n .groups = \"drop\")\n}\n```\n:::\n\n\n\n\nNow that those functions are defined, we can execute all the inner resampling loops:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ntuning_results <- map(results$inner_resamples, summarize_tune_results) \n```\n:::\n\n\n\n\nAlternatively, since these computations can be run in parallel, we can use the furrr package. Instead of using `map()`, the function `future_map()` parallelizes the iterations using the [future package](https://cran.r-project.org/web/packages/future/vignettes/future-1-overview.html). The `multisession` plan uses the local cores to process the inner resampling loop. The end results are the same as the sequential computations.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(furrr)\nplan(multisession)\n\ntuning_results <- future_map(results$inner_resamples, summarize_tune_results) \n```\n:::\n\n\n\n\nThe object `tuning_results` is a list of data frames for each of the 50 outer resamples.\n\nLet's make a plot of the averaged results to see what the relationship is between the RMSE and the tuning parameters for each of the inner bootstrapping operations:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(scales)\n\npooled_inner <- tuning_results %>% bind_rows\n\nbest_cost <- function(dat) dat[which.min(dat$mean_RMSE),]\n\np <- \n ggplot(pooled_inner, aes(x = cost, y = mean_RMSE)) + \n scale_x_continuous(trans = 'log2') +\n xlab(\"SVM Cost\") + ylab(\"Inner RMSE\")\n\nfor (i in 1:length(tuning_results))\n p <- p +\n geom_line(data = tuning_results[[i]], alpha = .2) +\n geom_point(data = best_cost(tuning_results[[i]]), pch = 16, alpha = 3/4)\n\np <- p + geom_smooth(data = pooled_inner, se = FALSE)\np\n```\n\n::: {.cell-output-display}\n![](figs/rmse-plot-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nEach gray line is a separate bootstrap resampling curve created from a different 90% of the data. The blue line is a LOESS smooth of all the results pooled together.\n\nTo determine the best parameter estimate for each of the outer resampling iterations:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncost_vals <- \n tuning_results %>% \n map_df(best_cost) %>% \n select(cost)\n\nresults <- \n bind_cols(results, cost_vals) %>% \n mutate(cost = factor(cost, levels = paste(2 ^ seq(-2, 8, by = 1))))\n\nggplot(results, aes(x = cost)) + \n geom_bar() + \n xlab(\"SVM Cost\") + \n scale_x_discrete(drop = FALSE)\n```\n\n::: {.cell-output-display}\n![](figs/choose-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nMost of the resamples produced an optimal cost value of 2.0, but the distribution is right-skewed due to the flat trend in the resampling profile once the cost value becomes 10 or larger.\n\nNow that we have these estimates, we can compute the outer resampling results for each of the 50 splits using the corresponding tuning parameter value:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults <- \n results %>% \n mutate(RMSE = map2_dbl(splits, cost, svm_rmse))\n\nsummary(results$RMSE)\n#> Min. 1st Qu. Median Mean 3rd Qu. Max. \n#> 1.574 2.095 2.668 2.683 3.252 4.350\n```\n:::\n\n\n\n\nThe estimated RMSE for the model tuning process is 2.68.\n\nWhat is the RMSE estimate for the non-nested procedure when only the outer resampling method is used? For each cost value in the tuning grid, 50 SVM models are fit and their RMSE values are averaged. The table of cost values and mean RMSE estimates is used to determine the best cost value. The associated RMSE is the biased estimate.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnot_nested <- \n map(results$splits, tune_over_cost) %>%\n bind_rows\n\nouter_summary <- not_nested %>% \n group_by(cost) %>% \n summarize(outer_RMSE = mean(RMSE), n = length(RMSE))\n\nouter_summary\n#> # A tibble: 11 × 3\n#> cost outer_RMSE n\n#> \n#> 1 0.25 3.54 50\n#> 2 0.5 3.11 50\n#> 3 1 2.77 50\n#> 4 2 2.62 50\n#> 5 4 2.65 50\n#> 6 8 2.75 50\n#> 7 16 2.82 50\n#> 8 32 2.82 50\n#> 9 64 2.83 50\n#> 10 128 2.83 50\n#> 11 256 2.82 50\n\nggplot(outer_summary, aes(x = cost, y = outer_RMSE)) + \n geom_point() + \n geom_line() + \n scale_x_continuous(trans = 'log2') +\n xlab(\"SVM Cost\") + ylab(\"RMSE\")\n```\n\n::: {.cell-output-display}\n![](figs/not-nested-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThe non-nested procedure estimates the RMSE to be 2.62. Both estimates are fairly close.\n\nThe approximately true RMSE for an SVM model with a cost value of 2.0 can be approximated with the large sample that was simulated at the beginning.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nfinalModel <- ksvm(y ~ ., data = train_dat, C = 2)\nlarge_pred <- predict(finalModel, large_dat[, -ncol(large_dat)])\nsqrt(mean((large_dat$y - large_pred) ^ 2, na.rm = TRUE))\n#> [1] 2.712059\n```\n:::\n\n\n\n\nThe nested procedure produces a closer estimate to the approximate truth but the non-nested estimate is very similar.\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> furrr 0.3.1 2022-08-15 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> kernlab 0.9-33 2024-08-13 CRAN (R 4.4.0)\n#> mlbench 2.1-6 2024-12-30 CRAN (R 4.4.1)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> scales 1.3.0 2023-11-28 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", "supporting": [], "filters": [ "rmarkdown/pagebreak.lua" diff --git a/learn/work/nested-resampling/figs/choose-1.svg b/learn/work/nested-resampling/figs/choose-1.svg index ddca71d0..9a3c3761 100644 --- a/learn/work/nested-resampling/figs/choose-1.svg +++ b/learn/work/nested-resampling/figs/choose-1.svg @@ -30,12 +30,13 @@ - - - + + + - - + + + @@ -48,22 +49,23 @@ - - - - - - - + + + + + + 0 -10 -20 +10 +20 +30 - - + + + diff --git a/learn/work/nested-resampling/figs/rmse-plot-1.svg b/learn/work/nested-resampling/figs/rmse-plot-1.svg index 91e53428..14125612 100644 --- a/learn/work/nested-resampling/figs/rmse-plot-1.svg +++ b/learn/work/nested-resampling/figs/rmse-plot-1.svg @@ -30,126 +30,126 @@ - - - + + + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + -3.0 -3.5 - - +3.0 +3.5 + + diff --git a/learn/work/nested-resampling/index.html.md b/learn/work/nested-resampling/index.html.md index dd9ff905..dca8f49a 100644 --- a/learn/work/nested-resampling/index.html.md +++ b/learn/work/nested-resampling/index.html.md @@ -293,11 +293,11 @@ results <- summary(results$RMSE) #> Min. 1st Qu. Median Mean 3rd Qu. Max. -#> 1.574 2.095 2.688 2.697 3.265 4.350 +#> 1.574 2.095 2.668 2.683 3.252 4.350 ``` ::: -The estimated RMSE for the model tuning process is 2.7. +The estimated RMSE for the model tuning process is 2.68. What is the RMSE estimate for the non-nested procedure when only the outer resampling method is used? For each cost value in the tuning grid, 50 SVM models are fit and their RMSE values are averaged. The table of cost values and mean RMSE estimates is used to determine the best cost value. The associated RMSE is the biased estimate. diff --git a/learn/work/nested-resampling/index.qmd b/learn/work/nested-resampling/index.qmd index fa067167..37e67577 100644 --- a/learn/work/nested-resampling/index.qmd +++ b/learn/work/nested-resampling/index.qmd @@ -28,6 +28,8 @@ library(mlbench) library(kernlab) library(furrr) +set.seed(1234) + pkgs <- c("tidymodels", "scales", "mlbench", "kernlab", "furrr") theme_set(theme_bw() + theme(legend.position = "top")) diff --git a/learn/work/nested-resampling/index.rmarkdown b/learn/work/nested-resampling/index.rmarkdown new file mode 100644 index 00000000..7de0b15f --- /dev/null +++ b/learn/work/nested-resampling/index.rmarkdown @@ -0,0 +1,371 @@ +--- +title: "Nested resampling" +categories: + - SVMs +type: learn-subsection +weight: 2 +description: | + Estimate the best hyperparameters for a model using nested resampling. +toc: true +toc-depth: 2 +include-after-body: ../../../resources.html +--- + +```{r} +#| label: "setup" +#| include: false +#| message: false +#| warning: false +source(here::here("common.R")) +``` + +```{r} +#| label: "load" +#| include: false +library(tidymodels) +library(scales) +library(mlbench) +library(kernlab) +library(furrr) + +set.seed(1234) + +pkgs <- c("tidymodels", "scales", "mlbench", "kernlab", "furrr") + +theme_set(theme_bw() + theme(legend.position = "top")) +``` + + + +## Introduction + +`r article_req_pkgs(pkgs)` + +In this article, we discuss an alternative method for evaluating and tuning models, called [nested resampling](https://scholar.google.com/scholar?hl=en&as_sdt=0%2C7&q=%22nested+resampling%22+inner+outer&btnG=). While it is more computationally taxing and challenging to implement than other resampling methods, it has the potential to produce better estimates of model performance. + +## Resampling models + +A typical scheme for splitting the data when developing a predictive model is to create an initial split of the data into a training and test set. If resampling is used, it is executed on the training set. A series of binary splits is created. In rsample, we use the term *analysis set* for the data that are used to fit the model and the term *assessment set* for the set used to compute performance: + + + +```{r} +#| label: "resampling-fig" +#| echo: false +#| fig-align: center +#| out-width: "70%" +knitr::include_graphics("img/resampling.svg") +``` + + + +A common method for tuning models is [grid search](/learn/work/tune-svm/) where a candidate set of tuning parameters is created. The full set of models for every combination of the tuning parameter grid and the resamples is fitted. Each time, the assessment data are used to measure performance and the average value is determined for each tuning parameter. + +The potential problem is that once we pick the tuning parameter associated with the best performance, this performance value is usually quoted as the performance of the model. There is serious potential for *optimization bias* since we use the same data to tune the model and to assess performance. This would result in an optimistic estimate of performance. + +Nested resampling uses an additional layer of resampling that separates the tuning activities from the process used to estimate the efficacy of the model. An *outer* resampling scheme is used and, for every split in the outer resample, another full set of resampling splits are created on the original analysis set. For example, if 10-fold cross-validation is used on the outside and 5-fold cross-validation on the inside, a total of 500 models will be fit. The parameter tuning will be conducted 10 times and the best parameters are determined from the average of the 5 assessment sets. This process occurs 10 times. + +Once the tuning results are complete, a model is fit to each of the outer resampling splits using the best parameter associated with that resample. The average of the outer method's assessment sets are a unbiased estimate of the model. + +We will simulate some regression data to illustrate the methods. The mlbench package has a function `mlbench::mlbench.friedman1()` that can simulate a complex regression data structure from the [original MARS publication](https://scholar.google.com/scholar?hl=en&q=%22Multivariate+adaptive+regression+splines%22&btnG=&as_sdt=1%2C7&as_sdtp=). A training set size of 100 data points are generated as well as a large set that will be used to characterize how well the resampling procedure performed. + + + +```{r} +#| label: "sim-data" +library(mlbench) +sim_data <- function(n) { + tmp <- mlbench.friedman1(n, sd = 1) + tmp <- cbind(tmp$x, tmp$y) + tmp <- as.data.frame(tmp) + names(tmp)[ncol(tmp)] <- "y" + tmp +} + +set.seed(9815) +train_dat <- sim_data(100) +large_dat <- sim_data(10^5) +``` + + + +## Nested resampling + +To get started, the types of resampling methods need to be specified. This isn't a large data set, so 5 repeats of 10-fold cross validation will be used as the *outer* resampling method for generating the estimate of overall performance. To tune the model, it would be good to have precise estimates for each of the values of the tuning parameter so let's use 25 iterations of the bootstrap. This means that there will eventually be `5 * 10 * 25 = 1250` models that are fit to the data *per tuning parameter*. These models will be discarded once the performance of the model has been quantified. + +To create the tibble with the resampling specifications: + + + +```{r} +#| label: "tibble-gen" +library(tidymodels) +results <- nested_cv(train_dat, + outside = vfold_cv(repeats = 5), + inside = bootstraps(times = 25)) +results +``` + + + +The splitting information for each resample is contained in the `split` objects. Focusing on the second fold of the first repeat: + + + +```{r} +#| label: "split-example" +results$splits[[2]] +``` + + + +`<90/10/100>` indicates the number of observations in the analysis set, assessment set, and the original data. + +Each element of `inner_resamples` has its own tibble with the bootstrapping splits. + + + +```{r} +#| label: "inner-splits" +results$inner_resamples[[5]] +``` + + + +These are self-contained, meaning that the bootstrap sample is aware that it is a sample of a specific 90% of the data: + + + +```{r} +#| label: "inner-boot-split" +results$inner_resamples[[5]]$splits[[1]] +``` + + + +To start, we need to define how the model will be created and measured. Let's use a radial basis support vector machine model via the function `kernlab::ksvm`. This model is generally considered to have *two* tuning parameters: the SVM cost value and the kernel parameter `sigma`. For illustration purposes here, only the cost value will be tuned and the function `kernlab::sigest` will be used to estimate `sigma` during each model fit. This is automatically done by `ksvm`. + +After the model is fit to the analysis set, the root-mean squared error (RMSE) is computed on the assessment set. **One important note:** for this model, it is critical to center and scale the predictors before computing dot products. We don't do this operation here because `mlbench.friedman1` simulates all of the predictors to be standardized uniform random variables. + +Our function to fit the model and compute the RMSE is: + + + +```{r} +#| label: "rmse-func" +library(kernlab) + +# `object` will be an `rsplit` object from our `results` tibble +# `cost` is the tuning parameter +svm_rmse <- function(object, cost = 1) { + y_col <- ncol(object$data) + mod <- + svm_rbf(mode = "regression", cost = cost) %>% + set_engine("kernlab") %>% + fit(y ~ ., data = analysis(object)) + + holdout_pred <- + predict(mod, assessment(object) %>% dplyr::select(-y)) %>% + bind_cols(assessment(object) %>% dplyr::select(y)) + rmse(holdout_pred, truth = y, estimate = .pred)$.estimate +} + +# In some case, we want to parameterize the function over the tuning parameter: +rmse_wrapper <- function(cost, object) svm_rmse(object, cost) +``` + + + +For the nested resampling, a model needs to be fit for each tuning parameter and each bootstrap split. To do this, create a wrapper: + + + +```{r} +#| label: "inner-tune-func" +# `object` will be an `rsplit` object for the bootstrap samples +tune_over_cost <- function(object) { + tibble(cost = 2 ^ seq(-2, 8, by = 1)) %>% + mutate(RMSE = map_dbl(cost, rmse_wrapper, object = object)) +} +``` + + + +Since this will be called across the set of outer cross-validation splits, another wrapper is required: + + + +```{r} +#| label: "inner-func" +# `object` is an `rsplit` object in `results$inner_resamples` +summarize_tune_results <- function(object) { + # Return row-bound tibble that has the 25 bootstrap results + map_df(object$splits, tune_over_cost) %>% + # For each value of the tuning parameter, compute the + # average RMSE which is the inner bootstrap estimate. + group_by(cost) %>% + summarize(mean_RMSE = mean(RMSE, na.rm = TRUE), + n = length(RMSE), + .groups = "drop") +} +``` + + + +Now that those functions are defined, we can execute all the inner resampling loops: + + + +```{r} +#| label: "inner-runs" +#| eval: false +tuning_results <- map(results$inner_resamples, summarize_tune_results) +``` + + + +Alternatively, since these computations can be run in parallel, we can use the furrr package. Instead of using `map()`, the function `future_map()` parallelizes the iterations using the [future package](https://cran.r-project.org/web/packages/future/vignettes/future-1-overview.html). The `multisession` plan uses the local cores to process the inner resampling loop. The end results are the same as the sequential computations. + + + +```{r} +#| label: "inner-runs-parallel" +#| warning: false +library(furrr) +plan(multisession) + +tuning_results <- future_map(results$inner_resamples, summarize_tune_results) +``` + + + +The object `tuning_results` is a list of data frames for each of the 50 outer resamples. + +Let's make a plot of the averaged results to see what the relationship is between the RMSE and the tuning parameters for each of the inner bootstrapping operations: + + + +```{r} +#| label: "rmse-plot" +#| fig-height: 4 +#| message: false +library(scales) + +pooled_inner <- tuning_results %>% bind_rows + +best_cost <- function(dat) dat[which.min(dat$mean_RMSE),] + +p <- + ggplot(pooled_inner, aes(x = cost, y = mean_RMSE)) + + scale_x_continuous(trans = 'log2') + + xlab("SVM Cost") + ylab("Inner RMSE") + +for (i in 1:length(tuning_results)) + p <- p + + geom_line(data = tuning_results[[i]], alpha = .2) + + geom_point(data = best_cost(tuning_results[[i]]), pch = 16, alpha = 3/4) + +p <- p + geom_smooth(data = pooled_inner, se = FALSE) +p +``` + + + +Each gray line is a separate bootstrap resampling curve created from a different 90% of the data. The blue line is a LOESS smooth of all the results pooled together. + +To determine the best parameter estimate for each of the outer resampling iterations: + + + +```{r} +#| label: "choose" +#| fig-height: 4 +cost_vals <- + tuning_results %>% + map_df(best_cost) %>% + select(cost) + +results <- + bind_cols(results, cost_vals) %>% + mutate(cost = factor(cost, levels = paste(2 ^ seq(-2, 8, by = 1)))) + +ggplot(results, aes(x = cost)) + + geom_bar() + + xlab("SVM Cost") + + scale_x_discrete(drop = FALSE) +``` + + + +Most of the resamples produced an optimal cost value of 2.0, but the distribution is right-skewed due to the flat trend in the resampling profile once the cost value becomes 10 or larger. + +Now that we have these estimates, we can compute the outer resampling results for each of the `r nrow(results)` splits using the corresponding tuning parameter value: + + + +```{r} +#| label: "run-out-r" +results <- + results %>% + mutate(RMSE = map2_dbl(splits, cost, svm_rmse)) + +summary(results$RMSE) +``` + + + +The estimated RMSE for the model tuning process is `r round(mean(results$RMSE), 2)`. + +What is the RMSE estimate for the non-nested procedure when only the outer resampling method is used? For each cost value in the tuning grid, `r nrow(results)` SVM models are fit and their RMSE values are averaged. The table of cost values and mean RMSE estimates is used to determine the best cost value. The associated RMSE is the biased estimate. + + + +```{r} +#| label: "not-nested" +#| fig-height: 4 +not_nested <- + map(results$splits, tune_over_cost) %>% + bind_rows + +outer_summary <- not_nested %>% + group_by(cost) %>% + summarize(outer_RMSE = mean(RMSE), n = length(RMSE)) + +outer_summary + +ggplot(outer_summary, aes(x = cost, y = outer_RMSE)) + + geom_point() + + geom_line() + + scale_x_continuous(trans = 'log2') + + xlab("SVM Cost") + ylab("RMSE") +``` + + + +The non-nested procedure estimates the RMSE to be `r round(min(outer_summary$outer_RMSE), 2)`. Both estimates are fairly close. + +The approximately true RMSE for an SVM model with a cost value of 2.0 can be approximated with the large sample that was simulated at the beginning. + + + +```{r} +#| label: "large-sample-estimate" +finalModel <- ksvm(y ~ ., data = train_dat, C = 2) +large_pred <- predict(finalModel, large_dat[, -ncol(large_dat)]) +sqrt(mean((large_dat$y - large_pred) ^ 2, na.rm = TRUE)) +``` + + + +The nested procedure produces a closer estimate to the approximate truth but the non-nested estimate is very similar. + +## Session information {#session-info} + + + +```{r} +#| label: "si" +#| echo: false +small_session(pkgs) +``` + From d1bbd11f102d7efde4ba345b106cf8e3de81f98b Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 24 Mar 2025 12:41:15 -0700 Subject: [PATCH 6/7] set.seed in learn/work/nested-resampling --- .../index/execute-results/html.json | 4 +- .../work/nested-resampling/figs/choose-1.svg | 7 +- .../nested-resampling/figs/not-nested-1.svg | 60 +-- .../nested-resampling/figs/rmse-plot-1.svg | 218 +++++----- learn/work/nested-resampling/index.html.md | 25 +- learn/work/nested-resampling/index.qmd | 1 + learn/work/nested-resampling/index.rmarkdown | 371 ------------------ 7 files changed, 157 insertions(+), 529 deletions(-) delete mode 100644 learn/work/nested-resampling/index.rmarkdown diff --git a/_freeze/learn/work/nested-resampling/index/execute-results/html.json b/_freeze/learn/work/nested-resampling/index/execute-results/html.json index 2fc7e86b..ac703001 100644 --- a/_freeze/learn/work/nested-resampling/index/execute-results/html.json +++ b/_freeze/learn/work/nested-resampling/index/execute-results/html.json @@ -1,8 +1,8 @@ { - "hash": "07e8f62215dc10f0d38750bd159b4bc0", + "hash": "fa56727b4a316d11d23834f0388c1cfa", "result": { "engine": "knitr", - "markdown": "---\ntitle: \"Nested resampling\"\ncategories:\n - SVMs\ntype: learn-subsection\nweight: 2\ndescription: | \n Estimate the best hyperparameters for a model using nested resampling.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: furrr, kernlab, mlbench, scales, and tidymodels.\n\nIn this article, we discuss an alternative method for evaluating and tuning models, called [nested resampling](https://scholar.google.com/scholar?hl=en&as_sdt=0%2C7&q=%22nested+resampling%22+inner+outer&btnG=). While it is more computationally taxing and challenging to implement than other resampling methods, it has the potential to produce better estimates of model performance.\n\n## Resampling models\n\nA typical scheme for splitting the data when developing a predictive model is to create an initial split of the data into a training and test set. If resampling is used, it is executed on the training set. A series of binary splits is created. In rsample, we use the term *analysis set* for the data that are used to fit the model and the term *assessment set* for the set used to compute performance:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n::: {.cell-output-display}\n![](img/resampling.svg){fig-align='center' width=70%}\n:::\n:::\n\n\n\n\nA common method for tuning models is [grid search](/learn/work/tune-svm/) where a candidate set of tuning parameters is created. The full set of models for every combination of the tuning parameter grid and the resamples is fitted. Each time, the assessment data are used to measure performance and the average value is determined for each tuning parameter.\n\nThe potential problem is that once we pick the tuning parameter associated with the best performance, this performance value is usually quoted as the performance of the model. There is serious potential for *optimization bias* since we use the same data to tune the model and to assess performance. This would result in an optimistic estimate of performance.\n\nNested resampling uses an additional layer of resampling that separates the tuning activities from the process used to estimate the efficacy of the model. An *outer* resampling scheme is used and, for every split in the outer resample, another full set of resampling splits are created on the original analysis set. For example, if 10-fold cross-validation is used on the outside and 5-fold cross-validation on the inside, a total of 500 models will be fit. The parameter tuning will be conducted 10 times and the best parameters are determined from the average of the 5 assessment sets. This process occurs 10 times.\n\nOnce the tuning results are complete, a model is fit to each of the outer resampling splits using the best parameter associated with that resample. The average of the outer method's assessment sets are a unbiased estimate of the model.\n\nWe will simulate some regression data to illustrate the methods. The mlbench package has a function `mlbench::mlbench.friedman1()` that can simulate a complex regression data structure from the [original MARS publication](https://scholar.google.com/scholar?hl=en&q=%22Multivariate+adaptive+regression+splines%22&btnG=&as_sdt=1%2C7&as_sdtp=). A training set size of 100 data points are generated as well as a large set that will be used to characterize how well the resampling procedure performed.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(mlbench)\nsim_data <- function(n) {\n tmp <- mlbench.friedman1(n, sd = 1)\n tmp <- cbind(tmp$x, tmp$y)\n tmp <- as.data.frame(tmp)\n names(tmp)[ncol(tmp)] <- \"y\"\n tmp\n}\n\nset.seed(9815)\ntrain_dat <- sim_data(100)\nlarge_dat <- sim_data(10^5)\n```\n:::\n\n\n\n\n## Nested resampling\n\nTo get started, the types of resampling methods need to be specified. This isn't a large data set, so 5 repeats of 10-fold cross validation will be used as the *outer* resampling method for generating the estimate of overall performance. To tune the model, it would be good to have precise estimates for each of the values of the tuning parameter so let's use 25 iterations of the bootstrap. This means that there will eventually be `5 * 10 * 25 = 1250` models that are fit to the data *per tuning parameter*. These models will be discarded once the performance of the model has been quantified.\n\nTo create the tibble with the resampling specifications:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nresults <- nested_cv(train_dat, \n outside = vfold_cv(repeats = 5), \n inside = bootstraps(times = 25))\nresults\n#> # Nested resampling:\n#> # outer: 10-fold cross-validation repeated 5 times\n#> # inner: Bootstrap sampling\n#> # A tibble: 50 × 4\n#> splits id id2 inner_resamples\n#> \n#> 1 Repeat1 Fold01 \n#> 2 Repeat1 Fold02 \n#> 3 Repeat1 Fold03 \n#> 4 Repeat1 Fold04 \n#> 5 Repeat1 Fold05 \n#> 6 Repeat1 Fold06 \n#> 7 Repeat1 Fold07 \n#> 8 Repeat1 Fold08 \n#> 9 Repeat1 Fold09 \n#> 10 Repeat1 Fold10 \n#> # ℹ 40 more rows\n```\n:::\n\n\n\n\nThe splitting information for each resample is contained in the `split` objects. Focusing on the second fold of the first repeat:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$splits[[2]]\n#> \n#> <90/10/100>\n```\n:::\n\n\n\n\n`<90/10/100>` indicates the number of observations in the analysis set, assessment set, and the original data.\n\nEach element of `inner_resamples` has its own tibble with the bootstrapping splits.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$inner_resamples[[5]]\n#> # Bootstrap sampling \n#> # A tibble: 25 × 2\n#> splits id \n#> \n#> 1 Bootstrap01\n#> 2 Bootstrap02\n#> 3 Bootstrap03\n#> 4 Bootstrap04\n#> 5 Bootstrap05\n#> 6 Bootstrap06\n#> 7 Bootstrap07\n#> 8 Bootstrap08\n#> 9 Bootstrap09\n#> 10 Bootstrap10\n#> # ℹ 15 more rows\n```\n:::\n\n\n\n\nThese are self-contained, meaning that the bootstrap sample is aware that it is a sample of a specific 90% of the data:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$inner_resamples[[5]]$splits[[1]]\n#> \n#> <90/31/90>\n```\n:::\n\n\n\n\nTo start, we need to define how the model will be created and measured. Let's use a radial basis support vector machine model via the function `kernlab::ksvm`. This model is generally considered to have *two* tuning parameters: the SVM cost value and the kernel parameter `sigma`. For illustration purposes here, only the cost value will be tuned and the function `kernlab::sigest` will be used to estimate `sigma` during each model fit. This is automatically done by `ksvm`.\n\nAfter the model is fit to the analysis set, the root-mean squared error (RMSE) is computed on the assessment set. **One important note:** for this model, it is critical to center and scale the predictors before computing dot products. We don't do this operation here because `mlbench.friedman1` simulates all of the predictors to be standardized uniform random variables.\n\nOur function to fit the model and compute the RMSE is:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(kernlab)\n\n# `object` will be an `rsplit` object from our `results` tibble\n# `cost` is the tuning parameter\nsvm_rmse <- function(object, cost = 1) {\n y_col <- ncol(object$data)\n mod <- \n svm_rbf(mode = \"regression\", cost = cost) %>% \n set_engine(\"kernlab\") %>% \n fit(y ~ ., data = analysis(object))\n \n holdout_pred <- \n predict(mod, assessment(object) %>% dplyr::select(-y)) %>% \n bind_cols(assessment(object) %>% dplyr::select(y))\n rmse(holdout_pred, truth = y, estimate = .pred)$.estimate\n}\n\n# In some case, we want to parameterize the function over the tuning parameter:\nrmse_wrapper <- function(cost, object) svm_rmse(object, cost)\n```\n:::\n\n\n\n\nFor the nested resampling, a model needs to be fit for each tuning parameter and each bootstrap split. To do this, create a wrapper:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# `object` will be an `rsplit` object for the bootstrap samples\ntune_over_cost <- function(object) {\n tibble(cost = 2 ^ seq(-2, 8, by = 1)) %>% \n mutate(RMSE = map_dbl(cost, rmse_wrapper, object = object))\n}\n```\n:::\n\n\n\n\nSince this will be called across the set of outer cross-validation splits, another wrapper is required:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# `object` is an `rsplit` object in `results$inner_resamples` \nsummarize_tune_results <- function(object) {\n # Return row-bound tibble that has the 25 bootstrap results\n map_df(object$splits, tune_over_cost) %>%\n # For each value of the tuning parameter, compute the \n # average RMSE which is the inner bootstrap estimate. \n group_by(cost) %>%\n summarize(mean_RMSE = mean(RMSE, na.rm = TRUE),\n n = length(RMSE),\n .groups = \"drop\")\n}\n```\n:::\n\n\n\n\nNow that those functions are defined, we can execute all the inner resampling loops:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ntuning_results <- map(results$inner_resamples, summarize_tune_results) \n```\n:::\n\n\n\n\nAlternatively, since these computations can be run in parallel, we can use the furrr package. Instead of using `map()`, the function `future_map()` parallelizes the iterations using the [future package](https://cran.r-project.org/web/packages/future/vignettes/future-1-overview.html). The `multisession` plan uses the local cores to process the inner resampling loop. The end results are the same as the sequential computations.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(furrr)\nplan(multisession)\n\ntuning_results <- future_map(results$inner_resamples, summarize_tune_results) \n```\n:::\n\n\n\n\nThe object `tuning_results` is a list of data frames for each of the 50 outer resamples.\n\nLet's make a plot of the averaged results to see what the relationship is between the RMSE and the tuning parameters for each of the inner bootstrapping operations:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(scales)\n\npooled_inner <- tuning_results %>% bind_rows\n\nbest_cost <- function(dat) dat[which.min(dat$mean_RMSE),]\n\np <- \n ggplot(pooled_inner, aes(x = cost, y = mean_RMSE)) + \n scale_x_continuous(trans = 'log2') +\n xlab(\"SVM Cost\") + ylab(\"Inner RMSE\")\n\nfor (i in 1:length(tuning_results))\n p <- p +\n geom_line(data = tuning_results[[i]], alpha = .2) +\n geom_point(data = best_cost(tuning_results[[i]]), pch = 16, alpha = 3/4)\n\np <- p + geom_smooth(data = pooled_inner, se = FALSE)\np\n```\n\n::: {.cell-output-display}\n![](figs/rmse-plot-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nEach gray line is a separate bootstrap resampling curve created from a different 90% of the data. The blue line is a LOESS smooth of all the results pooled together.\n\nTo determine the best parameter estimate for each of the outer resampling iterations:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncost_vals <- \n tuning_results %>% \n map_df(best_cost) %>% \n select(cost)\n\nresults <- \n bind_cols(results, cost_vals) %>% \n mutate(cost = factor(cost, levels = paste(2 ^ seq(-2, 8, by = 1))))\n\nggplot(results, aes(x = cost)) + \n geom_bar() + \n xlab(\"SVM Cost\") + \n scale_x_discrete(drop = FALSE)\n```\n\n::: {.cell-output-display}\n![](figs/choose-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nMost of the resamples produced an optimal cost value of 2.0, but the distribution is right-skewed due to the flat trend in the resampling profile once the cost value becomes 10 or larger.\n\nNow that we have these estimates, we can compute the outer resampling results for each of the 50 splits using the corresponding tuning parameter value:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults <- \n results %>% \n mutate(RMSE = map2_dbl(splits, cost, svm_rmse))\n\nsummary(results$RMSE)\n#> Min. 1st Qu. Median Mean 3rd Qu. Max. \n#> 1.574 2.095 2.668 2.683 3.252 4.350\n```\n:::\n\n\n\n\nThe estimated RMSE for the model tuning process is 2.68.\n\nWhat is the RMSE estimate for the non-nested procedure when only the outer resampling method is used? For each cost value in the tuning grid, 50 SVM models are fit and their RMSE values are averaged. The table of cost values and mean RMSE estimates is used to determine the best cost value. The associated RMSE is the biased estimate.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnot_nested <- \n map(results$splits, tune_over_cost) %>%\n bind_rows\n\nouter_summary <- not_nested %>% \n group_by(cost) %>% \n summarize(outer_RMSE = mean(RMSE), n = length(RMSE))\n\nouter_summary\n#> # A tibble: 11 × 3\n#> cost outer_RMSE n\n#> \n#> 1 0.25 3.54 50\n#> 2 0.5 3.11 50\n#> 3 1 2.77 50\n#> 4 2 2.62 50\n#> 5 4 2.65 50\n#> 6 8 2.75 50\n#> 7 16 2.82 50\n#> 8 32 2.82 50\n#> 9 64 2.83 50\n#> 10 128 2.83 50\n#> 11 256 2.82 50\n\nggplot(outer_summary, aes(x = cost, y = outer_RMSE)) + \n geom_point() + \n geom_line() + \n scale_x_continuous(trans = 'log2') +\n xlab(\"SVM Cost\") + ylab(\"RMSE\")\n```\n\n::: {.cell-output-display}\n![](figs/not-nested-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThe non-nested procedure estimates the RMSE to be 2.62. Both estimates are fairly close.\n\nThe approximately true RMSE for an SVM model with a cost value of 2.0 can be approximated with the large sample that was simulated at the beginning.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nfinalModel <- ksvm(y ~ ., data = train_dat, C = 2)\nlarge_pred <- predict(finalModel, large_dat[, -ncol(large_dat)])\nsqrt(mean((large_dat$y - large_pred) ^ 2, na.rm = TRUE))\n#> [1] 2.712059\n```\n:::\n\n\n\n\nThe nested procedure produces a closer estimate to the approximate truth but the non-nested estimate is very similar.\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> furrr 0.3.1 2022-08-15 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> kernlab 0.9-33 2024-08-13 CRAN (R 4.4.0)\n#> mlbench 2.1-6 2024-12-30 CRAN (R 4.4.1)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> scales 1.3.0 2023-11-28 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", + "markdown": "---\ntitle: \"Nested resampling\"\ncategories:\n - SVMs\ntype: learn-subsection\nweight: 2\ndescription: | \n Estimate the best hyperparameters for a model using nested resampling.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: furrr, kernlab, mlbench, scales, and tidymodels.\n\nIn this article, we discuss an alternative method for evaluating and tuning models, called [nested resampling](https://scholar.google.com/scholar?hl=en&as_sdt=0%2C7&q=%22nested+resampling%22+inner+outer&btnG=). While it is more computationally taxing and challenging to implement than other resampling methods, it has the potential to produce better estimates of model performance.\n\n## Resampling models\n\nA typical scheme for splitting the data when developing a predictive model is to create an initial split of the data into a training and test set. If resampling is used, it is executed on the training set. A series of binary splits is created. In rsample, we use the term *analysis set* for the data that are used to fit the model and the term *assessment set* for the set used to compute performance:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n::: {.cell-output-display}\n![](img/resampling.svg){fig-align='center' width=70%}\n:::\n:::\n\n\n\n\nA common method for tuning models is [grid search](/learn/work/tune-svm/) where a candidate set of tuning parameters is created. The full set of models for every combination of the tuning parameter grid and the resamples is fitted. Each time, the assessment data are used to measure performance and the average value is determined for each tuning parameter.\n\nThe potential problem is that once we pick the tuning parameter associated with the best performance, this performance value is usually quoted as the performance of the model. There is serious potential for *optimization bias* since we use the same data to tune the model and to assess performance. This would result in an optimistic estimate of performance.\n\nNested resampling uses an additional layer of resampling that separates the tuning activities from the process used to estimate the efficacy of the model. An *outer* resampling scheme is used and, for every split in the outer resample, another full set of resampling splits are created on the original analysis set. For example, if 10-fold cross-validation is used on the outside and 5-fold cross-validation on the inside, a total of 500 models will be fit. The parameter tuning will be conducted 10 times and the best parameters are determined from the average of the 5 assessment sets. This process occurs 10 times.\n\nOnce the tuning results are complete, a model is fit to each of the outer resampling splits using the best parameter associated with that resample. The average of the outer method's assessment sets are a unbiased estimate of the model.\n\nWe will simulate some regression data to illustrate the methods. The mlbench package has a function `mlbench::mlbench.friedman1()` that can simulate a complex regression data structure from the [original MARS publication](https://scholar.google.com/scholar?hl=en&q=%22Multivariate+adaptive+regression+splines%22&btnG=&as_sdt=1%2C7&as_sdtp=). A training set size of 100 data points are generated as well as a large set that will be used to characterize how well the resampling procedure performed.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(mlbench)\nsim_data <- function(n) {\n tmp <- mlbench.friedman1(n, sd = 1)\n tmp <- cbind(tmp$x, tmp$y)\n tmp <- as.data.frame(tmp)\n names(tmp)[ncol(tmp)] <- \"y\"\n tmp\n}\n\nset.seed(9815)\ntrain_dat <- sim_data(100)\nlarge_dat <- sim_data(10^5)\n```\n:::\n\n\n\n\n## Nested resampling\n\nTo get started, the types of resampling methods need to be specified. This isn't a large data set, so 5 repeats of 10-fold cross validation will be used as the *outer* resampling method for generating the estimate of overall performance. To tune the model, it would be good to have precise estimates for each of the values of the tuning parameter so let's use 25 iterations of the bootstrap. This means that there will eventually be `5 * 10 * 25 = 1250` models that are fit to the data *per tuning parameter*. These models will be discarded once the performance of the model has been quantified.\n\nTo create the tibble with the resampling specifications:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nresults <- nested_cv(train_dat, \n outside = vfold_cv(repeats = 5), \n inside = bootstraps(times = 25))\nresults\n#> # Nested resampling:\n#> # outer: 10-fold cross-validation repeated 5 times\n#> # inner: Bootstrap sampling\n#> # A tibble: 50 × 4\n#> splits id id2 inner_resamples\n#> \n#> 1 Repeat1 Fold01 \n#> 2 Repeat1 Fold02 \n#> 3 Repeat1 Fold03 \n#> 4 Repeat1 Fold04 \n#> 5 Repeat1 Fold05 \n#> 6 Repeat1 Fold06 \n#> 7 Repeat1 Fold07 \n#> 8 Repeat1 Fold08 \n#> 9 Repeat1 Fold09 \n#> 10 Repeat1 Fold10 \n#> # ℹ 40 more rows\n```\n:::\n\n\n\n\nThe splitting information for each resample is contained in the `split` objects. Focusing on the second fold of the first repeat:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$splits[[2]]\n#> \n#> <90/10/100>\n```\n:::\n\n\n\n\n`<90/10/100>` indicates the number of observations in the analysis set, assessment set, and the original data.\n\nEach element of `inner_resamples` has its own tibble with the bootstrapping splits.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$inner_resamples[[5]]\n#> # Bootstrap sampling \n#> # A tibble: 25 × 2\n#> splits id \n#> \n#> 1 Bootstrap01\n#> 2 Bootstrap02\n#> 3 Bootstrap03\n#> 4 Bootstrap04\n#> 5 Bootstrap05\n#> 6 Bootstrap06\n#> 7 Bootstrap07\n#> 8 Bootstrap08\n#> 9 Bootstrap09\n#> 10 Bootstrap10\n#> # ℹ 15 more rows\n```\n:::\n\n\n\n\nThese are self-contained, meaning that the bootstrap sample is aware that it is a sample of a specific 90% of the data:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults$inner_resamples[[5]]$splits[[1]]\n#> \n#> <90/31/90>\n```\n:::\n\n\n\n\nTo start, we need to define how the model will be created and measured. Let's use a radial basis support vector machine model via the function `kernlab::ksvm`. This model is generally considered to have *two* tuning parameters: the SVM cost value and the kernel parameter `sigma`. For illustration purposes here, only the cost value will be tuned and the function `kernlab::sigest` will be used to estimate `sigma` during each model fit. This is automatically done by `ksvm`.\n\nAfter the model is fit to the analysis set, the root-mean squared error (RMSE) is computed on the assessment set. **One important note:** for this model, it is critical to center and scale the predictors before computing dot products. We don't do this operation here because `mlbench.friedman1` simulates all of the predictors to be standardized uniform random variables.\n\nOur function to fit the model and compute the RMSE is:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(kernlab)\n\n# `object` will be an `rsplit` object from our `results` tibble\n# `cost` is the tuning parameter\nsvm_rmse <- function(object, cost = 1) {\n set.seed(1234)\n y_col <- ncol(object$data)\n mod <- \n svm_rbf(mode = \"regression\", cost = cost) %>% \n set_engine(\"kernlab\") %>% \n fit(y ~ ., data = analysis(object))\n \n holdout_pred <- \n predict(mod, assessment(object) %>% dplyr::select(-y)) %>% \n bind_cols(assessment(object) %>% dplyr::select(y))\n rmse(holdout_pred, truth = y, estimate = .pred)$.estimate\n}\n\n# In some case, we want to parameterize the function over the tuning parameter:\nrmse_wrapper <- function(cost, object) svm_rmse(object, cost)\n```\n:::\n\n\n\n\nFor the nested resampling, a model needs to be fit for each tuning parameter and each bootstrap split. To do this, create a wrapper:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# `object` will be an `rsplit` object for the bootstrap samples\ntune_over_cost <- function(object) {\n tibble(cost = 2 ^ seq(-2, 8, by = 1)) %>% \n mutate(RMSE = map_dbl(cost, rmse_wrapper, object = object))\n}\n```\n:::\n\n\n\n\nSince this will be called across the set of outer cross-validation splits, another wrapper is required:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# `object` is an `rsplit` object in `results$inner_resamples` \nsummarize_tune_results <- function(object) {\n # Return row-bound tibble that has the 25 bootstrap results\n map_df(object$splits, tune_over_cost) %>%\n # For each value of the tuning parameter, compute the \n # average RMSE which is the inner bootstrap estimate. \n group_by(cost) %>%\n summarize(mean_RMSE = mean(RMSE, na.rm = TRUE),\n n = length(RMSE),\n .groups = \"drop\")\n}\n```\n:::\n\n\n\n\nNow that those functions are defined, we can execute all the inner resampling loops:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ntuning_results <- map(results$inner_resamples, summarize_tune_results) \n```\n:::\n\n\n\n\nAlternatively, since these computations can be run in parallel, we can use the furrr package. Instead of using `map()`, the function `future_map()` parallelizes the iterations using the [future package](https://cran.r-project.org/web/packages/future/vignettes/future-1-overview.html). The `multisession` plan uses the local cores to process the inner resampling loop. The end results are the same as the sequential computations.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(furrr)\nplan(multisession)\n\ntuning_results <- future_map(results$inner_resamples, summarize_tune_results) \n```\n:::\n\n\n\n\nThe object `tuning_results` is a list of data frames for each of the 50 outer resamples.\n\nLet's make a plot of the averaged results to see what the relationship is between the RMSE and the tuning parameters for each of the inner bootstrapping operations:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(scales)\n\npooled_inner <- tuning_results %>% bind_rows\n\nbest_cost <- function(dat) dat[which.min(dat$mean_RMSE),]\n\np <- \n ggplot(pooled_inner, aes(x = cost, y = mean_RMSE)) + \n scale_x_continuous(trans = 'log2') +\n xlab(\"SVM Cost\") + ylab(\"Inner RMSE\")\n\nfor (i in 1:length(tuning_results))\n p <- p +\n geom_line(data = tuning_results[[i]], alpha = .2) +\n geom_point(data = best_cost(tuning_results[[i]]), pch = 16, alpha = 3/4)\n\np <- p + geom_smooth(data = pooled_inner, se = FALSE)\np\n```\n\n::: {.cell-output-display}\n![](figs/rmse-plot-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nEach gray line is a separate bootstrap resampling curve created from a different 90% of the data. The blue line is a LOESS smooth of all the results pooled together.\n\nTo determine the best parameter estimate for each of the outer resampling iterations:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncost_vals <- \n tuning_results %>% \n map_df(best_cost) %>% \n select(cost)\n\nresults <- \n bind_cols(results, cost_vals) %>% \n mutate(cost = factor(cost, levels = paste(2 ^ seq(-2, 8, by = 1))))\n\nggplot(results, aes(x = cost)) + \n geom_bar() + \n xlab(\"SVM Cost\") + \n scale_x_discrete(drop = FALSE)\n```\n\n::: {.cell-output-display}\n![](figs/choose-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nMost of the resamples produced an optimal cost value of 2.0, but the distribution is right-skewed due to the flat trend in the resampling profile once the cost value becomes 10 or larger.\n\nNow that we have these estimates, we can compute the outer resampling results for each of the 50 splits using the corresponding tuning parameter value:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nresults <- \n results %>% \n mutate(RMSE = map2_dbl(splits, cost, svm_rmse))\n\nsummary(results$RMSE)\n#> Min. 1st Qu. Median Mean 3rd Qu. Max. \n#> 1.676 2.090 2.589 2.682 3.220 4.228\n```\n:::\n\n\n\n\nThe estimated RMSE for the model tuning process is 2.68.\n\nWhat is the RMSE estimate for the non-nested procedure when only the outer resampling method is used? For each cost value in the tuning grid, 50 SVM models are fit and their RMSE values are averaged. The table of cost values and mean RMSE estimates is used to determine the best cost value. The associated RMSE is the biased estimate.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnot_nested <- \n map(results$splits, tune_over_cost) %>%\n bind_rows\n\nouter_summary <- not_nested %>% \n group_by(cost) %>% \n summarize(outer_RMSE = mean(RMSE), n = length(RMSE))\n\nouter_summary\n#> # A tibble: 11 × 3\n#> cost outer_RMSE n\n#> \n#> 1 0.25 3.54 50\n#> 2 0.5 3.11 50\n#> 3 1 2.78 50\n#> 4 2 2.63 50\n#> 5 4 2.66 50\n#> 6 8 2.78 50\n#> 7 16 2.84 50\n#> 8 32 2.84 50\n#> 9 64 2.84 50\n#> 10 128 2.84 50\n#> 11 256 2.84 50\n\nggplot(outer_summary, aes(x = cost, y = outer_RMSE)) + \n geom_point() + \n geom_line() + \n scale_x_continuous(trans = 'log2') +\n xlab(\"SVM Cost\") + ylab(\"RMSE\")\n```\n\n::: {.cell-output-display}\n![](figs/not-nested-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThe non-nested procedure estimates the RMSE to be 2.63. Both estimates are fairly close.\n\nThe approximately true RMSE for an SVM model with a cost value of 2.0 can be approximated with the large sample that was simulated at the beginning.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nfinalModel <- ksvm(y ~ ., data = train_dat, C = 2)\nlarge_pred <- predict(finalModel, large_dat[, -ncol(large_dat)])\nsqrt(mean((large_dat$y - large_pred) ^ 2, na.rm = TRUE))\n#> [1] 2.695091\n```\n:::\n\n\n\n\nThe nested procedure produces a closer estimate to the approximate truth but the non-nested estimate is very similar.\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> furrr 0.3.1 2022-08-15 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> kernlab 0.9-33 2024-08-13 CRAN (R 4.4.0)\n#> mlbench 2.1-6 2024-12-30 CRAN (R 4.4.1)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> scales 1.3.0 2023-11-28 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", "supporting": [], "filters": [ "rmarkdown/pagebreak.lua" diff --git a/learn/work/nested-resampling/figs/choose-1.svg b/learn/work/nested-resampling/figs/choose-1.svg index 9a3c3761..1fbe8e44 100644 --- a/learn/work/nested-resampling/figs/choose-1.svg +++ b/learn/work/nested-resampling/figs/choose-1.svg @@ -50,11 +50,8 @@ - - - - - + + diff --git a/learn/work/nested-resampling/figs/not-nested-1.svg b/learn/work/nested-resampling/figs/not-nested-1.svg index 8c272c40..a67583b1 100644 --- a/learn/work/nested-resampling/figs/not-nested-1.svg +++ b/learn/work/nested-resampling/figs/not-nested-1.svg @@ -30,48 +30,48 @@ - - - - - + + + + + - - - - - + + + + + - - + + - - - - - - - - + + + + + + + + -2.6 -2.8 -3.0 -3.2 -3.4 - - - - - +2.6 +2.8 +3.0 +3.2 +3.4 + + + + + diff --git a/learn/work/nested-resampling/figs/rmse-plot-1.svg b/learn/work/nested-resampling/figs/rmse-plot-1.svg index 14125612..be5c624a 100644 --- a/learn/work/nested-resampling/figs/rmse-plot-1.svg +++ b/learn/work/nested-resampling/figs/rmse-plot-1.svg @@ -30,126 +30,126 @@ - - - + + + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + -3.0 -3.5 - - +3.0 +3.5 + + diff --git a/learn/work/nested-resampling/index.html.md b/learn/work/nested-resampling/index.html.md index dca8f49a..e0ffa868 100644 --- a/learn/work/nested-resampling/index.html.md +++ b/learn/work/nested-resampling/index.html.md @@ -151,6 +151,7 @@ library(kernlab) # `object` will be an `rsplit` object from our `results` tibble # `cost` is the tuning parameter svm_rmse <- function(object, cost = 1) { + set.seed(1234) y_col <- ncol(object$data) mod <- svm_rbf(mode = "regression", cost = cost) %>% @@ -293,7 +294,7 @@ results <- summary(results$RMSE) #> Min. 1st Qu. Median Mean 3rd Qu. Max. -#> 1.574 2.095 2.668 2.683 3.252 4.350 +#> 1.676 2.090 2.589 2.682 3.220 4.228 ``` ::: @@ -318,15 +319,15 @@ outer_summary #> #> 1 0.25 3.54 50 #> 2 0.5 3.11 50 -#> 3 1 2.77 50 -#> 4 2 2.62 50 -#> 5 4 2.65 50 -#> 6 8 2.75 50 -#> 7 16 2.82 50 -#> 8 32 2.82 50 -#> 9 64 2.83 50 -#> 10 128 2.83 50 -#> 11 256 2.82 50 +#> 3 1 2.78 50 +#> 4 2 2.63 50 +#> 5 4 2.66 50 +#> 6 8 2.78 50 +#> 7 16 2.84 50 +#> 8 32 2.84 50 +#> 9 64 2.84 50 +#> 10 128 2.84 50 +#> 11 256 2.84 50 ggplot(outer_summary, aes(x = cost, y = outer_RMSE)) + geom_point() + @@ -340,7 +341,7 @@ ggplot(outer_summary, aes(x = cost, y = outer_RMSE)) + ::: ::: -The non-nested procedure estimates the RMSE to be 2.62. Both estimates are fairly close. +The non-nested procedure estimates the RMSE to be 2.63. Both estimates are fairly close. The approximately true RMSE for an SVM model with a cost value of 2.0 can be approximated with the large sample that was simulated at the beginning. @@ -350,7 +351,7 @@ The approximately true RMSE for an SVM model with a cost value of 2.0 can be app finalModel <- ksvm(y ~ ., data = train_dat, C = 2) large_pred <- predict(finalModel, large_dat[, -ncol(large_dat)]) sqrt(mean((large_dat$y - large_pred) ^ 2, na.rm = TRUE)) -#> [1] 2.712059 +#> [1] 2.695091 ``` ::: diff --git a/learn/work/nested-resampling/index.qmd b/learn/work/nested-resampling/index.qmd index 37e67577..f1bf65b3 100644 --- a/learn/work/nested-resampling/index.qmd +++ b/learn/work/nested-resampling/index.qmd @@ -130,6 +130,7 @@ library(kernlab) # `object` will be an `rsplit` object from our `results` tibble # `cost` is the tuning parameter svm_rmse <- function(object, cost = 1) { + set.seed(1234) y_col <- ncol(object$data) mod <- svm_rbf(mode = "regression", cost = cost) %>% diff --git a/learn/work/nested-resampling/index.rmarkdown b/learn/work/nested-resampling/index.rmarkdown deleted file mode 100644 index 7de0b15f..00000000 --- a/learn/work/nested-resampling/index.rmarkdown +++ /dev/null @@ -1,371 +0,0 @@ ---- -title: "Nested resampling" -categories: - - SVMs -type: learn-subsection -weight: 2 -description: | - Estimate the best hyperparameters for a model using nested resampling. -toc: true -toc-depth: 2 -include-after-body: ../../../resources.html ---- - -```{r} -#| label: "setup" -#| include: false -#| message: false -#| warning: false -source(here::here("common.R")) -``` - -```{r} -#| label: "load" -#| include: false -library(tidymodels) -library(scales) -library(mlbench) -library(kernlab) -library(furrr) - -set.seed(1234) - -pkgs <- c("tidymodels", "scales", "mlbench", "kernlab", "furrr") - -theme_set(theme_bw() + theme(legend.position = "top")) -``` - - - -## Introduction - -`r article_req_pkgs(pkgs)` - -In this article, we discuss an alternative method for evaluating and tuning models, called [nested resampling](https://scholar.google.com/scholar?hl=en&as_sdt=0%2C7&q=%22nested+resampling%22+inner+outer&btnG=). While it is more computationally taxing and challenging to implement than other resampling methods, it has the potential to produce better estimates of model performance. - -## Resampling models - -A typical scheme for splitting the data when developing a predictive model is to create an initial split of the data into a training and test set. If resampling is used, it is executed on the training set. A series of binary splits is created. In rsample, we use the term *analysis set* for the data that are used to fit the model and the term *assessment set* for the set used to compute performance: - - - -```{r} -#| label: "resampling-fig" -#| echo: false -#| fig-align: center -#| out-width: "70%" -knitr::include_graphics("img/resampling.svg") -``` - - - -A common method for tuning models is [grid search](/learn/work/tune-svm/) where a candidate set of tuning parameters is created. The full set of models for every combination of the tuning parameter grid and the resamples is fitted. Each time, the assessment data are used to measure performance and the average value is determined for each tuning parameter. - -The potential problem is that once we pick the tuning parameter associated with the best performance, this performance value is usually quoted as the performance of the model. There is serious potential for *optimization bias* since we use the same data to tune the model and to assess performance. This would result in an optimistic estimate of performance. - -Nested resampling uses an additional layer of resampling that separates the tuning activities from the process used to estimate the efficacy of the model. An *outer* resampling scheme is used and, for every split in the outer resample, another full set of resampling splits are created on the original analysis set. For example, if 10-fold cross-validation is used on the outside and 5-fold cross-validation on the inside, a total of 500 models will be fit. The parameter tuning will be conducted 10 times and the best parameters are determined from the average of the 5 assessment sets. This process occurs 10 times. - -Once the tuning results are complete, a model is fit to each of the outer resampling splits using the best parameter associated with that resample. The average of the outer method's assessment sets are a unbiased estimate of the model. - -We will simulate some regression data to illustrate the methods. The mlbench package has a function `mlbench::mlbench.friedman1()` that can simulate a complex regression data structure from the [original MARS publication](https://scholar.google.com/scholar?hl=en&q=%22Multivariate+adaptive+regression+splines%22&btnG=&as_sdt=1%2C7&as_sdtp=). A training set size of 100 data points are generated as well as a large set that will be used to characterize how well the resampling procedure performed. - - - -```{r} -#| label: "sim-data" -library(mlbench) -sim_data <- function(n) { - tmp <- mlbench.friedman1(n, sd = 1) - tmp <- cbind(tmp$x, tmp$y) - tmp <- as.data.frame(tmp) - names(tmp)[ncol(tmp)] <- "y" - tmp -} - -set.seed(9815) -train_dat <- sim_data(100) -large_dat <- sim_data(10^5) -``` - - - -## Nested resampling - -To get started, the types of resampling methods need to be specified. This isn't a large data set, so 5 repeats of 10-fold cross validation will be used as the *outer* resampling method for generating the estimate of overall performance. To tune the model, it would be good to have precise estimates for each of the values of the tuning parameter so let's use 25 iterations of the bootstrap. This means that there will eventually be `5 * 10 * 25 = 1250` models that are fit to the data *per tuning parameter*. These models will be discarded once the performance of the model has been quantified. - -To create the tibble with the resampling specifications: - - - -```{r} -#| label: "tibble-gen" -library(tidymodels) -results <- nested_cv(train_dat, - outside = vfold_cv(repeats = 5), - inside = bootstraps(times = 25)) -results -``` - - - -The splitting information for each resample is contained in the `split` objects. Focusing on the second fold of the first repeat: - - - -```{r} -#| label: "split-example" -results$splits[[2]] -``` - - - -`<90/10/100>` indicates the number of observations in the analysis set, assessment set, and the original data. - -Each element of `inner_resamples` has its own tibble with the bootstrapping splits. - - - -```{r} -#| label: "inner-splits" -results$inner_resamples[[5]] -``` - - - -These are self-contained, meaning that the bootstrap sample is aware that it is a sample of a specific 90% of the data: - - - -```{r} -#| label: "inner-boot-split" -results$inner_resamples[[5]]$splits[[1]] -``` - - - -To start, we need to define how the model will be created and measured. Let's use a radial basis support vector machine model via the function `kernlab::ksvm`. This model is generally considered to have *two* tuning parameters: the SVM cost value and the kernel parameter `sigma`. For illustration purposes here, only the cost value will be tuned and the function `kernlab::sigest` will be used to estimate `sigma` during each model fit. This is automatically done by `ksvm`. - -After the model is fit to the analysis set, the root-mean squared error (RMSE) is computed on the assessment set. **One important note:** for this model, it is critical to center and scale the predictors before computing dot products. We don't do this operation here because `mlbench.friedman1` simulates all of the predictors to be standardized uniform random variables. - -Our function to fit the model and compute the RMSE is: - - - -```{r} -#| label: "rmse-func" -library(kernlab) - -# `object` will be an `rsplit` object from our `results` tibble -# `cost` is the tuning parameter -svm_rmse <- function(object, cost = 1) { - y_col <- ncol(object$data) - mod <- - svm_rbf(mode = "regression", cost = cost) %>% - set_engine("kernlab") %>% - fit(y ~ ., data = analysis(object)) - - holdout_pred <- - predict(mod, assessment(object) %>% dplyr::select(-y)) %>% - bind_cols(assessment(object) %>% dplyr::select(y)) - rmse(holdout_pred, truth = y, estimate = .pred)$.estimate -} - -# In some case, we want to parameterize the function over the tuning parameter: -rmse_wrapper <- function(cost, object) svm_rmse(object, cost) -``` - - - -For the nested resampling, a model needs to be fit for each tuning parameter and each bootstrap split. To do this, create a wrapper: - - - -```{r} -#| label: "inner-tune-func" -# `object` will be an `rsplit` object for the bootstrap samples -tune_over_cost <- function(object) { - tibble(cost = 2 ^ seq(-2, 8, by = 1)) %>% - mutate(RMSE = map_dbl(cost, rmse_wrapper, object = object)) -} -``` - - - -Since this will be called across the set of outer cross-validation splits, another wrapper is required: - - - -```{r} -#| label: "inner-func" -# `object` is an `rsplit` object in `results$inner_resamples` -summarize_tune_results <- function(object) { - # Return row-bound tibble that has the 25 bootstrap results - map_df(object$splits, tune_over_cost) %>% - # For each value of the tuning parameter, compute the - # average RMSE which is the inner bootstrap estimate. - group_by(cost) %>% - summarize(mean_RMSE = mean(RMSE, na.rm = TRUE), - n = length(RMSE), - .groups = "drop") -} -``` - - - -Now that those functions are defined, we can execute all the inner resampling loops: - - - -```{r} -#| label: "inner-runs" -#| eval: false -tuning_results <- map(results$inner_resamples, summarize_tune_results) -``` - - - -Alternatively, since these computations can be run in parallel, we can use the furrr package. Instead of using `map()`, the function `future_map()` parallelizes the iterations using the [future package](https://cran.r-project.org/web/packages/future/vignettes/future-1-overview.html). The `multisession` plan uses the local cores to process the inner resampling loop. The end results are the same as the sequential computations. - - - -```{r} -#| label: "inner-runs-parallel" -#| warning: false -library(furrr) -plan(multisession) - -tuning_results <- future_map(results$inner_resamples, summarize_tune_results) -``` - - - -The object `tuning_results` is a list of data frames for each of the 50 outer resamples. - -Let's make a plot of the averaged results to see what the relationship is between the RMSE and the tuning parameters for each of the inner bootstrapping operations: - - - -```{r} -#| label: "rmse-plot" -#| fig-height: 4 -#| message: false -library(scales) - -pooled_inner <- tuning_results %>% bind_rows - -best_cost <- function(dat) dat[which.min(dat$mean_RMSE),] - -p <- - ggplot(pooled_inner, aes(x = cost, y = mean_RMSE)) + - scale_x_continuous(trans = 'log2') + - xlab("SVM Cost") + ylab("Inner RMSE") - -for (i in 1:length(tuning_results)) - p <- p + - geom_line(data = tuning_results[[i]], alpha = .2) + - geom_point(data = best_cost(tuning_results[[i]]), pch = 16, alpha = 3/4) - -p <- p + geom_smooth(data = pooled_inner, se = FALSE) -p -``` - - - -Each gray line is a separate bootstrap resampling curve created from a different 90% of the data. The blue line is a LOESS smooth of all the results pooled together. - -To determine the best parameter estimate for each of the outer resampling iterations: - - - -```{r} -#| label: "choose" -#| fig-height: 4 -cost_vals <- - tuning_results %>% - map_df(best_cost) %>% - select(cost) - -results <- - bind_cols(results, cost_vals) %>% - mutate(cost = factor(cost, levels = paste(2 ^ seq(-2, 8, by = 1)))) - -ggplot(results, aes(x = cost)) + - geom_bar() + - xlab("SVM Cost") + - scale_x_discrete(drop = FALSE) -``` - - - -Most of the resamples produced an optimal cost value of 2.0, but the distribution is right-skewed due to the flat trend in the resampling profile once the cost value becomes 10 or larger. - -Now that we have these estimates, we can compute the outer resampling results for each of the `r nrow(results)` splits using the corresponding tuning parameter value: - - - -```{r} -#| label: "run-out-r" -results <- - results %>% - mutate(RMSE = map2_dbl(splits, cost, svm_rmse)) - -summary(results$RMSE) -``` - - - -The estimated RMSE for the model tuning process is `r round(mean(results$RMSE), 2)`. - -What is the RMSE estimate for the non-nested procedure when only the outer resampling method is used? For each cost value in the tuning grid, `r nrow(results)` SVM models are fit and their RMSE values are averaged. The table of cost values and mean RMSE estimates is used to determine the best cost value. The associated RMSE is the biased estimate. - - - -```{r} -#| label: "not-nested" -#| fig-height: 4 -not_nested <- - map(results$splits, tune_over_cost) %>% - bind_rows - -outer_summary <- not_nested %>% - group_by(cost) %>% - summarize(outer_RMSE = mean(RMSE), n = length(RMSE)) - -outer_summary - -ggplot(outer_summary, aes(x = cost, y = outer_RMSE)) + - geom_point() + - geom_line() + - scale_x_continuous(trans = 'log2') + - xlab("SVM Cost") + ylab("RMSE") -``` - - - -The non-nested procedure estimates the RMSE to be `r round(min(outer_summary$outer_RMSE), 2)`. Both estimates are fairly close. - -The approximately true RMSE for an SVM model with a cost value of 2.0 can be approximated with the large sample that was simulated at the beginning. - - - -```{r} -#| label: "large-sample-estimate" -finalModel <- ksvm(y ~ ., data = train_dat, C = 2) -large_pred <- predict(finalModel, large_dat[, -ncol(large_dat)]) -sqrt(mean((large_dat$y - large_pred) ^ 2, na.rm = TRUE)) -``` - - - -The nested procedure produces a closer estimate to the approximate truth but the non-nested estimate is very similar. - -## Session information {#session-info} - - - -```{r} -#| label: "si" -#| echo: false -small_session(pkgs) -``` - From f002cb7c2329f98dd12f46dc5ed6a882973e4dbc Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 24 Mar 2025 13:00:13 -0700 Subject: [PATCH 7/7] rerender --- .../learn/models/sub-sampling/index/execute-results/html.json | 2 +- _freeze/learn/statistics/infer/index/execute-results/html.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/_freeze/learn/models/sub-sampling/index/execute-results/html.json b/_freeze/learn/models/sub-sampling/index/execute-results/html.json index cbfa10e6..e1e41e78 100644 --- a/_freeze/learn/models/sub-sampling/index/execute-results/html.json +++ b/_freeze/learn/models/sub-sampling/index/execute-results/html.json @@ -2,7 +2,7 @@ "hash": "47f50af03cb12d3ec299b80c5d93b121", "result": { "engine": "knitr", - "markdown": "---\ntitle: \"Subsampling for class imbalances\"\ncategories:\n - model fitting\n - pre-processing\ntype: learn-subsection\nweight: 3\ndescription: | \n Improve model performance in imbalanced data sets through undersampling or oversampling.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: discrim, klaR, readr, ROSE, themis, and tidymodels.\n\nSubsampling a training set, either undersampling or oversampling the appropriate class or classes, can be a helpful approach to dealing with classification data where one or more classes occur very infrequently. In such a situation (without compensating for it), most models will overfit to the majority class and produce very good statistics for the class containing the frequently occurring classes while the minority classes have poor performance. \n\nThis article describes subsampling for dealing with class imbalances. For better understanding, some knowledge of classification metrics like sensitivity, specificity, and receiver operating characteristic curves is required. See Section 3.2.2 in [Kuhn and Johnson (2019)](https://bookdown.org/max/FES/measuring-performance.html) for more information on these metrics. \n\n## Simulated data\n\nConsider a two-class problem where the first class has a very low rate of occurrence. The data were simulated and can be imported into R using the code below:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nimbal_data <- \n readr::read_csv(\"https://tidymodels.org/learn/models/sub-sampling/imbal_data.csv\") %>% \n mutate(Class = factor(Class))\ndim(imbal_data)\n#> [1] 1200 16\ntable(imbal_data$Class)\n#> \n#> Class1 Class2 \n#> 60 1140\n```\n:::\n\n\n\n\n\nIf \"Class1\" is the event of interest, it is very likely that a classification model would be able to achieve very good _specificity_ since almost all of the data are of the second class. _Sensitivity_, however, would likely be poor since the models will optimize accuracy (or other loss functions) by predicting everything to be the majority class. \n\nOne result of class imbalance when there are two classes is that the default probability cutoff of 50% is inappropriate; a different cutoff that is more extreme might be able to achieve good performance. \n\n## Subsampling the data\n\nOne way to alleviate this issue is to _subsample_ the data. There are a number of ways to do this but the most simple one is to _sample down_ (undersample) the majority class data until it occurs with the same frequency as the minority class. While it may seem counterintuitive, throwing out a large percentage of your data can be effective at producing a useful model that can recognize both the majority and minority classes. In some cases, this even means that the overall performance of the model is better (e.g. improved area under the ROC curve). However, subsampling almost always produces models that are _better calibrated_, meaning that the distributions of the class probabilities are more well behaved. As a result, the default 50% cutoff is much more likely to produce better sensitivity and specificity values than they would otherwise. \n\nLet's explore subsampling using `themis::step_rose()` in a recipe for the simulated data. It uses the ROSE (random over sampling examples) method from [Menardi, G. and Torelli, N. (2014)](https://scholar.google.com/scholar?hl=en&q=%22training+and+assessing+classification+rules+with+imbalanced+data%22). This is an example of an oversampling strategy, rather than undersampling.\n\nIn terms of workflow:\n\n * It is extremely important that subsampling occurs _inside of resampling_. Otherwise, the resampling process can produce [poor estimates of model performance](https://topepo.github.io/caret/subsampling-for-class-imbalances.html#resampling). \n * The subsampling process should only be applied to the analysis set. The assessment set should reflect the event rates seen \"in the wild\" and, for this reason, the `skip` argument to `step_downsample()` and other subsampling recipes steps has a default of `TRUE`. \n\nHere is a simple recipe implementing oversampling: \n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nlibrary(themis)\nset.seed(1234)\n\nimbal_rec <- \n recipe(Class ~ ., data = imbal_data) %>%\n step_rose(Class)\n```\n:::\n\n\n\n\n\nFor a model, let's use a [quadratic discriminant analysis](https://en.wikipedia.org/wiki/Quadratic_classifier#Quadratic_discriminant_analysis) (QDA) model. From the discrim package, this model can be specified using:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(discrim)\nqda_mod <- \n discrim_regularized(frac_common_cov = 0, frac_identity = 0) %>% \n set_engine(\"klaR\")\n```\n:::\n\n\n\n\n\nTo keep these objects bound together, they can be combined in a [workflow](https://workflows.tidymodels.org/):\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nqda_rose_wflw <- \n workflow() %>% \n add_model(qda_mod) %>% \n add_recipe(imbal_rec)\nqda_rose_wflw\n#> ══ Workflow ══════════════════════════════════════════════════════════\n#> Preprocessor: Recipe\n#> Model: discrim_regularized()\n#> \n#> ── Preprocessor ──────────────────────────────────────────────────────\n#> 1 Recipe Step\n#> \n#> • step_rose()\n#> \n#> ── Model ─────────────────────────────────────────────────────────────\n#> Regularized Discriminant Model Specification (classification)\n#> \n#> Main Arguments:\n#> frac_common_cov = 0\n#> frac_identity = 0\n#> \n#> Computational engine: klaR\n```\n:::\n\n\n\n\n\n## Model performance\n\nStratified, repeated 10-fold cross-validation is used to resample the model:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(5732)\ncv_folds <- vfold_cv(imbal_data, strata = \"Class\", repeats = 5)\n```\n:::\n\n\n\n\n\nTo measure model performance, let's use two metrics:\n\n * The area under the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) is an overall assessment of performance across _all_ cutoffs. Values near one indicate very good results while values near 0.5 would imply that the model is very poor. \n * The _J_ index (a.k.a. [Youden's _J_](https://en.wikipedia.org/wiki/Youden%27s_J_statistic) statistic) is `sensitivity + specificity - 1`. Values near one are once again best. \n\nIf a model is poorly calibrated, the ROC curve value might not show diminished performance. However, the _J_ index would be lower for models with pathological distributions for the class probabilities. The yardstick package will be used to compute these metrics. \n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncls_metrics <- metric_set(roc_auc, j_index)\n```\n:::\n\n\n\n\n\nNow, we train the models and generate the results using `tune::fit_resamples()`:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(2180)\nqda_rose_res <- fit_resamples(\n qda_rose_wflw, \n resamples = cv_folds, \n metrics = cls_metrics\n)\n\ncollect_metrics(qda_rose_res)\n#> # A tibble: 2 × 6\n#> .metric .estimator mean n std_err .config \n#> \n#> 1 j_index binary 0.777 50 0.0199 Preprocessor1_Model1\n#> 2 roc_auc binary 0.949 50 0.00508 Preprocessor1_Model1\n```\n:::\n\n\n\n\n\nWhat do the results look like without using ROSE? We can create another workflow and fit the QDA model along the same resamples:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nqda_wflw <- \n workflow() %>% \n add_model(qda_mod) %>% \n add_formula(Class ~ .)\n\nset.seed(2180)\nqda_only_res <- fit_resamples(qda_wflw, resamples = cv_folds, metrics = cls_metrics)\ncollect_metrics(qda_only_res)\n#> # A tibble: 2 × 6\n#> .metric .estimator mean n std_err .config \n#> \n#> 1 j_index binary 0.250 50 0.0288 Preprocessor1_Model1\n#> 2 roc_auc binary 0.953 50 0.00479 Preprocessor1_Model1\n```\n:::\n\n\n\n\n\nIt looks like ROSE helped a lot, especially with the J-index. Class imbalance sampling methods tend to greatly improve metrics based on the hard class predictions (i.e., the categorical predictions) because the default cutoff tends to be a better balance of sensitivity and specificity. \n\nLet's plot the metrics for each resample to see how the individual results changed. \n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nno_sampling <- \n qda_only_res %>% \n collect_metrics(summarize = FALSE) %>% \n dplyr::select(-.estimator) %>% \n mutate(sampling = \"no_sampling\")\n\nwith_sampling <- \n qda_rose_res %>% \n collect_metrics(summarize = FALSE) %>% \n dplyr::select(-.estimator) %>% \n mutate(sampling = \"rose\")\n\nbind_rows(no_sampling, with_sampling) %>% \n mutate(label = paste(id2, id)) %>% \n ggplot(aes(x = sampling, y = .estimate, group = label)) + \n geom_line(alpha = .4) + \n facet_wrap(~ .metric, scales = \"free_y\")\n```\n\n::: {.cell-output-display}\n![](figs/merge-metrics-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nThis visually demonstrates that the subsampling mostly affects metrics that use the hard class predictions. \n\n## Session information {#session-info}\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> discrim 1.0.1 2023-03-08 CRAN (R 4.4.0)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> klaR 1.7-3 2023-12-13 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> readr 2.1.5 2024-01-10 CRAN (R 4.4.0)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> ROSE 0.0-4 2021-06-14 CRAN (R 4.4.0)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> themis 1.0.3 2025-01-23 CRAN (R 4.4.1)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", + "markdown": "---\ntitle: \"Subsampling for class imbalances\"\ncategories:\n - model fitting\n - pre-processing\ntype: learn-subsection\nweight: 3\ndescription: | \n Improve model performance in imbalanced data sets through undersampling or oversampling.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n\n## Introduction\n\nTo use code in this article, you will need to install the following packages: discrim, klaR, readr, ROSE, themis, and tidymodels.\n\nSubsampling a training set, either undersampling or oversampling the appropriate class or classes, can be a helpful approach to dealing with classification data where one or more classes occur very infrequently. In such a situation (without compensating for it), most models will overfit to the majority class and produce very good statistics for the class containing the frequently occurring classes while the minority classes have poor performance. \n\nThis article describes subsampling for dealing with class imbalances. For better understanding, some knowledge of classification metrics like sensitivity, specificity, and receiver operating characteristic curves is required. See Section 3.2.2 in [Kuhn and Johnson (2019)](https://bookdown.org/max/FES/measuring-performance.html) for more information on these metrics. \n\n## Simulated data\n\nConsider a two-class problem where the first class has a very low rate of occurrence. The data were simulated and can be imported into R using the code below:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nimbal_data <- \n readr::read_csv(\"https://tidymodels.org/learn/models/sub-sampling/imbal_data.csv\") %>% \n mutate(Class = factor(Class))\ndim(imbal_data)\n#> [1] 1200 16\ntable(imbal_data$Class)\n#> \n#> Class1 Class2 \n#> 60 1140\n```\n:::\n\n\n\n\nIf \"Class1\" is the event of interest, it is very likely that a classification model would be able to achieve very good _specificity_ since almost all of the data are of the second class. _Sensitivity_, however, would likely be poor since the models will optimize accuracy (or other loss functions) by predicting everything to be the majority class. \n\nOne result of class imbalance when there are two classes is that the default probability cutoff of 50% is inappropriate; a different cutoff that is more extreme might be able to achieve good performance. \n\n## Subsampling the data\n\nOne way to alleviate this issue is to _subsample_ the data. There are a number of ways to do this but the most simple one is to _sample down_ (undersample) the majority class data until it occurs with the same frequency as the minority class. While it may seem counterintuitive, throwing out a large percentage of your data can be effective at producing a useful model that can recognize both the majority and minority classes. In some cases, this even means that the overall performance of the model is better (e.g. improved area under the ROC curve). However, subsampling almost always produces models that are _better calibrated_, meaning that the distributions of the class probabilities are more well behaved. As a result, the default 50% cutoff is much more likely to produce better sensitivity and specificity values than they would otherwise. \n\nLet's explore subsampling using `themis::step_rose()` in a recipe for the simulated data. It uses the ROSE (random over sampling examples) method from [Menardi, G. and Torelli, N. (2014)](https://scholar.google.com/scholar?hl=en&q=%22training+and+assessing+classification+rules+with+imbalanced+data%22). This is an example of an oversampling strategy, rather than undersampling.\n\nIn terms of workflow:\n\n * It is extremely important that subsampling occurs _inside of resampling_. Otherwise, the resampling process can produce [poor estimates of model performance](https://topepo.github.io/caret/subsampling-for-class-imbalances.html#resampling). \n * The subsampling process should only be applied to the analysis set. The assessment set should reflect the event rates seen \"in the wild\" and, for this reason, the `skip` argument to `step_downsample()` and other subsampling recipes steps has a default of `TRUE`. \n\nHere is a simple recipe implementing oversampling: \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels)\nlibrary(themis)\nset.seed(1234)\n\nimbal_rec <- \n recipe(Class ~ ., data = imbal_data) %>%\n step_rose(Class)\n```\n:::\n\n\n\n\nFor a model, let's use a [quadratic discriminant analysis](https://en.wikipedia.org/wiki/Quadratic_classifier#Quadratic_discriminant_analysis) (QDA) model. From the discrim package, this model can be specified using:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(discrim)\nqda_mod <- \n discrim_regularized(frac_common_cov = 0, frac_identity = 0) %>% \n set_engine(\"klaR\")\n```\n:::\n\n\n\n\nTo keep these objects bound together, they can be combined in a [workflow](https://workflows.tidymodels.org/):\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nqda_rose_wflw <- \n workflow() %>% \n add_model(qda_mod) %>% \n add_recipe(imbal_rec)\nqda_rose_wflw\n#> ══ Workflow ══════════════════════════════════════════════════════════\n#> Preprocessor: Recipe\n#> Model: discrim_regularized()\n#> \n#> ── Preprocessor ──────────────────────────────────────────────────────\n#> 1 Recipe Step\n#> \n#> • step_rose()\n#> \n#> ── Model ─────────────────────────────────────────────────────────────\n#> Regularized Discriminant Model Specification (classification)\n#> \n#> Main Arguments:\n#> frac_common_cov = 0\n#> frac_identity = 0\n#> \n#> Computational engine: klaR\n```\n:::\n\n\n\n\n## Model performance\n\nStratified, repeated 10-fold cross-validation is used to resample the model:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(5732)\ncv_folds <- vfold_cv(imbal_data, strata = \"Class\", repeats = 5)\n```\n:::\n\n\n\n\nTo measure model performance, let's use two metrics:\n\n * The area under the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) is an overall assessment of performance across _all_ cutoffs. Values near one indicate very good results while values near 0.5 would imply that the model is very poor. \n * The _J_ index (a.k.a. [Youden's _J_](https://en.wikipedia.org/wiki/Youden%27s_J_statistic) statistic) is `sensitivity + specificity - 1`. Values near one are once again best. \n\nIf a model is poorly calibrated, the ROC curve value might not show diminished performance. However, the _J_ index would be lower for models with pathological distributions for the class probabilities. The yardstick package will be used to compute these metrics. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ncls_metrics <- metric_set(roc_auc, j_index)\n```\n:::\n\n\n\n\nNow, we train the models and generate the results using `tune::fit_resamples()`:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nset.seed(2180)\nqda_rose_res <- fit_resamples(\n qda_rose_wflw, \n resamples = cv_folds, \n metrics = cls_metrics\n)\n\ncollect_metrics(qda_rose_res)\n#> # A tibble: 2 × 6\n#> .metric .estimator mean n std_err .config \n#> \n#> 1 j_index binary 0.777 50 0.0199 Preprocessor1_Model1\n#> 2 roc_auc binary 0.949 50 0.00508 Preprocessor1_Model1\n```\n:::\n\n\n\n\nWhat do the results look like without using ROSE? We can create another workflow and fit the QDA model along the same resamples:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nqda_wflw <- \n workflow() %>% \n add_model(qda_mod) %>% \n add_formula(Class ~ .)\n\nset.seed(2180)\nqda_only_res <- fit_resamples(qda_wflw, resamples = cv_folds, metrics = cls_metrics)\ncollect_metrics(qda_only_res)\n#> # A tibble: 2 × 6\n#> .metric .estimator mean n std_err .config \n#> \n#> 1 j_index binary 0.250 50 0.0288 Preprocessor1_Model1\n#> 2 roc_auc binary 0.953 50 0.00479 Preprocessor1_Model1\n```\n:::\n\n\n\n\nIt looks like ROSE helped a lot, especially with the J-index. Class imbalance sampling methods tend to greatly improve metrics based on the hard class predictions (i.e., the categorical predictions) because the default cutoff tends to be a better balance of sensitivity and specificity. \n\nLet's plot the metrics for each resample to see how the individual results changed. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nno_sampling <- \n qda_only_res %>% \n collect_metrics(summarize = FALSE) %>% \n dplyr::select(-.estimator) %>% \n mutate(sampling = \"no_sampling\")\n\nwith_sampling <- \n qda_rose_res %>% \n collect_metrics(summarize = FALSE) %>% \n dplyr::select(-.estimator) %>% \n mutate(sampling = \"rose\")\n\nbind_rows(no_sampling, with_sampling) %>% \n mutate(label = paste(id2, id)) %>% \n ggplot(aes(x = sampling, y = .estimate, group = label)) + \n geom_line(alpha = .4) + \n facet_wrap(~ .metric, scales = \"free_y\")\n```\n\n::: {.cell-output-display}\n![](figs/merge-metrics-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThis visually demonstrates that the subsampling mostly affects metrics that use the hard class predictions. \n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> discrim 1.0.1 2023-03-08 CRAN (R 4.4.0)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> klaR 1.7-3 2023-12-13 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> readr 2.1.5 2024-01-10 CRAN (R 4.4.0)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> ROSE 0.0-4 2021-06-14 CRAN (R 4.4.0)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> themis 1.0.3 2025-01-23 CRAN (R 4.4.1)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", "supporting": [], "filters": [ "rmarkdown/pagebreak.lua" diff --git a/_freeze/learn/statistics/infer/index/execute-results/html.json b/_freeze/learn/statistics/infer/index/execute-results/html.json index 035adcec..915cc0ab 100644 --- a/_freeze/learn/statistics/infer/index/execute-results/html.json +++ b/_freeze/learn/statistics/infer/index/execute-results/html.json @@ -2,7 +2,7 @@ "hash": "4158a1c8a702b6966548fdb8f055388e", "result": { "engine": "knitr", - "markdown": "---\ntitle: \"Hypothesis testing using resampling and tidy data\"\ncategories:\n - statistical analysis\n - hypothesis testing\n - bootstrapping\ntype: learn-subsection\nweight: 4\ndescription: | \n Perform common hypothesis tests for statistical inference using flexible functions.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n\n## Introduction\n\nThis article only requires the tidymodels package. \n\nThe tidymodels package [infer](https://infer.tidymodels.org/) implements an expressive grammar to perform statistical inference that coheres with the `tidyverse` design framework. Rather than providing methods for specific statistical tests, this package consolidates the principles that are shared among common hypothesis tests into a set of 4 main verbs (functions), supplemented with many utilities to visualize and extract information from their outputs.\n\nRegardless of which hypothesis test we're using, we're still asking the same kind of question: \n\n>Is the effect or difference in our observed data real, or due to chance? \n\nTo answer this question, we start by assuming that the observed data came from some world where \"nothing is going on\" (i.e. the observed effect was simply due to random chance), and call this assumption our **null hypothesis**. (In reality, we might not believe in the null hypothesis at all; the null hypothesis is in opposition to the **alternate hypothesis**, which supposes that the effect present in the observed data is actually due to the fact that \"something is going on.\") We then calculate a **test statistic** from our data that describes the observed effect. We can use this test statistic to calculate a **p-value**, giving the probability that our observed data could come about if the null hypothesis was true. If this probability is below some pre-defined **significance level** $\\alpha$, then we can reject our null hypothesis.\n\nIf you are new to hypothesis testing, take a look at \n\n* [Section 9.2 of _Statistical Inference via Data Science_](https://moderndive.com/9-hypothesis-testing.html#understanding-ht)\n* The American Statistical Association's recent [statement on p-values](https://doi.org/10.1080/00031305.2016.1154108) \n\nThe workflow of this package is designed around these ideas. Starting from some data set,\n\n+ `specify()` allows you to specify the variable, or relationship between variables, that you're interested in,\n+ `hypothesize()` allows you to declare the null hypothesis,\n+ `generate()` allows you to generate data reflecting the null hypothesis, and\n+ `calculate()` allows you to calculate a distribution of statistics from the generated data to form the null distribution.\n\nThroughout this vignette, we make use of `gss`, a data set available in infer containing a sample of 500 observations of 11 variables from the *General Social Survey*. \n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels) # Includes the infer package\n\n# Set seed\nset.seed(1234)\n\n# load in the data set\ndata(gss)\n\n# take a look at its structure\ndplyr::glimpse(gss)\n#> Rows: 500\n#> Columns: 11\n#> $ year 2014, 1994, 1998, 1996, 1994, 1996, 1990, 2016, 2000, 1998, 20…\n#> $ age 36, 34, 24, 42, 31, 32, 48, 36, 30, 33, 21, 30, 38, 49, 25, 56…\n#> $ sex male, female, male, male, male, female, female, female, female…\n#> $ college degree, no degree, degree, no degree, degree, no degree, no de…\n#> $ partyid ind, rep, ind, ind, rep, rep, dem, ind, rep, dem, dem, ind, de…\n#> $ hompop 3, 4, 1, 4, 2, 4, 2, 1, 5, 2, 4, 3, 4, 4, 2, 2, 3, 2, 1, 2, 5,…\n#> $ hours 50, 31, 40, 40, 40, 53, 32, 20, 40, 40, 23, 52, 38, 72, 48, 40…\n#> $ income $25000 or more, $20000 - 24999, $25000 or more, $25000 or more…\n#> $ class middle class, working class, working class, working class, mid…\n#> $ finrela below average, below average, below average, above average, ab…\n#> $ weight 0.8960034, 1.0825000, 0.5501000, 1.0864000, 1.0825000, 1.08640…\n```\n:::\n\n\n\n\n\nEach row is an individual survey response, containing some basic demographic information on the respondent as well as some additional variables. See `?gss` for more information on the variables included and their source. Note that this data (and our examples on it) are for demonstration purposes only, and will not necessarily provide accurate estimates unless weighted properly. For these examples, let's suppose that this data set is a representative sample of a population we want to learn about: American adults.\n\n## Specify variables\n\nThe `specify()` function can be used to specify which of the variables in the data set you're interested in. If you're only interested in, say, the `age` of the respondents, you might write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = age)\n#> Response: age (numeric)\n#> # A tibble: 500 × 1\n#> age\n#> \n#> 1 36\n#> 2 34\n#> 3 24\n#> 4 42\n#> 5 31\n#> 6 32\n#> 7 48\n#> 8 36\n#> 9 30\n#> 10 33\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\nOn the front end, the output of `specify()` just looks like it selects off the columns in the dataframe that you've specified. What do we see if we check the class of this object, though?\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = age) %>%\n class()\n#> [1] \"infer\" \"tbl_df\" \"tbl\" \"data.frame\"\n```\n:::\n\n\n\n\n\nWe can see that the infer class has been appended on top of the dataframe classes; this new class stores some extra metadata.\n\nIf you're interested in two variables (`age` and `partyid`, for example) you can `specify()` their relationship in one of two (equivalent) ways:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# as a formula\ngss %>%\n specify(age ~ partyid)\n#> Response: age (numeric)\n#> Explanatory: partyid (factor)\n#> # A tibble: 500 × 2\n#> age partyid\n#> \n#> 1 36 ind \n#> 2 34 rep \n#> 3 24 ind \n#> 4 42 ind \n#> 5 31 rep \n#> 6 32 rep \n#> 7 48 dem \n#> 8 36 ind \n#> 9 30 rep \n#> 10 33 dem \n#> # ℹ 490 more rows\n\n# with the named arguments\ngss %>%\n specify(response = age, explanatory = partyid)\n#> Response: age (numeric)\n#> Explanatory: partyid (factor)\n#> # A tibble: 500 × 2\n#> age partyid\n#> \n#> 1 36 ind \n#> 2 34 rep \n#> 3 24 ind \n#> 4 42 ind \n#> 5 31 rep \n#> 6 32 rep \n#> 7 48 dem \n#> 8 36 ind \n#> 9 30 rep \n#> 10 33 dem \n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\nIf you're doing inference on one proportion or a difference in proportions, you will need to use the `success` argument to specify which level of your `response` variable is a success. For instance, if you're interested in the proportion of the population with a college degree, you might use the following code:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# specifying for inference on proportions\ngss %>%\n specify(response = college, success = \"degree\")\n#> Response: college (factor)\n#> # A tibble: 500 × 1\n#> college \n#> \n#> 1 degree \n#> 2 no degree\n#> 3 degree \n#> 4 no degree\n#> 5 degree \n#> 6 no degree\n#> 7 no degree\n#> 8 degree \n#> 9 degree \n#> 10 no degree\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\n## Declare the hypothesis\n\nThe next step in the infer pipeline is often to declare a null hypothesis using `hypothesize()`. The first step is to supply one of \"independence\" or \"point\" to the `null` argument. If your null hypothesis assumes independence between two variables, then this is all you need to supply to `hypothesize()`:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(college ~ partyid, success = \"degree\") %>%\n hypothesize(null = \"independence\")\n#> Response: college (factor)\n#> Explanatory: partyid (factor)\n#> Null Hypothesis: independence\n#> # A tibble: 500 × 2\n#> college partyid\n#> \n#> 1 degree ind \n#> 2 no degree rep \n#> 3 degree ind \n#> 4 no degree ind \n#> 5 degree rep \n#> 6 no degree rep \n#> 7 no degree dem \n#> 8 degree ind \n#> 9 degree rep \n#> 10 no degree dem \n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\nIf you're doing inference on a point estimate, you will also need to provide one of `p` (the true proportion of successes, between 0 and 1), `mu` (the true mean), `med` (the true median), or `sigma` (the true standard deviation). For instance, if the null hypothesis is that the mean number of hours worked per week in our population is 40, we would write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40)\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 500 × 1\n#> hours\n#> \n#> 1 50\n#> 2 31\n#> 3 40\n#> 4 40\n#> 5 40\n#> 6 53\n#> 7 32\n#> 8 20\n#> 9 40\n#> 10 40\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n\nAgain, from the front-end, the dataframe outputted from `hypothesize()` looks almost exactly the same as it did when it came out of `specify()`, but infer now \"knows\" your null hypothesis.\n\n## Generate the distribution\n\nOnce we've asserted our null hypothesis using `hypothesize()`, we can construct a null distribution based on this hypothesis. We can do this using one of several methods, supplied in the `type` argument:\n\n* `bootstrap`: A bootstrap sample will be drawn for each replicate, where a sample of size equal to the input sample size is drawn (with replacement) from the input sample data. \n* `permute`: For each replicate, each input value will be randomly reassigned (without replacement) to a new output value in the sample. \n* `simulate`: A value will be sampled from a theoretical distribution with parameters specified in `hypothesize()` for each replicate. (This option is currently only applicable for testing point estimates.) \n\nContinuing on with our example above, about the average number of hours worked a week, we might write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\")\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 2,500,000 × 2\n#> # Groups: replicate [5,000]\n#> replicate hours\n#> \n#> 1 1 58.6\n#> 2 1 35.6\n#> 3 1 28.6\n#> 4 1 38.6\n#> 5 1 28.6\n#> 6 1 38.6\n#> 7 1 38.6\n#> 8 1 57.6\n#> 9 1 58.6\n#> 10 1 38.6\n#> # ℹ 2,499,990 more rows\n```\n:::\n\n\n\n\n\nIn the above example, we take 5000 bootstrap samples to form our null distribution.\n\nTo generate a null distribution for the independence of two variables, we could also randomly reshuffle the pairings of explanatory and response variables to break any existing association. For instance, to generate 5000 replicates that can be used to create a null distribution under the assumption that political party affiliation is not affected by age:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(partyid ~ age) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\")\n#> Response: partyid (factor)\n#> Explanatory: age (numeric)\n#> Null Hypothesis: independence\n#> # A tibble: 2,500,000 × 3\n#> # Groups: replicate [5,000]\n#> partyid age replicate\n#> \n#> 1 ind 36 1\n#> 2 ind 34 1\n#> 3 ind 24 1\n#> 4 rep 42 1\n#> 5 dem 31 1\n#> 6 dem 32 1\n#> 7 dem 48 1\n#> 8 rep 36 1\n#> 9 ind 30 1\n#> 10 dem 33 1\n#> # ℹ 2,499,990 more rows\n```\n:::\n\n\n\n\n\n## Calculate statistics\n\nDepending on whether you're carrying out computation-based inference or theory-based inference, you will either supply `calculate()` with the output of `generate()` or `hypothesize()`, respectively. The function, for one, takes in a `stat` argument, which is currently one of `\"mean\"`, `\"median\"`, `\"sum\"`, `\"sd\"`, `\"prop\"`, `\"count\"`, `\"diff in means\"`, `\"diff in medians\"`, `\"diff in props\"`, `\"Chisq\"`, `\"F\"`, `\"t\"`, `\"z\"`, `\"slope\"`, or `\"correlation\"`. For example, continuing our example above to calculate the null distribution of mean hours worked per week:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\") %>%\n calculate(stat = \"mean\")\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 5,000 × 2\n#> replicate stat\n#> \n#> 1 1 39.8\n#> 2 2 39.6\n#> 3 3 39.8\n#> 4 4 39.2\n#> 5 5 39.0\n#> 6 6 39.8\n#> 7 7 40.6\n#> 8 8 40.6\n#> 9 9 40.4\n#> 10 10 39.0\n#> # ℹ 4,990 more rows\n```\n:::\n\n\n\n\n\nThe output of `calculate()` here shows us the sample statistic (in this case, the mean) for each of our 1000 replicates. If you're carrying out inference on differences in means, medians, or proportions, or $t$ and $z$ statistics, you will need to supply an `order` argument, giving the order in which the explanatory variables should be subtracted. For instance, to find the difference in mean age of those that have a college degree and those that don't, we might write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(age ~ college) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(\"diff in means\", order = c(\"degree\", \"no degree\"))\n#> Response: age (numeric)\n#> Explanatory: college (factor)\n#> Null Hypothesis: independence\n#> # A tibble: 5,000 × 2\n#> replicate stat\n#> \n#> 1 1 -0.0378\n#> 2 2 1.55 \n#> 3 3 0.465 \n#> 4 4 1.39 \n#> 5 5 -0.161 \n#> 6 6 -0.179 \n#> 7 7 0.0151\n#> 8 8 0.914 \n#> 9 9 -1.32 \n#> 10 10 -0.426 \n#> # ℹ 4,990 more rows\n```\n:::\n\n\n\n\n\n## Other utilities\n\nThe infer package also offers several utilities to extract meaning out of summary statistics and null distributions; the package provides functions to visualize where a statistic is relative to a distribution (with `visualize()`), calculate p-values (with `get_p_value()`), and calculate confidence intervals (with `get_confidence_interval()`).\n\nTo illustrate, we'll go back to the example of determining whether the mean number of hours worked per week is 40 hours.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# find the point estimate\npoint_estimate <- gss %>%\n specify(response = hours) %>%\n calculate(stat = \"mean\")\n\n# generate a null distribution\nnull_dist <- gss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\") %>%\n calculate(stat = \"mean\")\n```\n:::\n\n\n\n\n\n(Notice the warning: `Removed 1244 rows containing missing values.` This would be worth noting if you were actually carrying out this hypothesis test.)\n\nOur point estimate 41.382 seems *pretty* close to 40, but a little bit different. We might wonder if this difference is just due to random chance, or if the mean number of hours worked per week in the population really isn't 40.\n\nWe could initially just visualize the null distribution.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_dist %>%\n visualize()\n```\n\n::: {.cell-output-display}\n![](figs/visualize-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nWhere does our sample's observed statistic lie on this distribution? We can use the `obs_stat` argument to specify this.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_dist %>%\n visualize() +\n shade_p_value(obs_stat = point_estimate, direction = \"two_sided\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize2-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nNotice that infer has also shaded the regions of the null distribution that are as (or more) extreme than our observed statistic. (Also, note that we now use the `+` operator to apply the `shade_p_value()` function. This is because `visualize()` outputs a plot object from ggplot2 instead of a dataframe, and the `+` operator is needed to add the p-value layer to the plot object.) The red bar looks like it's slightly far out on the right tail of the null distribution, so observing a sample mean of 41.382 hours would be somewhat unlikely if the mean was actually 40 hours. How unlikely, though?\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# get a two-tailed p-value\np_value <- null_dist %>%\n get_p_value(obs_stat = point_estimate, direction = \"two_sided\")\n\np_value\n#> # A tibble: 1 × 1\n#> p_value\n#> \n#> 1 0.046\n```\n:::\n\n\n\n\n\nIt looks like the p-value is 0.046, which is pretty small---if the true mean number of hours worked per week was actually 40, the probability of our sample mean being this far (1.382 hours) from 40 would be 0.046. This may or may not be statistically significantly different, depending on the significance level $\\alpha$ you decided on *before* you ran this analysis. If you had set $\\alpha = .05$, then this difference would be statistically significant, but if you had set $\\alpha = .01$, then it would not be.\n\nTo get a confidence interval around our estimate, we can write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# start with the null distribution\nnull_dist %>%\n # calculate the confidence interval around the point estimate\n get_confidence_interval(point_estimate = point_estimate,\n # at the 95% confidence level\n level = .95,\n # using the standard error\n type = \"se\")\n#> # A tibble: 1 × 2\n#> lower_ci upper_ci\n#> \n#> 1 40.1 42.7\n```\n:::\n\n\n\n\n\nAs you can see, 40 hours per week is not contained in this interval, which aligns with our previous conclusion that this finding is significant at the confidence level $\\alpha = .05$.\n\n## Theoretical methods\n\nThe infer package also provides functionality to use theoretical methods for `\"Chisq\"`, `\"F\"` and `\"t\"` test statistics. \n\nGenerally, to find a null distribution using theory-based methods, use the same code that you would use to find the null distribution using randomization-based methods, but skip the `generate()` step. For example, if we wanted to find a null distribution for the relationship between age (`age`) and party identification (`partyid`) using randomization, we could write:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_f_distn <- gss %>%\n specify(age ~ partyid) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\n\nTo find the null distribution using theory-based methods, instead, skip the `generate()` step entirely:\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_f_distn_theoretical <- gss %>%\n specify(age ~ partyid) %>%\n hypothesize(null = \"independence\") %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\n\nWe'll calculate the observed statistic to make use of in the following visualizations; this procedure is the same, regardless of the methods used to find the null distribution.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nF_hat <- gss %>% \n specify(age ~ partyid) %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\n\nNow, instead of just piping the null distribution into `visualize()`, as we would do if we wanted to visualize the randomization-based null distribution, we also need to provide `method = \"theoretical\"` to `visualize()`.\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nvisualize(null_f_distn_theoretical, method = \"theoretical\") +\n shade_p_value(obs_stat = F_hat, direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-22-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nTo get a sense of how the theory-based and randomization-based null distributions relate, we can pipe the randomization-based null distribution into `visualize()` and also specify `method = \"both\"`\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nvisualize(null_f_distn, method = \"both\") +\n shade_p_value(obs_stat = F_hat, direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-23-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\n\nThat's it! This vignette covers most all of the key functionality of infer. See `help(package = \"infer\")` for a full list of functions and vignettes.\n\n\n## Session information {#session-info}\n\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", + "markdown": "---\ntitle: \"Hypothesis testing using resampling and tidy data\"\ncategories:\n - statistical analysis\n - hypothesis testing\n - bootstrapping\ntype: learn-subsection\nweight: 4\ndescription: | \n Perform common hypothesis tests for statistical inference using flexible functions.\ntoc: true\ntoc-depth: 2\ninclude-after-body: ../../../resources.html\n---\n\n\n\n\n\n\n\n\n## Introduction\n\nThis article only requires the tidymodels package. \n\nThe tidymodels package [infer](https://infer.tidymodels.org/) implements an expressive grammar to perform statistical inference that coheres with the `tidyverse` design framework. Rather than providing methods for specific statistical tests, this package consolidates the principles that are shared among common hypothesis tests into a set of 4 main verbs (functions), supplemented with many utilities to visualize and extract information from their outputs.\n\nRegardless of which hypothesis test we're using, we're still asking the same kind of question: \n\n>Is the effect or difference in our observed data real, or due to chance? \n\nTo answer this question, we start by assuming that the observed data came from some world where \"nothing is going on\" (i.e. the observed effect was simply due to random chance), and call this assumption our **null hypothesis**. (In reality, we might not believe in the null hypothesis at all; the null hypothesis is in opposition to the **alternate hypothesis**, which supposes that the effect present in the observed data is actually due to the fact that \"something is going on.\") We then calculate a **test statistic** from our data that describes the observed effect. We can use this test statistic to calculate a **p-value**, giving the probability that our observed data could come about if the null hypothesis was true. If this probability is below some pre-defined **significance level** $\\alpha$, then we can reject our null hypothesis.\n\nIf you are new to hypothesis testing, take a look at \n\n* [Section 9.2 of _Statistical Inference via Data Science_](https://moderndive.com/9-hypothesis-testing.html#understanding-ht)\n* The American Statistical Association's recent [statement on p-values](https://doi.org/10.1080/00031305.2016.1154108) \n\nThe workflow of this package is designed around these ideas. Starting from some data set,\n\n+ `specify()` allows you to specify the variable, or relationship between variables, that you're interested in,\n+ `hypothesize()` allows you to declare the null hypothesis,\n+ `generate()` allows you to generate data reflecting the null hypothesis, and\n+ `calculate()` allows you to calculate a distribution of statistics from the generated data to form the null distribution.\n\nThroughout this vignette, we make use of `gss`, a data set available in infer containing a sample of 500 observations of 11 variables from the *General Social Survey*. \n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nlibrary(tidymodels) # Includes the infer package\n\n# Set seed\nset.seed(1234)\n\n# load in the data set\ndata(gss)\n\n# take a look at its structure\ndplyr::glimpse(gss)\n#> Rows: 500\n#> Columns: 11\n#> $ year 2014, 1994, 1998, 1996, 1994, 1996, 1990, 2016, 2000, 1998, 20…\n#> $ age 36, 34, 24, 42, 31, 32, 48, 36, 30, 33, 21, 30, 38, 49, 25, 56…\n#> $ sex male, female, male, male, male, female, female, female, female…\n#> $ college degree, no degree, degree, no degree, degree, no degree, no de…\n#> $ partyid ind, rep, ind, ind, rep, rep, dem, ind, rep, dem, dem, ind, de…\n#> $ hompop 3, 4, 1, 4, 2, 4, 2, 1, 5, 2, 4, 3, 4, 4, 2, 2, 3, 2, 1, 2, 5,…\n#> $ hours 50, 31, 40, 40, 40, 53, 32, 20, 40, 40, 23, 52, 38, 72, 48, 40…\n#> $ income $25000 or more, $20000 - 24999, $25000 or more, $25000 or more…\n#> $ class middle class, working class, working class, working class, mid…\n#> $ finrela below average, below average, below average, above average, ab…\n#> $ weight 0.8960034, 1.0825000, 0.5501000, 1.0864000, 1.0825000, 1.08640…\n```\n:::\n\n\n\n\nEach row is an individual survey response, containing some basic demographic information on the respondent as well as some additional variables. See `?gss` for more information on the variables included and their source. Note that this data (and our examples on it) are for demonstration purposes only, and will not necessarily provide accurate estimates unless weighted properly. For these examples, let's suppose that this data set is a representative sample of a population we want to learn about: American adults.\n\n## Specify variables\n\nThe `specify()` function can be used to specify which of the variables in the data set you're interested in. If you're only interested in, say, the `age` of the respondents, you might write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = age)\n#> Response: age (numeric)\n#> # A tibble: 500 × 1\n#> age\n#> \n#> 1 36\n#> 2 34\n#> 3 24\n#> 4 42\n#> 5 31\n#> 6 32\n#> 7 48\n#> 8 36\n#> 9 30\n#> 10 33\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\nOn the front end, the output of `specify()` just looks like it selects off the columns in the dataframe that you've specified. What do we see if we check the class of this object, though?\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = age) %>%\n class()\n#> [1] \"infer\" \"tbl_df\" \"tbl\" \"data.frame\"\n```\n:::\n\n\n\n\nWe can see that the infer class has been appended on top of the dataframe classes; this new class stores some extra metadata.\n\nIf you're interested in two variables (`age` and `partyid`, for example) you can `specify()` their relationship in one of two (equivalent) ways:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# as a formula\ngss %>%\n specify(age ~ partyid)\n#> Response: age (numeric)\n#> Explanatory: partyid (factor)\n#> # A tibble: 500 × 2\n#> age partyid\n#> \n#> 1 36 ind \n#> 2 34 rep \n#> 3 24 ind \n#> 4 42 ind \n#> 5 31 rep \n#> 6 32 rep \n#> 7 48 dem \n#> 8 36 ind \n#> 9 30 rep \n#> 10 33 dem \n#> # ℹ 490 more rows\n\n# with the named arguments\ngss %>%\n specify(response = age, explanatory = partyid)\n#> Response: age (numeric)\n#> Explanatory: partyid (factor)\n#> # A tibble: 500 × 2\n#> age partyid\n#> \n#> 1 36 ind \n#> 2 34 rep \n#> 3 24 ind \n#> 4 42 ind \n#> 5 31 rep \n#> 6 32 rep \n#> 7 48 dem \n#> 8 36 ind \n#> 9 30 rep \n#> 10 33 dem \n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\nIf you're doing inference on one proportion or a difference in proportions, you will need to use the `success` argument to specify which level of your `response` variable is a success. For instance, if you're interested in the proportion of the population with a college degree, you might use the following code:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# specifying for inference on proportions\ngss %>%\n specify(response = college, success = \"degree\")\n#> Response: college (factor)\n#> # A tibble: 500 × 1\n#> college \n#> \n#> 1 degree \n#> 2 no degree\n#> 3 degree \n#> 4 no degree\n#> 5 degree \n#> 6 no degree\n#> 7 no degree\n#> 8 degree \n#> 9 degree \n#> 10 no degree\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\n## Declare the hypothesis\n\nThe next step in the infer pipeline is often to declare a null hypothesis using `hypothesize()`. The first step is to supply one of \"independence\" or \"point\" to the `null` argument. If your null hypothesis assumes independence between two variables, then this is all you need to supply to `hypothesize()`:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(college ~ partyid, success = \"degree\") %>%\n hypothesize(null = \"independence\")\n#> Response: college (factor)\n#> Explanatory: partyid (factor)\n#> Null Hypothesis: independence\n#> # A tibble: 500 × 2\n#> college partyid\n#> \n#> 1 degree ind \n#> 2 no degree rep \n#> 3 degree ind \n#> 4 no degree ind \n#> 5 degree rep \n#> 6 no degree rep \n#> 7 no degree dem \n#> 8 degree ind \n#> 9 degree rep \n#> 10 no degree dem \n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\nIf you're doing inference on a point estimate, you will also need to provide one of `p` (the true proportion of successes, between 0 and 1), `mu` (the true mean), `med` (the true median), or `sigma` (the true standard deviation). For instance, if the null hypothesis is that the mean number of hours worked per week in our population is 40, we would write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40)\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 500 × 1\n#> hours\n#> \n#> 1 50\n#> 2 31\n#> 3 40\n#> 4 40\n#> 5 40\n#> 6 53\n#> 7 32\n#> 8 20\n#> 9 40\n#> 10 40\n#> # ℹ 490 more rows\n```\n:::\n\n\n\n\nAgain, from the front-end, the dataframe outputted from `hypothesize()` looks almost exactly the same as it did when it came out of `specify()`, but infer now \"knows\" your null hypothesis.\n\n## Generate the distribution\n\nOnce we've asserted our null hypothesis using `hypothesize()`, we can construct a null distribution based on this hypothesis. We can do this using one of several methods, supplied in the `type` argument:\n\n* `bootstrap`: A bootstrap sample will be drawn for each replicate, where a sample of size equal to the input sample size is drawn (with replacement) from the input sample data. \n* `permute`: For each replicate, each input value will be randomly reassigned (without replacement) to a new output value in the sample. \n* `simulate`: A value will be sampled from a theoretical distribution with parameters specified in `hypothesize()` for each replicate. (This option is currently only applicable for testing point estimates.) \n\nContinuing on with our example above, about the average number of hours worked a week, we might write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\")\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 2,500,000 × 2\n#> # Groups: replicate [5,000]\n#> replicate hours\n#> \n#> 1 1 58.6\n#> 2 1 35.6\n#> 3 1 28.6\n#> 4 1 38.6\n#> 5 1 28.6\n#> 6 1 38.6\n#> 7 1 38.6\n#> 8 1 57.6\n#> 9 1 58.6\n#> 10 1 38.6\n#> # ℹ 2,499,990 more rows\n```\n:::\n\n\n\n\nIn the above example, we take 5000 bootstrap samples to form our null distribution.\n\nTo generate a null distribution for the independence of two variables, we could also randomly reshuffle the pairings of explanatory and response variables to break any existing association. For instance, to generate 5000 replicates that can be used to create a null distribution under the assumption that political party affiliation is not affected by age:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(partyid ~ age) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\")\n#> Response: partyid (factor)\n#> Explanatory: age (numeric)\n#> Null Hypothesis: independence\n#> # A tibble: 2,500,000 × 3\n#> # Groups: replicate [5,000]\n#> partyid age replicate\n#> \n#> 1 ind 36 1\n#> 2 ind 34 1\n#> 3 ind 24 1\n#> 4 rep 42 1\n#> 5 dem 31 1\n#> 6 dem 32 1\n#> 7 dem 48 1\n#> 8 rep 36 1\n#> 9 ind 30 1\n#> 10 dem 33 1\n#> # ℹ 2,499,990 more rows\n```\n:::\n\n\n\n\n## Calculate statistics\n\nDepending on whether you're carrying out computation-based inference or theory-based inference, you will either supply `calculate()` with the output of `generate()` or `hypothesize()`, respectively. The function, for one, takes in a `stat` argument, which is currently one of `\"mean\"`, `\"median\"`, `\"sum\"`, `\"sd\"`, `\"prop\"`, `\"count\"`, `\"diff in means\"`, `\"diff in medians\"`, `\"diff in props\"`, `\"Chisq\"`, `\"F\"`, `\"t\"`, `\"z\"`, `\"slope\"`, or `\"correlation\"`. For example, continuing our example above to calculate the null distribution of mean hours worked per week:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\") %>%\n calculate(stat = \"mean\")\n#> Response: hours (numeric)\n#> Null Hypothesis: point\n#> # A tibble: 5,000 × 2\n#> replicate stat\n#> \n#> 1 1 39.8\n#> 2 2 39.6\n#> 3 3 39.8\n#> 4 4 39.2\n#> 5 5 39.0\n#> 6 6 39.8\n#> 7 7 40.6\n#> 8 8 40.6\n#> 9 9 40.4\n#> 10 10 39.0\n#> # ℹ 4,990 more rows\n```\n:::\n\n\n\n\nThe output of `calculate()` here shows us the sample statistic (in this case, the mean) for each of our 1000 replicates. If you're carrying out inference on differences in means, medians, or proportions, or $t$ and $z$ statistics, you will need to supply an `order` argument, giving the order in which the explanatory variables should be subtracted. For instance, to find the difference in mean age of those that have a college degree and those that don't, we might write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\ngss %>%\n specify(age ~ college) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(\"diff in means\", order = c(\"degree\", \"no degree\"))\n#> Response: age (numeric)\n#> Explanatory: college (factor)\n#> Null Hypothesis: independence\n#> # A tibble: 5,000 × 2\n#> replicate stat\n#> \n#> 1 1 -0.0378\n#> 2 2 1.55 \n#> 3 3 0.465 \n#> 4 4 1.39 \n#> 5 5 -0.161 \n#> 6 6 -0.179 \n#> 7 7 0.0151\n#> 8 8 0.914 \n#> 9 9 -1.32 \n#> 10 10 -0.426 \n#> # ℹ 4,990 more rows\n```\n:::\n\n\n\n\n## Other utilities\n\nThe infer package also offers several utilities to extract meaning out of summary statistics and null distributions; the package provides functions to visualize where a statistic is relative to a distribution (with `visualize()`), calculate p-values (with `get_p_value()`), and calculate confidence intervals (with `get_confidence_interval()`).\n\nTo illustrate, we'll go back to the example of determining whether the mean number of hours worked per week is 40 hours.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# find the point estimate\npoint_estimate <- gss %>%\n specify(response = hours) %>%\n calculate(stat = \"mean\")\n\n# generate a null distribution\nnull_dist <- gss %>%\n specify(response = hours) %>%\n hypothesize(null = \"point\", mu = 40) %>%\n generate(reps = 5000, type = \"bootstrap\") %>%\n calculate(stat = \"mean\")\n```\n:::\n\n\n\n\n(Notice the warning: `Removed 1244 rows containing missing values.` This would be worth noting if you were actually carrying out this hypothesis test.)\n\nOur point estimate 41.382 seems *pretty* close to 40, but a little bit different. We might wonder if this difference is just due to random chance, or if the mean number of hours worked per week in the population really isn't 40.\n\nWe could initially just visualize the null distribution.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_dist %>%\n visualize()\n```\n\n::: {.cell-output-display}\n![](figs/visualize-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nWhere does our sample's observed statistic lie on this distribution? We can use the `obs_stat` argument to specify this.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_dist %>%\n visualize() +\n shade_p_value(obs_stat = point_estimate, direction = \"two_sided\")\n```\n\n::: {.cell-output-display}\n![](figs/visualize2-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nNotice that infer has also shaded the regions of the null distribution that are as (or more) extreme than our observed statistic. (Also, note that we now use the `+` operator to apply the `shade_p_value()` function. This is because `visualize()` outputs a plot object from ggplot2 instead of a dataframe, and the `+` operator is needed to add the p-value layer to the plot object.) The red bar looks like it's slightly far out on the right tail of the null distribution, so observing a sample mean of 41.382 hours would be somewhat unlikely if the mean was actually 40 hours. How unlikely, though?\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# get a two-tailed p-value\np_value <- null_dist %>%\n get_p_value(obs_stat = point_estimate, direction = \"two_sided\")\n\np_value\n#> # A tibble: 1 × 1\n#> p_value\n#> \n#> 1 0.046\n```\n:::\n\n\n\n\nIt looks like the p-value is 0.046, which is pretty small---if the true mean number of hours worked per week was actually 40, the probability of our sample mean being this far (1.382 hours) from 40 would be 0.046. This may or may not be statistically significantly different, depending on the significance level $\\alpha$ you decided on *before* you ran this analysis. If you had set $\\alpha = .05$, then this difference would be statistically significant, but if you had set $\\alpha = .01$, then it would not be.\n\nTo get a confidence interval around our estimate, we can write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\n# start with the null distribution\nnull_dist %>%\n # calculate the confidence interval around the point estimate\n get_confidence_interval(point_estimate = point_estimate,\n # at the 95% confidence level\n level = .95,\n # using the standard error\n type = \"se\")\n#> # A tibble: 1 × 2\n#> lower_ci upper_ci\n#> \n#> 1 40.1 42.7\n```\n:::\n\n\n\n\nAs you can see, 40 hours per week is not contained in this interval, which aligns with our previous conclusion that this finding is significant at the confidence level $\\alpha = .05$.\n\n## Theoretical methods\n\nThe infer package also provides functionality to use theoretical methods for `\"Chisq\"`, `\"F\"` and `\"t\"` test statistics. \n\nGenerally, to find a null distribution using theory-based methods, use the same code that you would use to find the null distribution using randomization-based methods, but skip the `generate()` step. For example, if we wanted to find a null distribution for the relationship between age (`age`) and party identification (`partyid`) using randomization, we could write:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_f_distn <- gss %>%\n specify(age ~ partyid) %>%\n hypothesize(null = \"independence\") %>%\n generate(reps = 5000, type = \"permute\") %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\nTo find the null distribution using theory-based methods, instead, skip the `generate()` step entirely:\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nnull_f_distn_theoretical <- gss %>%\n specify(age ~ partyid) %>%\n hypothesize(null = \"independence\") %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\nWe'll calculate the observed statistic to make use of in the following visualizations; this procedure is the same, regardless of the methods used to find the null distribution.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nF_hat <- gss %>% \n specify(age ~ partyid) %>%\n calculate(stat = \"F\")\n```\n:::\n\n\n\n\nNow, instead of just piping the null distribution into `visualize()`, as we would do if we wanted to visualize the randomization-based null distribution, we also need to provide `method = \"theoretical\"` to `visualize()`.\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nvisualize(null_f_distn_theoretical, method = \"theoretical\") +\n shade_p_value(obs_stat = F_hat, direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-22-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nTo get a sense of how the theory-based and randomization-based null distributions relate, we can pipe the randomization-based null distribution into `visualize()` and also specify `method = \"both\"`\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```{.r .cell-code}\nvisualize(null_f_distn, method = \"both\") +\n shade_p_value(obs_stat = F_hat, direction = \"greater\")\n```\n\n::: {.cell-output-display}\n![](figs/unnamed-chunk-23-1.svg){fig-align='center' width=672}\n:::\n:::\n\n\n\n\nThat's it! This vignette covers most all of the key functionality of infer. See `help(package = \"infer\")` for a full list of functions and vignettes.\n\n\n## Session information {#session-info}\n\n\n\n\n::: {.cell layout-align=\"center\"}\n\n```\n#> ─ Session info ─────────────────────────────────────────────────────\n#> version R version 4.4.2 (2024-10-31)\n#> language (EN)\n#> date 2025-03-24\n#> pandoc 3.6.1\n#> quarto 1.6.42\n#> \n#> ─ Packages ─────────────────────────────────────────────────────────\n#> package version date (UTC) source\n#> broom 1.0.7 2024-09-26 CRAN (R 4.4.1)\n#> dials 1.4.0 2025-02-13 CRAN (R 4.4.2)\n#> dplyr 1.1.4 2023-11-17 CRAN (R 4.4.0)\n#> ggplot2 3.5.1 2024-04-23 CRAN (R 4.4.0)\n#> infer 1.0.7 2024-03-25 CRAN (R 4.4.0)\n#> parsnip 1.3.1 2025-03-12 CRAN (R 4.4.1)\n#> purrr 1.0.4 2025-02-05 CRAN (R 4.4.1)\n#> recipes 1.2.0 2025-03-17 CRAN (R 4.4.1)\n#> rlang 1.1.5 2025-01-17 CRAN (R 4.4.2)\n#> rsample 1.2.1 2024-03-25 CRAN (R 4.4.0)\n#> tibble 3.2.1 2023-03-20 CRAN (R 4.4.0)\n#> tidymodels 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> tune 1.3.0 2025-02-21 CRAN (R 4.4.1)\n#> workflows 1.2.0 2025-02-19 CRAN (R 4.4.1)\n#> yardstick 1.3.2 2025-01-22 CRAN (R 4.4.1)\n#> \n#> ────────────────────────────────────────────────────────────────────\n```\n:::\n", "supporting": [], "filters": [ "rmarkdown/pagebreak.lua"