Skip to content

Commit

Permalink
Merge pull request #51 from xzackli/ana_cov_comp_gp
Browse files Browse the repository at this point in the history
GP smoothing of analytic covmat (refactored)
  • Loading branch information
zatkins2 committed Apr 23, 2024
2 parents 4215038 + 3470329 commit 74081ae
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
49 changes: 49 additions & 0 deletions pspipe_utils/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from scipy.optimize import curve_fit
import pylab as plt

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel

from pixell import utils

from itertools import combinations_with_replacement as cwr
Expand Down Expand Up @@ -342,7 +345,34 @@ def skew(cov, dir=1):
return S, corrected_cov
else:
return corrected_cov


def smooth_gp_diag(lb, arr_diag, ell_cut, length_scale=500.0,
length_scale_bounds=(100, 1e4), noise_level=0.01,
noise_level_bounds=(1e-6, 1e1), n_restarts_optimizer=20):

kernel = 1.0 * RBF(length_scale=length_scale,
length_scale_bounds=length_scale_bounds) + WhiteKernel(
noise_level=noise_level, noise_level_bounds=noise_level_bounds
)
# fit the first GP on the bins above the ell_cut
gpr = GaussianProcessRegressor(kernel=kernel, alpha=0.0, normalize_y=True,
n_restarts_optimizer=n_restarts_optimizer)
i_cut = np.argmax(lb > ell_cut)
X_train = lb[i_cut:,np.newaxis]
y_train = arr_diag[i_cut:]
gpr.fit(X_train, y_train)
y_mean_high = gpr.predict(lb[:,np.newaxis], return_std=False)

# fit an exponential at the low end
i_cut = np.argmax(lb > ell_cut)
X_train = lb[:i_cut]
y_train = np.abs(arr_diag - y_mean_high)[:i_cut]
z = np.polyfit(X_train, np.log(y_train), 1)
f = np.poly1d(z)
y_mean_high[:i_cut] += np.exp(f(lb[:i_cut]))
return y_mean_high


def _correct_analytical_cov_keep_res_diag(an_full_cov, mc_full_cov, return_diag=False):
sqrt_an_full_cov = utils.eigpow(an_full_cov, 0.5)
Expand All @@ -369,6 +399,25 @@ def correct_analytical_cov_keep_res_diag(an_full_cov, mc_full_cov, return_diag=F
else:
return corrected_cov

def correct_analytical_cov_keep_res_diag_gp(an_full_cov, mc_full_cov, lb, ell_cut, return_diag=False):
d_an, O_an = np.linalg.eigh(an_full_cov)
sqrt_an_full_cov = O_an @ np.diag(d_an**.5)
inv_sqrt_an_full_cov = np.diag(d_an**-.5) @ O_an.T
res = inv_sqrt_an_full_cov @ mc_full_cov @ inv_sqrt_an_full_cov.T # res should be close to the identity if an_full_cov is good
res_diag = np.diag(res)

n_spec = len(res_diag) // len(lb)
diags = np.array_split(res_diag, n_spec)
smoothed_diags = [smooth_gp_diag(lb, r, ell_cut) for r in diags]
smooth_res = np.hstack(smoothed_diags)

corrected_cov = sqrt_an_full_cov @ np.diag(smooth_res) @ sqrt_an_full_cov.T

if return_diag:
return corrected_cov, res_diag
else:
return corrected_cov


def canonize_connected_2pt(leg1, leg2, all_legs):
"""A connected 2-point term has two legs but is invariant to their
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
python_requires=">=3.9",
install_requires=[
"pspy>=1.5.3",
"scikit-learn>=1",
"mflike>=0.9.5",
],
package_data={"": ["data/**"]},
Expand Down

0 comments on commit 74081ae

Please sign in to comment.