[MRG] Add clone_kernel option to make gaussian process models faster #14378
What does this implement/fix? Explain your changes.
I want to make GaussianProcessRegressor faster. I profiled
This result shows
Basically, I think we don't need to clone kernels. It seems to be enough that we just replace
Any other comments?
To confirm this PR doesn't break the logic, I generate the graph of model prediction results (I might need to design more suitable kernel functions).
The text was updated successfully, but these errors were encountered:
Thanks for reviewing! I'll apply this change to
import numpy as np import matplotlib.pyplot as plt from sklearn.gaussian_process import kernels as sk_kern from sklearn.gaussian_process import GaussianProcessRegressor def objective(x): return x + 20 * np.sin(x) def main(): kernel = sk_kern.RBF(1.0, (1e-3, 1e3)) + sk_kern.ConstantKernel(1.0, (1e-3, 1e3)) clf = GaussianProcessRegressor( kernel=kernel, alpha=1e-10, optimizer="fmin_l_bfgs_b", n_restarts_optimizer=20, normalize_y=True) np.random.seed(0) x_train = np.random.uniform(-20, 20, 200) y_train = objective(x_train) + np.random.normal(loc=0, scale=.1, size=x_train.shape) clf.fit(x_train.reshape(-1, 1), y_train) x_test = np.linspace(-20., 20., 200).reshape(-1, 1) pred_mean, pred_std = clf.predict(x_test, return_std=True) # save pred_mean and pred_std #np.save("pred_mean", pred_mean) #np.save("pred_std", pred_std) # assertion orig_mean = np.load("pred_mean.npy") orig_std = np.load("pred_std.npy") np.testing.assert_allclose(orig_mean, pred_mean) np.testing.assert_allclose(orig_std, pred_std) if __name__ == '__main__': main()
I added the patch for GaussianProcessClassifier.
The benchmark and the results are below. It seems GaussianProcessClassifier will be 25.2% faster in the benchmark of fitting the iris dataset.
I also confirmed
This PR changes the output of the third line here:
gpr.fit(X, y) gpr.log_marginal_likelihood(theta=another_value) gpr.log_marginal_likelihood()
You could maybe add a parameter like
Thanks for reviewing, @adrinjalali !
Could you tell me the case that the output of the third line will be changed? If there is that, I'll push the following change on this branch.