Skip to content

Commit

Permalink
remove extraneous parameter and smoooth diag
Browse files Browse the repository at this point in the history
  • Loading branch information
xzackli committed Apr 2, 2024
1 parent b3fa51f commit 3470329
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions pspipe_utils/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def skew(cov, dir=1):

def smooth_gp_diag(lb, arr_diag, ell_cut, length_scale=500.0,
length_scale_bounds=(100, 1e4), noise_level=0.01,
noise_level_bounds=(1e-6, 1e1), low_ell_scale=100, n_restarts_optimizer=20):
noise_level_bounds=(1e-6, 1e1), n_restarts_optimizer=20):

kernel = 1.0 * RBF(length_scale=length_scale,
length_scale_bounds=length_scale_bounds) + WhiteKernel(
Expand All @@ -367,9 +367,7 @@ def smooth_gp_diag(lb, arr_diag, ell_cut, length_scale=500.0,
# fit an exponential at the low end
i_cut = np.argmax(lb > ell_cut)
X_train = lb[:i_cut]
y_train = (arr_diag - y_mean_high)[:i_cut]
pos_el = y_train > 0
X_train, y_train = X_train[pos_el], y_train[pos_el]
y_train = np.abs(arr_diag - y_mean_high)[:i_cut]
z = np.polyfit(X_train, np.log(y_train), 1)
f = np.poly1d(z)
y_mean_high[:i_cut] += np.exp(f(lb[:i_cut]))
Expand Down Expand Up @@ -401,6 +399,25 @@ def correct_analytical_cov_keep_res_diag(an_full_cov, mc_full_cov, return_diag=F
else:
return corrected_cov

def correct_analytical_cov_keep_res_diag_gp(an_full_cov, mc_full_cov, lb, ell_cut, return_diag=False):
d_an, O_an = np.linalg.eigh(an_full_cov)
sqrt_an_full_cov = O_an @ np.diag(d_an**.5)
inv_sqrt_an_full_cov = np.diag(d_an**-.5) @ O_an.T
res = inv_sqrt_an_full_cov @ mc_full_cov @ inv_sqrt_an_full_cov.T # res should be close to the identity if an_full_cov is good
res_diag = np.diag(res)

n_spec = len(res_diag) // len(lb)
diags = np.array_split(res_diag, n_spec)
smoothed_diags = [smooth_gp_diag(lb, r, ell_cut) for r in diags]
smooth_res = np.hstack(smoothed_diags)

corrected_cov = sqrt_an_full_cov @ np.diag(smooth_res) @ sqrt_an_full_cov.T

if return_diag:
return corrected_cov, res_diag
else:
return corrected_cov


def canonize_connected_2pt(leg1, leg2, all_legs):
"""A connected 2-point term has two legs but is invariant to their
Expand Down

0 comments on commit 3470329

Please sign in to comment.