/
RFCDE.Rmd
84 lines (66 loc) · 1.91 KB
/
RFCDE.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
---
title: "RFCDE"
author: "Taylor Pospisil"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{RFCDE}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
```{r setup, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
```
Random forests is a common non-parametric regression technique which
performs well for mixed-type data and irrelevant covariates, while
being robust to monotonic variable transformations. RFCDE fits random
forest models optimized for nonparametric conditional density
estimation.
## Example
```{r}
library(RFCDE)
set.seed(42)
generate_data <- function(n) {
x_relevant <- matrix(runif(n * 10), n, 10)
x_irrelevant <- matrix(runif(n * 10), n, 10)
z <- rnorm(n, rowSums(x_relevant), 1)
return(list(x = cbind(x_relevant, x_irrelevant), z = z))
}
n_train <- 10000
n_test <- 4
train_data <- generate_data(n_train)
x_train <- train_data$x
z_train <- train_data$z
```
### Training
Trees are recursively partitioned to minimize the CDE loss
$$ \int \int (\hat{f}(x \mid z) - f(x \mid z))^{2} dz dP(x) $$
This is efficiently calculated using an orthogonal series
representation of the conditional densities. The resolution of this
representation is controlled by `n_basis`.
```{r}
n_trees <- 1000
mtry <- 4
node_size <- 20
n_basis <- 15
forest <- RFCDE(x_train, z_train, n_trees = n_trees, mtry = mtry,
node_size = node_size, n_basis = n_basis)
```
### Prediction
We use the forest structure to determine weights for a weighted kernel
density estimate. The `predict` function evaluates this density on the
provided grid.
```{r}
bandwidth <- 0.2
n_grid <- 100
z_grid <- seq(0, 10, length.out = n_grid)
x_test <- generate_data(1)$x
density <- predict(forest, x_test, "CDE", z_grid, bandwidth = bandwidth)
```
```{r}
plot(z_grid, density, type = "l", col = "green")
lines(z_grid, dnorm(z_grid, sum(x_test[1:10]), 1))
```