Skip to content

Commit

Permalink
Merge pull request #100 from richford/enh/bagging
Browse files Browse the repository at this point in the history
ENH: Allow forestci to work on general Bagging estimators
  • Loading branch information
arokem committed Dec 17, 2020
2 parents 5d88ca2 + 52947e7 commit 7f36ef7
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 97 deletions.
10 changes: 6 additions & 4 deletions examples/plot_mpg.py
Expand Up @@ -19,7 +19,7 @@
import forestci as fci

# retreive mpg data from machine learning library
mpg_data = fetch_openml('autompg')
mpg_data = fetch_openml(data_id=196)

# separate mpg data into predictors and outcome variable
mpg_X = mpg_data["data"]
Expand All @@ -32,9 +32,11 @@
mpg_y = mpg_y[not_null_sel]

# split mpg data into training and test set
mpg_X_train, mpg_X_test, mpg_y_train, mpg_y_test = xval.train_test_split(mpg_X, mpg_y,
test_size=0.25,
random_state=42)
mpg_X_train, mpg_X_test, mpg_y_train, mpg_y_test = xval.train_test_split(
mpg_X,
mpg_y,
test_size=0.25,
random_state=42)

# Create RandomForestRegressor
n_trees = 2000
Expand Down
62 changes: 62 additions & 0 deletions examples/plot_mpg_svr.py
@@ -0,0 +1,62 @@
"""
======================================
Plotting Bagging Regression Error Bars
======================================
This example demonstrates using `forestci` to calculate the error bars of
the predictions of a :class:`sklearn.ensemble.BaggingRegressor` object.
The data used here are a classical machine learning data-set, describing
various features of different cars, and their MPG.
"""

# Regression Forest Example
import numpy as np
from matplotlib import pyplot as plt
from sklearn.ensemble import BaggingRegressor
from sklearn.svm import SVR
import sklearn.model_selection as xval
from sklearn.datasets import fetch_openml
import forestci as fci

# retreive mpg data from machine learning library
mpg_data = fetch_openml(data_id=196)

# separate mpg data into predictors and outcome variable
mpg_X = mpg_data["data"]
mpg_y = mpg_data["target"]

# remove rows where the data is nan
not_null_sel = np.invert(np.sum(np.isnan(mpg_data["data"]), axis=1).astype(bool))
mpg_X = mpg_X[not_null_sel]
mpg_y = mpg_y[not_null_sel]

# split mpg data into training and test set
mpg_X_train, mpg_X_test, mpg_y_train, mpg_y_test = xval.train_test_split(
mpg_X, mpg_y, test_size=0.25, random_state=42
)

# Create RandomForestRegressor
n_estimators = 1000
mpg_bagger = BaggingRegressor(
base_estimator=SVR(), n_estimators=n_estimators, random_state=42
)
mpg_bagger.fit(mpg_X_train, mpg_y_train)
mpg_y_hat = mpg_bagger.predict(mpg_X_test)

# Plot predicted MPG without error bars
plt.scatter(mpg_y_test, mpg_y_hat)
plt.plot([5, 45], [5, 45], "k--")
plt.xlabel("Reported MPG")
plt.ylabel("Predicted MPG")
plt.show()

# Calculate the variance
mpg_V_IJ_unbiased = fci.random_forest_error(mpg_bagger, mpg_X_train, mpg_X_test)

# Plot error bars for predicted MPG using unbiased variance
plt.errorbar(mpg_y_test, mpg_y_hat, yerr=np.sqrt(mpg_V_IJ_unbiased), fmt="o")
plt.plot([5, 45], [5, 45], "k--")
plt.xlabel("Reported MPG")
plt.ylabel("Predicted MPG")
plt.show()
138 changes: 91 additions & 47 deletions forestci/forestci.py
Expand Up @@ -7,14 +7,19 @@

import numpy as np
import copy

from sklearn.ensemble._forest import BaseForest
from sklearn.ensemble._forest import _generate_sample_indices, _get_n_samples_bootstrap
from sklearn.ensemble._bagging import BaseBagging

from .calibration import calibrateEB
from sklearn.ensemble.forest import _generate_sample_indices, _get_n_samples_bootstrap
from .due import _due, _BibTeX

__all__ = ("calc_inbag", "random_forest_error", "_bias_correction",
"_core_computation")
__all__ = ("calc_inbag", "random_forest_error", "_bias_correction", "_core_computation")

_due.cite(_BibTeX("""
_due.cite(
_BibTeX(
"""
@ARTICLE{Wager2014-wn,
title = "Confidence Intervals for Random Forests: The Jackknife and the Infinitesimal Jackknife",
author = "Wager, Stefan and Hastie, Trevor and Efron, Bradley",
Expand All @@ -23,10 +28,14 @@
number = 1,
pages = "1625--1651",
month = jan,
year = 2014,}"""),
description=("Confidence Intervals for Random Forests:",
"The Jackknife and the Infinitesimal Jackknife"),
path='forestci')
year = 2014,}"""
),
description=(
"Confidence Intervals for Random Forests:",
"The Jackknife and the Infinitesimal Jackknife",
),
path="forestci",
)


def calc_inbag(n_samples, forest):
Expand All @@ -52,28 +61,42 @@ def calc_inbag(n_samples, forest):
"""

if not forest.bootstrap:
e_s = "Cannot calculate the inbag from a forest that has "
e_s = " bootstrap=False"
e_s = "Cannot calculate the inbag from a forest that has bootstrap=False"
raise ValueError(e_s)

n_trees = forest.n_estimators
inbag = np.zeros((n_samples, n_trees))
sample_idx = []
n_samples_bootstrap = _get_n_samples_bootstrap(
n_samples, forest.max_samples
)
if isinstance(forest, BaseForest):
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples, forest.max_samples)

for t_idx in range(n_trees):
sample_idx.append(
_generate_sample_indices(
forest.estimators_[t_idx].random_state,
n_samples,
n_samples_bootstrap,
)
)
inbag[:, t_idx] = np.bincount(sample_idx[-1], minlength=n_samples)
elif isinstance(forest, BaseBagging):
for t_idx, estimator_sample in enumerate(forest.estimators_samples_):
sample_idx.append(estimator_sample)
inbag[:, t_idx] = np.bincount(sample_idx[-1], minlength=n_samples)

for t_idx in range(n_trees):
sample_idx.append(
_generate_sample_indices(forest.estimators_[t_idx].random_state,
n_samples, n_samples_bootstrap))
inbag[:, t_idx] = np.bincount(sample_idx[-1], minlength=n_samples)
return inbag


def _core_computation(X_train, X_test, inbag, pred_centered, n_trees,
memory_constrained=False, memory_limit=None,
test_mode=False):
def _core_computation(
X_train,
X_test,
inbag,
pred_centered,
n_trees,
memory_constrained=False,
memory_limit=None,
test_mode=False,
):
"""
Helper function, that performs the core computation
Expand Down Expand Up @@ -112,27 +135,32 @@ def _core_computation(X_train, X_test, inbag, pred_centered, n_trees,
return np.sum((np.dot(inbag - 1, pred_centered.T) / n_trees) ** 2, 0)

if not memory_limit:
raise ValueError('If memory_constrained=True, must provide',
'memory_limit.')
raise ValueError("If memory_constrained=True, must provide", "memory_limit.")

# Assumes double precision float
chunk_size = int((memory_limit * 1e6) / (8.0 * X_train.shape[0]))

if chunk_size == 0:
min_limit = 8.0 * X_train.shape[0] / 1e6
raise ValueError('memory_limit provided is too small.' +
'For these dimensions, memory_limit must ' +
'be greater than or equal to %.3e' % min_limit)
raise ValueError(
"memory_limit provided is too small."
+ "For these dimensions, memory_limit must "
+ "be greater than or equal to %.3e" % min_limit
)

chunk_edges = np.arange(0, X_test.shape[0] + chunk_size, chunk_size)
inds = range(X_test.shape[0])
chunks = [inds[chunk_edges[i]:chunk_edges[i+1]]
for i in range(len(chunk_edges)-1)]
chunks = [
inds[chunk_edges[i] : chunk_edges[i + 1]] for i in range(len(chunk_edges) - 1)
]
if test_mode:
print('Number of chunks: %d' % (len(chunks),))
V_IJ = np.concatenate([
np.sum((np.dot(inbag-1, pred_centered[chunk].T)/n_trees)**2, 0)
for chunk in chunks])
print("Number of chunks: %d" % (len(chunks),))
V_IJ = np.concatenate(
[
np.sum((np.dot(inbag - 1, pred_centered[chunk].T) / n_trees) ** 2, 0)
for chunk in chunks
]
)
return V_IJ


Expand Down Expand Up @@ -160,17 +188,25 @@ def _bias_correction(V_IJ, inbag, pred_centered, n_trees):
The number of trees in the forest object.
"""
n_train_samples = inbag.shape[0]
n_var = np.mean(np.square(inbag[0:n_trees]).mean(axis=1).T.view() -
np.square(inbag[0:n_trees].mean(axis=1)).T.view())
n_var = np.mean(
np.square(inbag[0:n_trees]).mean(axis=1).T.view()
- np.square(inbag[0:n_trees].mean(axis=1)).T.view()
)
boot_var = np.square(pred_centered).sum(axis=1) / n_trees
bias_correction = n_train_samples * n_var * boot_var / n_trees
V_IJ_unbiased = V_IJ - bias_correction
return V_IJ_unbiased


def random_forest_error(forest, X_train, X_test, inbag=None,
calibrate=True, memory_constrained=False,
memory_limit=None):
def random_forest_error(
forest,
X_train,
X_test,
inbag=None,
calibrate=True,
memory_constrained=False,
memory_limit=None,
):
"""
Calculate error bars from scikit-learn RandomForest estimators.
Expand Down Expand Up @@ -239,8 +275,9 @@ def random_forest_error(forest, X_train, X_test, inbag=None,
pred_mean = np.mean(pred, 0)
pred_centered = pred - pred_mean
n_trees = forest.n_estimators
V_IJ = _core_computation(X_train, X_test, inbag, pred_centered, n_trees,
memory_constrained, memory_limit)
V_IJ = _core_computation(
X_train, X_test, inbag, pred_centered, n_trees, memory_constrained, memory_limit
)
V_IJ_unbiased = _bias_correction(V_IJ, inbag, pred_centered, n_trees)

# Correct for cases where resampling is done without replacement:
Expand All @@ -259,19 +296,26 @@ def random_forest_error(forest, X_train, X_test, inbag=None,
calibration_ratio = 2
n_sample = np.ceil(n_trees / calibration_ratio)
new_forest = copy.deepcopy(forest)
new_forest.estimators_ =\
np.random.permutation(new_forest.estimators_)[:int(n_sample)]
random_idx = np.random.permutation(len(new_forest.estimators_))[: int(n_sample)]
new_forest.estimators_ = list(np.array(new_forest.estimators_)[random_idx])
if hasattr(new_forest, "_seeds"):
new_forest._seeds = new_forest._seeds[random_idx]

new_forest.n_estimators = int(n_sample)

results_ss = random_forest_error(new_forest, X_train, X_test,
calibrate=False,
memory_constrained=memory_constrained,
memory_limit=memory_limit)
results_ss = random_forest_error(
new_forest,
X_train,
X_test,
calibrate=False,
memory_constrained=memory_constrained,
memory_limit=memory_limit,
)
# Use this second set of variance estimates
# to estimate scale of Monte Carlo noise
sigma2_ss = np.mean((results_ss - V_IJ_unbiased)**2)
sigma2_ss = np.mean((results_ss - V_IJ_unbiased) ** 2)
delta = n_sample / n_trees
sigma2 = (delta**2 + (1 - delta)**2) / (2 * (1 - delta)**2) * sigma2_ss
sigma2 = (delta ** 2 + (1 - delta) ** 2) / (2 * (1 - delta) ** 2) * sigma2_ss

# Use Monte Carlo noise scale estimate for empirical Bayes calibration
V_IJ_calibrated = calibrateEB(V_IJ_unbiased, sigma2)
Expand Down

0 comments on commit 7f36ef7

Please sign in to comment.