Skip to content

Commit

Permalink
covariance updates for mc_cov_analysis updates (full_cov_to_cov_dict …
Browse files Browse the repository at this point in the history
…handles remove_doublon)
  • Loading branch information
zatkins2 committed Apr 26, 2024
1 parent 74081ae commit 1923057
Showing 1 changed file with 190 additions and 89 deletions.
279 changes: 190 additions & 89 deletions pspipe_utils/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,40 @@ def read_cov_block_and_build_dict(spec_name_list,

return cov_dict


def spec_dict_to_full_vec(spec_dict,
spec_name_list,
spectra_order=["TT", "TE", "ET", "EE"],
remove_doublon=False):
n_cross = len(spec_name_list)
n_spec = len(spectra_order)
# this looks complicated but just read the first element of the dict, takes its shape and divide by len(
n_bins = int(spec_dict[list(spec_dict)[0]].shape[0])

full_vec = np.zeros(n_cross * n_spec * n_bins)

for sid1, name1 in enumerate(spec_name_list):
for s1, spec1 in enumerate(spectra_order):
id_start_1 = sid1 * n_bins + s1 * n_cross * n_bins
id_stop_1 = (sid1 + 1) * n_bins + s1 * n_cross * n_bins
full_vec[id_start_1:id_stop_1] = spec_dict[name1, spec1]

if remove_doublon == True:
block_to_delete = []
for sid, name in enumerate(spec_name_list):
na, nb = name.split("x")
for s, spec in enumerate(spectra_order):
id_start = sid * n_bins + s * n_cross * n_bins
id_stop = (sid + 1) * n_bins + s * n_cross * n_bins
if (na == nb) & (spec == "ET" or spec == "BT" or spec == "BE"):
block_to_delete = np.append(block_to_delete, np.arange(id_start, id_stop))
block_to_delete = block_to_delete.astype(int)

full_vec = np.delete(full_vec, block_to_delete, axis=0)

return full_vec


def cov_dict_to_full_cov(cov_dict,
spec_name_list,
spectra_order=["TT", "TE", "ET", "EE"],
Expand Down Expand Up @@ -112,7 +146,7 @@ def cov_dict_to_full_cov(cov_dict,
id_stop_1 = (sid1 + 1) * n_bins + s1 * n_cross * n_bins
id_start_2 = sid2 * n_bins + s2 * n_cross * n_bins
id_stop_2 = (sid2 + 1) * n_bins + s2 * n_cross * n_bins
full_cov[id_start_1:id_stop_1, id_start_2: id_stop_2] = cov_dict[name1, name2, spec1, spec2]
full_cov[id_start_1:id_stop_1, id_start_2:id_stop_2] = cov_dict[name1, name2, spec1, spec2]
transpose = full_cov.copy().T
transpose[full_cov != 0] = 0
full_cov += transpose
Expand All @@ -138,6 +172,76 @@ def cov_dict_to_full_cov(cov_dict,
return full_cov


def full_cov_to_cov_dict(full_cov,
spec_name_list,
spectra_order=["TT", "TE", "ET", "EE"],
remove_doublon=False):

# we need nbins. to calculate that we need to know what n_cross * n_spec
# becomes when removing doublon
n_cross = len(spec_name_list)
n_spec = len(spectra_order)

num_deleted_blocks = 0
if remove_doublon == True:
for sid, name in enumerate(spec_name_list):
na, nb = name.split("x")
for s, spec in enumerate(spectra_order):
if (na == nb) & (spec == "ET" or spec == "BT" or spec == "BE"):
num_deleted_blocks += 1

num_cov_blocks = n_cross * n_spec - num_deleted_blocks
n_bins = int(full_cov.shape[0] // num_cov_blocks)

# now that we have nbins, we want to figure out which idxs were remapped
# in the case of removing doublon
idxs = np.arange(n_cross * n_spec * n_bins)
if remove_doublon == True:
block_to_delete = []
for sid, name in enumerate(spec_name_list):
na, nb = name.split("x")
for s, spec in enumerate(spectra_order):
id_start = sid * n_bins + s * n_cross * n_bins
id_stop = (sid + 1) * n_bins + s * n_cross * n_bins
if (na == nb) & (spec == "ET" or spec == "BT" or spec == "BE"):
block_to_delete = np.append(block_to_delete, np.arange(id_start, id_stop))
block_to_delete = block_to_delete.astype(int)

idxs = np.delete(idxs, block_to_delete, axis=0)

# now we can get the cov_dict with this mapping
cov_dict = {}
for sid1, name1 in enumerate(spec_name_list):
na, nb = name1.split("x")
for sid2, name2 in enumerate(spec_name_list):
nc, nd = name2.split("x")
if sid1 > sid2: continue
for s1, spec1 in enumerate(spectra_order):
for s2, spec2 in enumerate(spectra_order):
if remove_doublon == True:
# if a doublon, pretend we are looking at the reversed spec
if (na == nb) & (spec1 == "ET" or spec1 == "BT" or spec1 == "BE"):
s1 = spectra_order.index(spec1[::-1])
if (nc == nd) & (spec2 == "ET" or spec2 == "BT" or spec2 == "BE"):
s2 = spectra_order.index(spec2[::-1])

# get the plain index, whether reversed spec or not
id_start_1 = sid1 * n_bins + s1 * n_cross * n_bins
id_stop_1 = (sid1 + 1) * n_bins + s1 * n_cross * n_bins
id_start_2 = sid2 * n_bins + s2 * n_cross * n_bins
id_stop_2 = (sid2 + 1) * n_bins + s2 * n_cross * n_bins

# get the new index based on the plain index, whether remove_doublon or not
id_start_1 = np.where(idxs == id_start_1)[0][0]
id_stop_1 = np.where(idxs == id_stop_1 - 1)[0][0] + 1
id_start_2 = np.where(idxs == id_start_2)[0][0]
id_stop_2 = np.where(idxs == id_stop_2 - 1)[0][0] + 1

cov_dict[name1, name2, spec1, spec2] = full_cov[id_start_1:id_stop_1, id_start_2:id_stop_2]

return cov_dict


def read_cov_block_and_build_full_cov(spec_name_list,
cov_dir,
cov_type,
Expand Down Expand Up @@ -184,45 +288,45 @@ def read_cov_block_and_build_full_cov(spec_name_list,
return full_cov


def full_cov_to_cov_dict(full_cov,
spec_name_list,
n_bins,
spectra_order=["TT", "TE", "ET", "EE"]):

"""
Decompose the full covariance into a covariance dict, note that
the full covariance should NOT have been produced with remove_doublon=True
Parameters
----------
full_cov: 2d array
the full covariance to decompose
spec_name_list: list of str
list of the cross spectra
n_bins: int
the number of bins per spectra
spectra_order: list of str
the order of the spectra e.g ["TT", "TE", "ET", "EE"]
"""

n_cross = len(spec_name_list)
n_spec = len(spectra_order)
assert full_cov.shape[0] == n_cross * n_spec * n_bins, "full covariance do not have the correct shape"


cov_dict = {}
for sid1, name1 in enumerate(spec_name_list):
for sid2, name2 in enumerate(spec_name_list):
if sid1 > sid2: continue
for s1, spec1 in enumerate(spectra_order):
for s2, spec2 in enumerate(spectra_order):
id_start_1 = sid1 * n_bins + s1 * n_cross * n_bins
id_stop_1 = (sid1 + 1) * n_bins + s1 * n_cross * n_bins
id_start_2 = sid2 * n_bins + s2 * n_cross * n_bins
id_stop_2 = (sid2 + 1) * n_bins + s2 * n_cross * n_bins
cov_dict[name1, name2, spec1, spec2] = full_cov[id_start_1:id_stop_1, id_start_2: id_stop_2]

return cov_dict
# def full_cov_to_cov_dict(full_cov,
# spec_name_list,
# n_bins,
# spectra_order=["TT", "TE", "ET", "EE"]):

# """
# Decompose the full covariance into a covariance dict, note that
# the full covariance should NOT have been produced with remove_doublon=True

# Parameters
# ----------
# full_cov: 2d array
# the full covariance to decompose
# spec_name_list: list of str
# list of the cross spectra
# n_bins: int
# the number of bins per spectra
# spectra_order: list of str
# the order of the spectra e.g ["TT", "TE", "ET", "EE"]
# """

# n_cross = len(spec_name_list)
# n_spec = len(spectra_order)
# assert full_cov.shape[0] == n_cross * n_spec * n_bins, "full covariance do not have the correct shape"


# cov_dict = {}
# for sid1, name1 in enumerate(spec_name_list):
# for sid2, name2 in enumerate(spec_name_list):
# if sid1 > sid2: continue
# for s1, spec1 in enumerate(spectra_order):
# for s2, spec2 in enumerate(spectra_order):
# id_start_1 = sid1 * n_bins + s1 * n_cross * n_bins
# id_stop_1 = (sid1 + 1) * n_bins + s1 * n_cross * n_bins
# id_start_2 = sid2 * n_bins + s2 * n_cross * n_bins
# id_stop_2 = (sid2 + 1) * n_bins + s2 * n_cross * n_bins
# cov_dict[name1, name2, spec1, spec2] = full_cov[id_start_1:id_stop_1, id_start_2: id_stop_2]

# return cov_dict

def cov_dict_to_file(cov_dict,
spec_name_list,
Expand Down Expand Up @@ -347,78 +451,75 @@ def skew(cov, dir=1):
return corrected_cov


def smooth_gp_diag(lb, arr_diag, ell_cut, length_scale=500.0,
length_scale_bounds=(100, 1e4), noise_level=0.01,
noise_level_bounds=(1e-6, 1e1), n_restarts_optimizer=20):

kernel = 1.0 * RBF(length_scale=length_scale,
length_scale_bounds=length_scale_bounds) + WhiteKernel(
noise_level=noise_level, noise_level_bounds=noise_level_bounds
)
# fit the first GP on the bins above the ell_cut
gpr = GaussianProcessRegressor(kernel=kernel, alpha=0.0, normalize_y=True,
n_restarts_optimizer=n_restarts_optimizer)
i_cut = np.argmax(lb > ell_cut)
X_train = lb[i_cut:,np.newaxis]
y_train = arr_diag[i_cut:]
gpr.fit(X_train, y_train)
y_mean_high = gpr.predict(lb[:,np.newaxis], return_std=False)

# fit an exponential at the low end
i_cut = np.argmax(lb > ell_cut)
X_train = lb[:i_cut]
y_train = np.abs(arr_diag - y_mean_high)[:i_cut]
z = np.polyfit(X_train, np.log(y_train), 1)
f = np.poly1d(z)
y_mean_high[:i_cut] += np.exp(f(lb[:i_cut]))
return y_mean_high


def _correct_analytical_cov_keep_res_diag(an_full_cov, mc_full_cov, return_diag=False):
def correct_analytical_cov_keep_res_diag(an_full_cov, mc_full_cov, return_diag=False):
sqrt_an_full_cov = utils.eigpow(an_full_cov, 0.5)
inv_sqrt_an_full_cov = np.linalg.inv(sqrt_an_full_cov)
res = inv_sqrt_an_full_cov @ mc_full_cov @ inv_sqrt_an_full_cov # res should be close to the identity if an_full_cov is good
res_diag = np.diag(res)
corrected_cov = sqrt_an_full_cov @ np.diag(res_diag) @ sqrt_an_full_cov

if return_diag:
return corrected_cov, res_diag
else:
return corrected_cov

def correct_analytical_cov_keep_res_diag(an_full_cov, mc_full_cov, return_diag=False):
d_an, O_an = np.linalg.eigh(an_full_cov)
sqrt_an_full_cov = O_an @ np.diag(d_an**.5)
inv_sqrt_an_full_cov = np.diag(d_an**-.5) @ O_an.T
res = inv_sqrt_an_full_cov @ mc_full_cov @ inv_sqrt_an_full_cov.T # res should be close to the identity if an_full_cov is good
res_diag = np.diag(res)

corrected_cov = sqrt_an_full_cov @ np.diag(res_diag) @ sqrt_an_full_cov.T

if return_diag:
return corrected_cov, res_diag
else:
return corrected_cov

def correct_analytical_cov_keep_res_diag_gp(an_full_cov, mc_full_cov, lb, ell_cut, return_diag=False):
d_an, O_an = np.linalg.eigh(an_full_cov)
sqrt_an_full_cov = O_an @ np.diag(d_an**.5)
inv_sqrt_an_full_cov = np.diag(d_an**-.5) @ O_an.T

def correct_analytical_cov_keep_res_diag_gp(an_full_cov, mc_full_cov, lb, ell_cuts=None,
return_diag=False):
sqrt_an_full_cov = utils.eigpow(an_full_cov, 0.5)
inv_sqrt_an_full_cov = np.linalg.inv(sqrt_an_full_cov)
res = inv_sqrt_an_full_cov @ mc_full_cov @ inv_sqrt_an_full_cov.T # res should be close to the identity if an_full_cov is good
res_diag = np.diag(res)

n_spec = len(res_diag) // len(lb)
diags = np.array_split(res_diag, n_spec)
smoothed_diags = [smooth_gp_diag(lb, r, ell_cut) for r in diags]
smooth_res = np.hstack(smoothed_diags)
if ell_cuts is None:
ell_cuts = [0] * n_spec
res_diags = np.split(res_diag, n_spec)
smoothed_res_diags = []
for i, res in enumerate(res_diags):
smoothed_res_diags.append(smooth_gp_diag(lb, r, ell_cut=ell_cuts[i]) for r in res_diags)
smooth_res_diag = np.hstack(smoothed_res_diags)

corrected_cov = sqrt_an_full_cov @ np.diag(smooth_res) @ sqrt_an_full_cov.T
corrected_cov = sqrt_an_full_cov @ np.diag(smooth_res_diag) @ sqrt_an_full_cov.T

if return_diag:
return corrected_cov, res_diag
else:
return corrected_cov


def smooth_gp_diag(lb, arr_diag, var_diag, ell_cut=0, length_scale=500.0,
length_scale_bounds=(100, 1e4), n_restarts_optimizer=10):

kernel = 1.0 * RBF(length_scale=length_scale, length_scale_bounds=length_scale_bounds) # + \
# WhiteKernel(noise_level=noise_level, noise_level_bounds=noise_level_bounds)

out = np.zeros_like(arr_diag)


i_cut = np.where(lb > ell_cut)[0][0] # idx of first bin greater than ell_cut
out[:i_cut] = arr_diag[:i_cut] # below i_cut, do nothing

# fit the first GP on the bins above the ell_cut
X_train = lb[i_cut:, np.newaxis]
y_train = arr_diag[i_cut:] - 1 # prior mean of 1
var_train = var_diag[i_cut:]
gpr = GaussianProcessRegressor(kernel=kernel, alpha=var_train,
n_restarts_optimizer=n_restarts_optimizer)
gpr.fit(X_train, y_train)
out[i_cut:] = gpr.predict(X_train, return_std=False) + 1 # prior mean of 1

# # fit an exponential at the low end
# X_train = lb[:i_cut]
# y_train = np.abs(arr_diag - y_mean_high)[:i_cut]
# z = np.polyfit(X_train, np.log(y_train), 1)
# f = np.poly1d(z)
# y_mean_high[:i_cut] += np.exp(f(lb[:i_cut]))

return out


def canonize_connected_2pt(leg1, leg2, all_legs):
"""A connected 2-point term has two legs but is invariant to their
order. Thus, if we enforce a strict global order (a canonical order)
Expand Down

0 comments on commit 1923057

Please sign in to comment.