From ce4ba23b3089cf40cc505e5051a3342ee74048be Mon Sep 17 00:00:00 2001 From: Forest Fang Date: Mon, 6 Nov 2017 00:10:23 -0800 Subject: [PATCH] More Efficient `group_by(...) %>% sample_*(...)` Improves `df %>% group_by(...) %>% sample_*(...)` performance by 10-100x for dataset with large number of groups. The motivation is that when performing stratified sampling using `group_by %>% sample_n` on 100k+ strata, it can take minutes or longer. A toy example shows every 1k groups increases runtime by ~2s. A quick profiling shows most of time is spent in `eval_tidy(weight, ...)` for each group. This PR performs the weight calculation using `mutate` instead to preserve the semantics but eliminates the repeated overscope lookup across groups. --- NEWS.md | 2 ++ R/grouped-df.r | 9 ++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/NEWS.md b/NEWS.md index 95866abbf0..8762713886 100644 --- a/NEWS.md +++ b/NEWS.md @@ -45,6 +45,8 @@ * Add error for `distinct()` if any of the selected columns are of type `list` (#3088, @foo-bar-baz-qux). +* `sample_n()` and `sample_frac()` on grouped data frame are now faster especially for those with large number of groups (#3193, @saurfang). + # dplyr 0.7.4 * Fix recent Fedora and ASAN check errors (#3098). diff --git a/R/grouped-df.r b/R/grouped-df.r index dd84e3e1d2..8f8190975a 100644 --- a/R/grouped-df.r +++ b/R/grouped-df.r @@ -265,11 +265,11 @@ sample_n.grouped_df <- function(tbl, size, replace = FALSE, inform("`.env` is deprecated and no longer has any effect") } weight <- enquo(weight) + weight <- mutate(tbl, w = !!weight)[["w"]] index <- attr(tbl, "indices") sampled <- lapply(index, sample_group, frac = FALSE, - tbl = tbl, size = size, replace = replace, weight = weight @@ -292,11 +292,11 @@ sample_frac.grouped_df <- function(tbl, size = 1, replace = FALSE, ) } weight <- enquo(weight) + weight <- mutate(tbl, w = !!weight)[["w"]] index <- attr(tbl, "indices") sampled <- lapply(index, sample_group, frac = TRUE, - tbl = tbl, size = size, replace = replace, weight = weight @@ -306,7 +306,7 @@ sample_frac.grouped_df <- function(tbl, size = 1, replace = FALSE, grouped_df(tbl[idx, , drop = FALSE], vars = groups(tbl)) } -sample_group <- function(tbl, i, frac, size, replace, weight) { +sample_group <- function(i, frac, size, replace, weight) { n <- length(i) if (frac) { check_frac(size, replace) @@ -315,9 +315,8 @@ sample_group <- function(tbl, i, frac, size, replace, weight) { check_size(size, n, replace) } - weight <- eval_tidy(weight, tbl[i + 1, , drop = FALSE]) if (!is_null(weight)) { - weight <- check_weight(weight, n) + weight <- check_weight(weight[i + 1], n) } i[sample.int(n, size, replace = replace, prob = weight)]