-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for additional arguments (discrete, nthreads) when predicting from mgcv::bam #1068
Comments
Thanks for raising the issue. First, I'll note that you can already pass •any• extra argument you want to your package's predict() function. All unknown arguments.are already pushed forward via Second, that is unfortunately not going to be of much help here, because I'd be curious to try implement parallelization marginaleffects. Could you show me an example model with a smaller public dataset? I can't promise a short term solution, but id like to take a look at this eventually. |
I've put some example code in this gist. It uses the nycflights data to estimate a Poisson model. It's similar to my data in that it's a nonlinear model with a mix of random effects and splines. The standard errors from predictions always seem slower than the standard errors from mgcv::predict.bam. This is true whether or not discretization is used. With standard errors, When I run this code, I get the following warning:
|
Thanks for the Gist. I'll take a look when I find some time. I'm not surprised about the speed difference. If they can do it all with algebra, it's always going to be tons faster than with numeric differentiation. But maybe we can get some wins with parallelization. We'll see... The warning is there as a precaution. All arguments are passed automatically to the prediction function, so the arguments are supported. The warning simply indicates that the arguments are not "known" by |
I made a first attempt at parallelizing standard errors. This will But maybe there’s still some gains to be had? This is incomplete, but you can give it a shot by installing the PR See below for timings with your example on my 8 core laptop. Installlibrary(remotes)
install_github(repo="vincentarelbundock/marginaleffects", ref = github_pull(1071)) Prep and fitlibrary(mgcv)
library(marginaleffects)
library(nycflights13)
library(tictoc)
data("flights")
my_threads <- 8
set.seed(3)
flights <- flights |>
transform(date = as.Date(paste(year, month, day, sep = "/"))) |>
transform(date.num = as.numeric(date - min(date)))
flights <- flights |>
transform(wday = as.POSIXlt(date)$wday)
flights <- flights |>
transform(time = as.POSIXct(paste(hour, minute, sep = ":"), format = "%H:%M")) |>
transform(time.dt = difftime(time,
as.POSIXct('00:00', format = '%H:%M'), units = 'min')) |>
transform(time.num = as.numeric(time.dt))
flights <- flights |>
transform(dep_delay = ifelse(dep_delay < 0, 0, dep_delay)) |>
transform(dep_delay = ifelse(is.na(dep_delay), 0, dep_delay))
flights <- flights |>
transform(carrier = factor(carrier)) |>
transform(dest = factor(dest)) |>
transform(origin = factor(origin))
m_discrete <- bam(dep_delay ~ s(date.num, bs = "cr") +
s(wday, bs = "cc", k = 3) +
s(time.num, bs = "cr") +
s(carrier, bs = "re") +
origin +
s(distance, bs = "cr") +
s(dest, bs = "re"),
data = flights,
family = poisson,
discrete = TRUE,
nthreads = my_threads) Slowtic()
options(marginaleffects_cores = 1)
p1 <- predictions(m_discrete)
toc()
93.461 sec elapsed Faster?tic()
options(marginaleffects_cores = my_threads)
p8 <- predictions(m_discrete)
toc()
31.872 sec elapsed |
@vincentarelbundock I can confirm a roughly 3x speedup on this data and my (Linux) machine, and it's nice to see that this automagically carries over to comparisons(), which is where I started. Thank you so much for this -- it's amazing that you were able to do this so quickly! I'm going to mark this as closed because my starting assumption about the arguments was wrong. |
Great news! To be clear, I think the parallel feature is far from complete. For example, it doesn't work on Windows, and I think there might be better implementations out there. I'll open a separate parallel issue to make sure I don't forget. I'm swamped with work now, so can't promise super fast completion, unfortunately. |
Sorry to resurrect this old closed issue, but I suspect there is a speed benefit to From
So while |
Thanks @Aariq the argument should be white listed in the dev version on Github. |
Estimation of generalized additive models can be done quickly using discretization of covariates in the
bam
function in themgcv
package.Prediction from
bam
models can be speeded up by specifying that the model is discrete and specifying a number of threads. Per the documentationI'm asking for the two arguments
discrete
andnthreads
to be supported in predictions frombam
models. I'm asking because I estimated a beta regression on around a million observations, and it seems to be taking more than a day to make predictions for three representative observations passed as newdata. Obviously I can setvcov = FALSE
, but I need the CIs.The text was updated successfully, but these errors were encountered: