Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Efficient group_by(...) %>% sample_*(...) #3193

Merged
merged 2 commits into from Dec 30, 2017

Conversation

saurfang
Copy link
Contributor

@saurfang saurfang commented Nov 6, 2017

Improves df %>% group_by(...) %>% sample_*(...) performance by 10-100x for dataset with large number of groups.

The motivation is that when performing stratified sampling using group_by %>% sample_n on 100k+ strata, it can take minutes or longer. A toy example shows every 1k groups increases runtime by ~2s. A quick profiling shows most of time is spent in eval_tidy(weight, ...) for each group. This PR performs the weight calculation using mutate instead to preserve the semantics but eliminates the repeated overscope lookup across groups.

library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#>     filter, lag
#> The following objects are masked from 'package:base':
#>
#>     intersect, setdiff, setequal, union

n_strata <- 1000
# number of original observations in each strata (follows normal distribution)
n_per_strata_mean <- 100
n_per_strata_sd <- 10

# how many to sample from each strata
n_sample <- 10

source_df <-
  data_frame(group = 1:n_strata) %>%
  group_by(group) %>%
  do(sample_n(iris, round(rnorm(1, n_per_strata_mean, n_per_strata_sd)), replace = TRUE)) %>%
  ungroup() %>%
  # shuffle
  sample_frac(1)

source_df
#> # A tibble: 100,151 x 6
#>    group Sepal.Length Sepal.Width Petal.Length Petal.Width    Species
#>    <int>        <dbl>       <dbl>        <dbl>       <dbl>     <fctr>
#>  1   406          5.0         3.6          1.4         0.2     setosa
#>  2   998          5.5         2.4          3.8         1.1 versicolor
#>  3   101          6.4         2.8          5.6         2.2  virginica
#>  4   803          5.6         2.7          4.2         1.3 versicolor
#>  5   897          6.7         3.3          5.7         2.5  virginica
#>  6   743          4.9         3.1          1.5         0.2     setosa
#>  7   885          6.2         2.2          4.5         1.5 versicolor
#>  8    37          5.0         3.6          1.4         0.2     setosa
#>  9   276          5.2         4.1          1.5         0.1     setosa
#> 10   225          6.9         3.2          5.7         2.3  virginica
#> # ... with 100,141 more rows

# dplyr master ###########################
## sample without replacement no weights
system.time({
  source_df %>%
    group_by(group) %>%
    sample_n(n_sample)
})
#>    user  system elapsed
#>   1.677   0.006   1.687
## sample without replacement with weights
system.time({
  source_df %>%
    group_by(group) %>%
    sample_n(n_sample, weight = Petal.Length)
})
#>    user  system elapsed
#>   1.827   0.017   1.857
# sample with replacement
system.time({
  source_df %>%
    group_by(group) %>%
    sample_n(n_sample, replace = TRUE)
})
#>    user  system elapsed
#>   1.860   0.015   1.895

# dplyr dev ##################################
devtools::load_all("~/workspace/dplyr")
#> Loading dplyr
## sample without replacement no weights
system.time({
  source_df %>%
    group_by(group) %>%
    sample_n(n_sample)
})
#>    user  system elapsed
#>   0.019   0.000   0.023
## sample without replacement with weights
system.time({
  source_df %>%
    group_by(group) %>%
    sample_n(n_sample, weight = Petal.Length)
})
#>    user  system elapsed
#>   0.065   0.002   0.070
# sample with replacement
system.time({
  source_df %>%
    group_by(group) %>%
    sample_n(n_sample, replace = TRUE)
})
#>    user  system elapsed
#>   0.017   0.000   0.018

Before (dplyr master)

image

After

image

@hadley
Copy link
Member

hadley commented Nov 6, 2017

Great - much easier to start with this minimal change and then figure out the other stuff.

Can you please add a bullet to NEWS? It should briefly describe the change (starting with name of the function), and crediting yourself with (@yourname, #issuenumber).

@saurfang
Copy link
Contributor Author

Done. Can you please take another look?

@@ -265,11 +265,11 @@ sample_n.grouped_df <- function(tbl, size, replace = FALSE,
inform("`.env` is deprecated and no longer has any effect")
}
weight <- enquo(weight)
weight <- mutate(tbl, w = !!weight)[["w"]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a space after the !! please?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, !! is going to have very high precedence so now it makes sense to have it close to its argument just like unary -.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Let me know if there is anything else I can help to help this PR merged.

Improves `df %>% group_by(...) %>% sample_*(...)` performance by 10-100x for dataset with large number of groups.

The motivation is that when performing stratified sampling using `group_by %>% sample_n` on 100k+ strata, it can take minutes or longer. A toy example shows every 1k groups increases runtime by ~2s. A quick profiling shows most of time is spent in `eval_tidy(weight, ...)` for each group. This PR performs the weight calculation using `mutate` instead to preserve the semantics but eliminates the repeated overscope lookup across groups.
Copy link
Member

@krlmlr krlmlr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks good to me except for the small nit.

@@ -292,11 +292,11 @@ sample_frac.grouped_df <- function(tbl, size = 1, replace = FALSE,
)
}
weight <- enquo(weight)
weight <- mutate(tbl, w = !!weight)[["w"]]

index <- attr(tbl, "indices")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks a bit cleaner (and maybe faster) if we assign index <- attr(...) + 1L here and don't add in sample_group().

@krlmlr krlmlr merged commit a0d0e72 into tidyverse:master Dec 30, 2017
@krlmlr
Copy link
Member

krlmlr commented Dec 30, 2017

Thanks!

krlmlr added a commit that referenced this pull request Mar 14, 2018
@lock lock bot locked as resolved and limited conversation to collaborators Jun 28, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants