Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial dependence plots for ensemble models #335

Closed
Tracked by #55
dinilu opened this issue Jun 27, 2023 · 1 comment
Closed
Tracked by #55

Partial dependence plots for ensemble models #335

dinilu opened this issue Jun 27, 2023 · 1 comment

Comments

@dinilu
Copy link

dinilu commented Jun 27, 2023

Any plan to make the p_pdp function work with ensemble models?

@sjevelazco
Copy link
Owner

Hi

This is an interesting feature to be added to p_pdp function. However, I need to think carefully about implementing it to this function based on different ensemble approaches. I will add this implementation to my "todo" list ;)
Meanwhile, you can try to use the data_pdp, which returns a tibble object you can use to plot the curves.

Follow an example

library(terra)
library(dplyr)
library(tidyr)
library(ggplot2)
library(flexsdm)


somevar <- system.file("external/somevar.tif", package = "flexsdm")
somevar <- terra::rast(somevar) # environmental data
names(somevar) <- c("aet", "cwd", "tmx", "tmn")
data(abies)

abies2 <- abies %>%
  select(x, y, pr_ab)

abies2 <- sdm_extract(abies2,
  x = "x",
  y = "y",
  env_layer = somevar
)
abies2 <- part_random(abies2,
  pr_ab = "pr_ab",
  method = c(method = "kfold", folds = 5)
)

svm_t1 <- fit_svm(
  data = abies2,
  response = "pr_ab",
  predictors = c("aet", "cwd", "tmx", "tmn"),
  partition = ".part",
  thr = c("max_sens_spec")
)

gam_t1 <- fit_gam(
  data = abies2,
  response = "pr_ab",
  predictors = c("aet", "cwd", "tmx", "tmn"),
  partition = ".part",
  thr = c("max_sens_spec"), k = -1
)

raf_t1 <- fit_raf(
  data = abies2,
  response = "pr_ab",
  predictors = c("aet", "cwd", "tmx", "tmn"),
  partition = ".part",
  thr = c("max_sens_spec")
)

df_svm <- data_pdp(
  model = svm_t1$model,
  predictors = c("aet"),
  resolution = 100,
  resid = TRUE,
  projection_data = somevar,
  training_data = abies2,
  clamping = FALSE
)

df_gam <- data_pdp(
  model = gam_t1$model,
  predictors = c("aet"),
  resolution = 100,
  resid = TRUE,
  projection_data = somevar,
  training_data = abies2,
  clamping = FALSE
)
df_raf <- data_pdp(
  model = raf_t1$model,
  predictors = c("aet"),
  resolution = 100,
  resid = TRUE,
  projection_data = somevar,
  training_data = abies2,
  clamping = FALSE
)

db_list <- list(SVM = df_svm$pdpdata, 
                GAM = df_gam$pdpdata, 
                RAF = df_raf$pdpdata)
db_list <- bind_rows(db_list, .id = "Algorithm")
db_list <- db_list %>% tidyr::spread(Algorithm, Suitability)
# Calculate ensemble
db_list <- db_list %>%
  dplyr::rowwise() %>%
  mutate(Ensemble = mean(c(GAM, SVM, RAF)))

db_list <- db_list %>% 
  tidyr::gather("Algorithm", "Suitability", GAM, SVM, RAF, Ensemble)
db_list$`-aet` <- NULL

ggplot(db_list, aes(aet, Suitability)) +
  geom_line(aes(color = Algorithm))

ggplot(
  db_list %>% dplyr::filter(Algorithm == "Ensemble"),
  aes(aet, Suitability)
) +
  geom_line(aes(color = Algorithm))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants