Skip to content

ztluostat/BAMDT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Bayesian Additive Semi-Structured Regression Trees

R Implementation of Bayesian Additive Semi-Structured Regression Trees (BAMDT).

Reference:

Luo, Z. T., Sang, H., & Mallick, B. (2022) BAMDT: Bayesian Additive Semi-Multivariate Decision Trees for Nonparametric Regression. Proceedings of the 39th International Conference on Machine Learning (ICML 2022) [link]

Files

  • Demo.R: Demo code for fitting BAMDT on a U-shape domain
  • Tree.R: Class of semi-structured decision trees
  • Model.R: Class of BAMDT models
  • ComplexDomainFun.R: Utility functions for U-shape domain
  • SimData.R: Code to simulate data on a U-shape domain (no need to run unless you would like to regenerate data)
  • input_U.RData: Data sets generated by running SimData.R

Dependecies

The code depends on the following R packages: R6, collections, igraph, fdaPDE, BART, sf, ggplot2. Please make sure they are installed before running the demo code.

How to use BAMDT

model = Model$new(Y, X, graphs, projections, hyperpar, X_new, projections_new)

creates a BAMDT model object named model.

Parameters:

  • Y: Numeric responses vector of length n.
  • X: Numeric unstructured training features of size n * p.
  • graphs: List of M spatial graphs, where M is the number of trees. Each graph should be an igraph object.
  • projections: Integer matrix of size n * M, where projections[i, j] is the nearest knot index corresponding to training observation i for tree j.
  • hyperpar: Named vector of hyperparameters with the following elements
    • hyperpar['M']: Number of trees M.
    • hyperpar['sigmasq_mu']: Variance of prior for $\mu$, i.e., $\sigma^2_\mu$.
    • hyperpar['q']: Quantile used to calibrate prior for noise variance $\sigma^2$.
    • hyperpar['nu']: Degree of freedom of the inverse-$\chi^2$ prior for noise variance $\sigma^2$.
    • hyperpar['alpha']: Hyperparameter $\alpha$ in tree generating process.
    • hyperpar['beta']: Hyperparameter $\beta$ in tree generating process.
    • hyperpar['numcut']: Number of candidate split points for unstructured features.
    • hyperpar['prob_split_by_x']: Probability for performing a unstructured split.
  • X_new: Numeric unstructured test features of size n_new * p.
  • projections_new: Integer matrix of size n_ho * M, where projections_new[i, j] is the nearest knot index corresponding to test observation i for tree j.

To fit a BAMDT model and predict for test data, use

model$Fit(init_val, MCMC, BURNIN, THIN, seed = 1234, save_partitions = FALSE)

Parameters:

  • init_val: Named list of initial values with the following element
    • init_val[['sigmasq_y']]: Initial value for noise variance $\sigma^2$.
  • MCMC: Number of MCMC iterations.
  • BURNIN: Number of burn-in iterations.
  • THIN: Retain MCMC samples every THIN iterations, i.e., the number of posterior samples is npost = (MCMC - BURNIN) / THIN.
  • seed: Random seed.
  • save_partition: Logical value indicating whether posterior samples of partitions are saved. Default is FALSE (recommended). Setting save_partition = TRUE is highly memory inefficient.

The model object has the following public members:

  • model$sigmasq_y_out: Posterior samples of noise variance $\sigma^2$.
  • model$g_out: npost * n * M array of posterior samples of (in-sample) fitted values from each tree.
  • model$Y_new_out: npost * n_new matrix of posterior samples of (out-of-sample) predicted values.
  • model$importance_out: npost * (p + 1) matrix of posterior samples of feature importance metrics.

About

R Implementation of Bayesian Additive Semi-Structured Regression Trees (BAMDT).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages