Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for additional arguments (discrete, nthreads) when predicting from mgcv::bam #1068

Closed
chrishanretty opened this issue Apr 4, 2024 · 8 comments

Comments

@chrishanretty
Copy link

Estimation of generalized additive models can be done quickly using discretization of covariates in the bam function in the mgcv package.

Prediction from bam models can be speeded up by specifying that the model is discrete and specifying a number of threads. Per the documentation

if se.fit=TRUE and discrete prediction is used then parallel computation can be used to speed up se calcualtion. This specifies number of htreads to use.

I'm asking for the two arguments discrete and nthreads to be supported in predictions from bam models. I'm asking because I estimated a beta regression on around a million observations, and it seems to be taking more than a day to make predictions for three representative observations passed as newdata. Obviously I can set vcov = FALSE, but I need the CIs.

@vincentarelbundock
Copy link
Owner

vincentarelbundock commented Apr 4, 2024

Thanks for raising the issue.

First, I'll note that you can already pass •any• extra argument you want to your package's predict() function. All unknown arguments.are already pushed forward via ...

Second, that is unfortunately not going to be of much help here, because marginaleffects does not use the standard errors supplied by any package, and always computes its own. So if a package offers a way to parallelize SE computation, it won't matter because the SEs are always going to be computes in-house anyway.

I'd be curious to try implement parallelization marginaleffects. Could you show me an example model with a smaller public dataset?

I can't promise a short term solution, but id like to take a look at this eventually.

@chrishanretty
Copy link
Author

I've put some example code in this gist.

It uses the nycflights data to estimate a Poisson model. It's similar to my data in that it's a nonlinear model with a mix of random effects and splines.

The standard errors from predictions always seem slower than the standard errors from mgcv::predict.bam. This is true whether or not discretization is used. With standard errors, predictions is around twenty times slower. But then mgcv has been heavily, heavily optimized.

When I run this code, I get the following warning:

"These arguments are not supported for models of class bam: discrete, nthreads. Valid arguments include: exclude. Please file a request on Github if you believe that additional arguments should be supported: https://github.com/vincentarelbundock/marginaleffects/issues "

@vincentarelbundock
Copy link
Owner

Thanks for the Gist. I'll take a look when I find some time.

I'm not surprised about the speed difference. If they can do it all with algebra, it's always going to be tons faster than with numeric differentiation. But maybe we can get some wins with parallelization. We'll see...

The warning is there as a precaution. All arguments are passed automatically to the prediction function, so the arguments are supported. The warning simply indicates that the arguments are not "known" by marginaleffects. I'll try to modify wording to make that clear.

@vincentarelbundock
Copy link
Owner

@chrishanretty

I made a first attempt at parallelizing standard errors. This will
always be much slower than bam, and it’s only likely to matter when
coef(mod) returns a lot of parameters.

But maybe there’s still some gains to be had?

This is incomplete, but you can give it a shot by installing the PR
branch: #1071

See below for timings with your example on my 8 core laptop.

Install

library(remotes)
install_github(repo="vincentarelbundock/marginaleffects", ref = github_pull(1071))

Prep and fit

library(mgcv)
library(marginaleffects)
library(nycflights13)
library(tictoc)
data("flights")

my_threads <- 8
set.seed(3)


flights <- flights |>
    transform(date = as.Date(paste(year, month, day, sep = "/"))) |>
    transform(date.num = as.numeric(date - min(date)))

flights <- flights |>
    transform(wday = as.POSIXlt(date)$wday)

flights <- flights |>
    transform(time = as.POSIXct(paste(hour, minute, sep = ":"), format = "%H:%M")) |>
    transform(time.dt = difftime(time,
                                 as.POSIXct('00:00', format = '%H:%M'), units = 'min')) |>
    transform(time.num = as.numeric(time.dt))

flights <- flights |>
    transform(dep_delay = ifelse(dep_delay < 0, 0, dep_delay)) |>
    transform(dep_delay = ifelse(is.na(dep_delay), 0, dep_delay))

flights <- flights |>
    transform(carrier = factor(carrier)) |>
    transform(dest = factor(dest)) |>
    transform(origin = factor(origin))


m_discrete <- bam(dep_delay ~ s(date.num, bs = "cr") +
                  s(wday, bs = "cc", k = 3) +
                  s(time.num, bs = "cr") +
                  s(carrier, bs = "re") +
                  origin +
                  s(distance, bs = "cr") + 
                  s(dest, bs = "re"),
              data = flights,
              family = poisson,
              discrete = TRUE,
              nthreads = my_threads)

Slow

tic()
options(marginaleffects_cores = 1)
p1 <- predictions(m_discrete)
toc()

    93.461 sec elapsed

Faster?

tic()
options(marginaleffects_cores = my_threads)
p8 <- predictions(m_discrete)
toc()

    31.872 sec elapsed

@chrishanretty
Copy link
Author

@vincentarelbundock I can confirm a roughly 3x speedup on this data and my (Linux) machine, and it's nice to see that this automagically carries over to comparisons(), which is where I started. Thank you so much for this -- it's amazing that you were able to do this so quickly! I'm going to mark this as closed because my starting assumption about the arguments was wrong.

@vincentarelbundock
Copy link
Owner

Great news!

To be clear, I think the parallel feature is far from complete. For example, it doesn't work on Windows, and I think there might be better implementations out there. I'll open a separate parallel issue to make sure I don't forget.

I'm swamped with work now, so can't promise super fast completion, unfortunately.

@Aariq
Copy link

Aariq commented Aug 21, 2024

Sorry to resurrect this old closed issue, but I suspect there is a speed benefit to discrete = TRUE regardless of whether standard errors are calculated.

From predict.bam:

discrete if TRUE then discrete prediction methods used with model fitted by discrete methods. FALSE for regular prediction. See details.

Details
When discrete=TRUE the prediction data in newdata is discretized in the same way as is done when using discrete fitting methods with bam. However the discretization grids are not currently identical to those used during fitting. Instead, discretization is done afresh for the prediction data. This means that if you are predicting for a relatively small set of prediction data, or on a regular grid, then the results may in fact be identical to those obtained without discretization. The disadvantage to this approach is that if you make predictions with a large data frame, and then split it into smaller data frames to make the predictions again, the results may differ slightly, because of slightly different discretization errors.

So while n.threads may not provide any speed up, discrete = TRUE seems like it might. It might be nice to silence the warning printed when discrete is passed to predict.bam().

@vincentarelbundock
Copy link
Owner

Thanks @Aariq the argument should be white listed in the dev version on Github.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants