Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add butcher methods for nestedmodels #256

Merged
merged 8 commits into from Mar 19, 2023
Merged

add butcher methods for nestedmodels #256

merged 8 commits into from Mar 19, 2023

Conversation

ashbythorpe
Copy link
Contributor

A simple set of methods for nested_model_fit objects, just iterates through all the inner models and calls axe_() on them.

library(nestedmodels)
library(parsnip)

model <- linear_reg() %>%
  set_engine("lm") %>%
  nested()

nested_data <- tidyr::nest(example_nested_data, data = -id)
fit <- fit(model, z ~ x + y + a + b, nested_data)

weigh(fit)
#> # A tibble: 1,101 × 2
#>    object                  size
#>    <chr>                  <dbl>
#>  1 .model_fit.fit.terms 0.00831
#>  2 .model_fit.fit.terms 0.00831
#>  3 .model_fit.fit.terms 0.00831
#>  4 .model_fit.fit.terms 0.00831
#>  5 .model_fit.fit.terms 0.00831
#>  6 .model_fit.fit.terms 0.00831
#>  7 .model_fit.fit.terms 0.00831
#>  8 .model_fit.fit.terms 0.00831
#>  9 .model_fit.fit.terms 0.00831
#> 10 .model_fit.fit.terms 0.00831
#> # … with 1,091 more rows

res <- butcher(fit)

weigh(res)
#> # A tibble: 1,101 × 2
#>    object                                       size
#>    <chr>                                       <dbl>
#>  1 .model_fit.spec.method.pred.conf_int.post 0.00710
#>  2 .model_fit.spec.method.pred.pred_int.post 0.00710
#>  3 .model_fit.spec.method.pred.conf_int.post 0.00710
#>  4 .model_fit.spec.method.pred.pred_int.post 0.00710
#>  5 .model_fit.spec.method.pred.conf_int.post 0.00710
#>  6 .model_fit.spec.method.pred.pred_int.post 0.00710
#>  7 .model_fit.spec.method.pred.conf_int.post 0.00710
#>  8 .model_fit.spec.method.pred.pred_int.post 0.00710
#>  9 .model_fit.spec.method.pred.conf_int.post 0.00710
#> 10 .model_fit.spec.method.pred.pred_int.post 0.00710
#> # … with 1,091 more rows

data_tst <- data.frame(
  id = 1, id2 = 1, x = 1, y = 1, a = -1, b = 10
)

predict(fit, data_tst)
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  40.2

predict(res, data_tst)
#> # A tibble: 1 × 1
#>   .pred
#>   <dbl>
#> 1  40.2

@ashbythorpe ashbythorpe marked this pull request as ready for review March 16, 2023 14:18
Copy link
Member

@juliasilge juliasilge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for this contribution @ashbythorpe! 🙌 I've got a couple of questions about how you set this up.

R/nested_model_fit.R Show resolved Hide resolved
R/nested_model_fit.R Show resolved Hide resolved
R/nested_model_fit.R Show resolved Hide resolved
verbose = FALSE,
...
)
all_disabled <- purrr::map(x$fit$.model_fit, attr, "disabled")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell me more about how the disabled methods are intended to be found here? I don't seem to find this working for me:

library(butcher)

model <- nestedmodels::nested(
  parsnip::set_engine(parsnip::linear_reg(), "lm")
)
nested_data <- tidyr::nest(nestedmodels::example_nested_data, data = -id)
fit <- parsnip::fit(model, z ~ x + y + a + b, nested_data)

fit$fit$.model_fit <- purrr::map(
  fit$fit$.model_fit,
  axe_fitted,
  verbose = TRUE
)
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`
#> ✔ Memory released: "0 B"

purrr::map(fit$fit$.model_fit, attr, "disabled")
#> [[1]]
#> NULL
#> 
#> [[2]]
#> NULL
#> 
#> [[3]]
#> NULL
#> 
#> [[4]]
#> NULL
#> 
#> [[5]]
#> NULL
#> 
#> [[6]]
#> NULL
#> 
#> [[7]]
#> NULL
#> 
#> [[8]]
#> NULL
#> 
#> [[9]]
#> NULL
#> 
#> [[10]]
#> NULL
#> 
#> [[11]]
#> NULL
#> 
#> [[12]]
#> NULL
#> 
#> [[13]]
#> NULL
#> 
#> [[14]]
#> NULL
#> 
#> [[15]]
#> NULL
#> 
#> [[16]]
#> NULL
#> 
#> [[17]]
#> NULL
#> 
#> [[18]]
#> NULL
#> 
#> [[19]]
#> NULL
#> 
#> [[20]]
#> NULL

Created on 2023-03-17 with reprex v2.0.2

@ashbythorpe
Copy link
Contributor Author

Apologies, I think I should have been more clear about what I was doing with the verbose option. Since in the "nested_model_fit" object, all the inner models are the same type of model, the functions disabled by each axing function will be the same for every inner model. As a result, my idea was to create a summary of these models, outputting the functions disabled for every inner model and the total memory released, rather than having a separate message for each model. A similar thing is done in ipred.R (except there the disabled methods are already known).

This example demonstrates this (the previous commit should have fixed the issue with the "butcher_disabled" attribute).

library(nestedmodels)
library(parsnip)

model <- linear_reg() %>%
  set_engine("lm") %>%
  nested()

nested_data <- tidyr::nest(example_nested_data, data = -id)
fit <- fit(model, z ~ x + y + a + b, nested_data)

res <- axe_call(fit, verbose = TRUE)
#> ✔ Memory released: "8.47 kB"
#> ✖ Disabled: `print()` and `summary()`

res <- axe_fitted(fit, verbose = TRUE)
#> ✔ Memory released: "5.71 kB"
#> ✖ Disabled: `fitted()` and `summary()`

res <- butcher(fit, verbose = TRUE)
#> ✔ Memory released: "61.96 kB"
#> ✖ Disabled: `print()`, `summary()`, and `fitted()`

# To compare, this is the result of axing one inner model
inner_model <- fit$fit$.model_fit[[5]]$fit

res <- axe_call(inner_model, verbose = TRUE)
#> ✔ Memory released: "920 B"
#> ✖ Disabled: `print()` and `summary()`

res <- axe_fitted(inner_model, verbose = TRUE)
#> ✔ Memory released: "80 B"
#> ✖ Disabled: `fitted()` and `summary()`

res <- butcher(inner_model, verbose = TRUE)
#> ✔ Memory released: "3.92 kB"
#> ✖ Disabled: `print()`, `summary()`, and `fitted()`

Let me know if you think it would be better to show the full output to the user.

Copy link
Member

@juliasilge juliasilge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes OK, makes sense to me now, and thank you for fixing the disabled methods! 🙌

@juliasilge juliasilge merged commit d7b07cc into tidymodels:main Mar 19, 2023
11 checks passed
@ashbythorpe ashbythorpe deleted the nestedmodels branch March 19, 2023 22:48
@ashbythorpe ashbythorpe restored the nestedmodels branch March 19, 2023 22:49
@ashbythorpe ashbythorpe deleted the nestedmodels branch March 19, 2023 22:49
@github-actions
Copy link

github-actions bot commented Apr 3, 2023

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Apr 3, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants