Skip to content

Commit

Permalink
More Efficient group_by(...) %>% sample_*(...)
Browse files Browse the repository at this point in the history
Improves `df %>% group_by(...) %>% sample_*(...)` performance by 10-100x for dataset with large number of groups.

The motivation is that when performing stratified sampling using `group_by %>% sample_n` on 100k+ strata, it can take minutes or longer. A toy example shows every 1k groups increases runtime by ~2s. A quick profiling shows most of time is spent in `eval_tidy(weight, ...)` for each group. This PR performs the weight calculation using `mutate` instead to preserve the semantics but eliminates the repeated overscope lookup across groups.
  • Loading branch information
Forest Fang committed Nov 13, 2017
1 parent a1cbc89 commit 9fc9696
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Expand Up @@ -22,6 +22,7 @@

* `select()` and `vars()` now treat `NULL` as empty inputs (#3023).

* `sample_n()` and `sample_frac()` on grouped data frame are now faster especially for those with large number of groups (#3193, @saurfang).

# dplyr 0.7.4

Expand Down
9 changes: 4 additions & 5 deletions R/grouped-df.r
Expand Up @@ -265,11 +265,11 @@ sample_n.grouped_df <- function(tbl, size, replace = FALSE,
inform("`.env` is deprecated and no longer has any effect")
}
weight <- enquo(weight)
weight <- mutate(tbl, w = !!weight)[["w"]]

index <- attr(tbl, "indices")
sampled <- lapply(index, sample_group,
frac = FALSE,
tbl = tbl,
size = size,
replace = replace,
weight = weight
Expand All @@ -292,11 +292,11 @@ sample_frac.grouped_df <- function(tbl, size = 1, replace = FALSE,
)
}
weight <- enquo(weight)
weight <- mutate(tbl, w = !!weight)[["w"]]

index <- attr(tbl, "indices")
sampled <- lapply(index, sample_group,
frac = TRUE,
tbl = tbl,
size = size,
replace = replace,
weight = weight
Expand All @@ -306,7 +306,7 @@ sample_frac.grouped_df <- function(tbl, size = 1, replace = FALSE,
grouped_df(tbl[idx, , drop = FALSE], vars = groups(tbl))
}

sample_group <- function(tbl, i, frac, size, replace, weight) {
sample_group <- function(i, frac, size, replace, weight) {
n <- length(i)
if (frac) {
check_frac(size, replace)
Expand All @@ -315,9 +315,8 @@ sample_group <- function(tbl, i, frac, size, replace, weight) {
check_size(size, n, replace)
}

weight <- eval_tidy(weight, tbl[i + 1, , drop = FALSE])
if (!is_null(weight)) {
weight <- check_weight(weight, n)
weight <- check_weight(weight[i + 1], n)
}

i[sample.int(n, size, replace = replace, prob = weight)]
Expand Down

0 comments on commit 9fc9696

Please sign in to comment.