In [1]:
%reload_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp
import fitsio

import impt
from impt.fpfs import *
from impt.fpfs.default import *

In [3]:
ndat=1000000
print("Simulating catalog with %d galaxies" %ndat)
key = jax.random.PRNGKey(212)  # Random seed is explicit in JAX
cat = jax.random.normal(key=key, shape=(ndat, ncol))

params = FpfsParams()
e1F = FpfsE1(params)
e2F = FpfsE2(params)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Simulating catalog with 1000000 galaxies


In [4]:
print("loading the noise covariance")
test_fname = os.path.join(
    impt.fpfs.__data_dir__,
    "fpfs-cut32-0000-g1-0000.fits",
    )
data=fitsio.read(test_fname)
%time noise_cov=impt.fpfs.utils.fpfsCov2lptCov(data)

loading the noise covariance
CPU times: user 2.05 ms, sys: 2 µs, total: 2.05 ms
Wall time: 2.01 ms


In [5]:
print("preparing the function for noise bias factor estimation")
%time e1_noi= impt.BiasNoise(e1F, noise_cov)
print("estimating noise bias")
%time e1_noi.evaluate(cat)

preparing the function for noise bias factor estimation
CPU times: user 399 µs, sys: 0 ns, total: 399 µs
Wall time: 410 µs
estimating noise bias
CPU times: user 1.87 s, sys: 12 ms, total: 1.88 s
Wall time: 1.87 s


Array([ 0.01420485, -0.01691402,  0.00158234, ..., -0.16089322,
        0.21418488,  0.00549208], dtype=float64)

In [6]:
print("preparing the function for shear response estimation")
%time e1_res=impt.RespG1(e1F)
print("estimating shear response")
%time out=e1_res.evaluate(cat)

preparing the function for shear response estimation
CPU times: user 358 µs, sys: 35 µs, total: 393 µs
Wall time: 405 µs
estimating shear response
CPU times: user 197 ms, sys: 7.93 ms, total: 205 ms
Wall time: 195 ms


In [2]:
import fpfs
def initialize_FPFS(fs, snlist):
    cutsig = []
    cut = []
    for sn in snlist:
        if sn == "detect2":
            cutsig.append(params.sigma_v)
            cut.append(params.lower_v)
        elif sn == "M00":
            cutsig.append(params.sigma_m00)
            cut.append(params.lower_m00)
        elif sn == "R2":
            cutsig.append(params.sigma_r2)
            cut.append(params.lower_r2)
    cutsig = np.array(cutsig)
    cut = np.array(cut)
    fs.clear_outcomes()
    fs.update_selection_weight(snlist, cut, cutsig)
    fs.update_selection_bias(snlist, cut, cutsig)
    fs.update_ellsum()
    return fs

test_fname = os.path.join(
    impt.fpfs.__data_dir__,
    "fpfs-cut32-0000-g1-0000.fits",
    )
cat = impt.fpfs.read_catalog(test_fname)
data = fitsio.read(test_fname)

2022/12/30 03:12:35 ---  Remote TPU is not linked into jax; skipping remote TPU.
2022/12/30 03:12:35 ---  Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
2022/12/30 03:12:35 ---  Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2022/12/30 03:12:35 ---  Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2022/12/30 03:12:35 ---  Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
2022/12/30 03:12:35 ---  Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
2022/12/30 03:12:35 ---  No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [11]:
params = FpfsParams(lower_m00=4.0, sigma_m00=0.5, lower_r2=-10.)
e1F = FpfsE1(params)
e2F = FpfsE2(params)
w_sel = FpfsWeightSelect(params)
w_det = FpfsWeightDetect(params)


ell_fpfs = fpfs.catalog.fpfsM2E(data, const=params.Const, noirev=False)
fs = fpfs.catalog.summary_stats(data, ell_fpfs, use_sig=False, ratio=1.0)
selnm = np.array(["M00"])
fs = initialize_FPFS(fs, selnm)

np.testing.assert_array_almost_equal(
    fs.ws,
    w_sel.evaluate(cat),
)

we1 = e1F*w_sel
we2 = e2F*w_sel
np.testing.assert_array_almost_equal(
    fs.sumE1,
    jnp.sum(we1.evaluate(cat)),
)
np.testing.assert_array_almost_equal(
    fs.sumE2,
    jnp.sum(we2.evaluate(cat)),
)

dwe1_dg1 = impt.RespG1(we1)
dwe2_dg2 = impt.RespG2(we2)
res_ad = jnp.sum(dwe1_dg1.evaluate(cat)) + jnp.sum(dwe2_dg2.evaluate(cat))
res_fpfs = fs.corR1 + fs.sumR1 + fs.corR2 + fs.sumR2


np.testing.assert_array_almost_equal(
    res_ad,
    res_fpfs,
)

In [7]:
params = FpfsParams(lower_m00=-4.0, sigma_m00=0.1, lower_r2=-10.)
e1F = FpfsE1(params)
e2F = FpfsE2(params)
w_sel = FpfsWeightSelect(params)
w_det = FpfsWeightDetect(params)


ell_fpfs = fpfs.catalog.fpfsM2E(data, const=params.Const, noirev=False)
fs = fpfs.catalog.summary_stats(data, ell_fpfs, use_sig=False, ratio=1.0)
selnm = np.array(["M00"])
fs = initialize_FPFS(fs, selnm)

np.testing.assert_array_almost_equal(
    fs.ws,
    w_sel.evaluate(cat),
)

we1 = e1F*w_sel
np.testing.assert_array_almost_equal(
    fs.sumE1,
    jnp.sum(we1.evaluate(cat)),
)
dwe1_dg1 = impt.RespG1(we1)
print(jnp.sum(dwe1_dg1.evaluate(cat)))
print(fs.corR1 + fs.sumR1)


we2 = e2F*w_sel
np.testing.assert_array_almost_equal(
    fs.sumE2,
    jnp.sum(we2.evaluate(cat)),
)

5.491789601167456
5.491789601167457


In [4]:
params = FpfsParams(lower_m00=-4.0, sigma_m00=0.5, lower_r2=0.12, sigma_r2=0.2)
e1F = FpfsE1(params)
e2F = FpfsE2(params)
w_sel = FpfsWeightSelect(params)
w_det = FpfsWeightDetect(params)


ell_fpfs = fpfs.catalog.fpfsM2E(data, const=params.Const, noirev=False)
fs = fpfs.catalog.summary_stats(data, ell_fpfs, use_sig=False, ratio=1.0)
selnm = np.array(["R2"])
fs = initialize_FPFS(fs, selnm)

np.testing.assert_array_almost_equal(
    fs.ws,
    w_sel.evaluate(cat),
)

In [5]:
params = FpfsParams(lower_m00=-4.0, sigma_m00=0.5, lower_r2=-4.0, sigma_r2=0.2, sigma_v=0.2)
e1F = FpfsE1(params)
e2F = FpfsE2(params)
w_sel = FpfsWeightSelect(params)
w_det = FpfsWeightDetect(params)


ell_fpfs = fpfs.catalog.fpfsM2E(data, const=params.Const, noirev=False)
fs = fpfs.catalog.summary_stats(data, ell_fpfs, use_sig=False, ratio=1.0)
selnm = np.array(["detect2"])
fs = initialize_FPFS(fs, selnm)

np.testing.assert_array_almost_equal(
    fs.ws,
    w_det.evaluate(cat),
)