In [1]:
import numpyro

numpyro.enable_x64()
numpyro.set_host_device_count(4)

In [2]:
import astropy.units as u
from xsb_fluc.data.cluster import Cluster

cluster = Cluster(
    imglink='MACSJ0717.5+3745/epic-obj-im-700-1200.fits.gz',
    explink='MACSJ0717.5+3745/epic-exp-im-700-1200.fits.gz',
    bkglink='MACSJ0717.5+3745/epic-back-tot-sp-700-1200.fits.gz',
    reglink='MACSJ0717.5+3745/srclist_MACSJ0717.5+3745_ib.reg',
    nhlink='MACSJ0717.5+3745/macs_j0717_nh.fits',
    ra=109.402083,
    dec=37.756389,
    r_500=1.5*u.Mpc, 
    redshift=0.5458,
)

cluster_voronoi = cluster.voronoi('MACSJ0717.5+3745/voronoi.txt', exclusion=2, rebin_factor=3, t_500_percent=2)

the RADECSYS keyword is deprecated, use RADESYSa. [astropy.wcs.wcs]
a floating-point value was expected. [astropy.wcs.wcs]


In [3]:
import haiku as hk
from xsb_fluc.simulation.mock_image import MockXrayCountsBetaModel


In [5]:
import jax.numpy as jnp
import haiku as hk 

class MultipleBetaModel(hk.Module):
    
    def __init__(self, cluster, n_components):
                
        super(MultipleBetaModel, self).__init__()
        self.models = [MockXrayCountsBetaModel(cluster) for _ in range(n_components)]
     
    def __call__(self):
     
        return jnp.sum(jnp.stack([model() for model in self.models]), axis=0)
        
images_simulator = hk.without_apply_rng(hk.transform(lambda : MultipleBetaModel(cluster_voronoi, 2)()))
print(images_simulator.init(None))

{'multiple_beta_model/~/mock_xray_counts_beta_model/~/ellipse_radius': {'angle': Array(0., dtype=float32), 'eccentricity': Array(0., dtype=float32), 'x_c': Array(0., dtype=float32), 'y_c': Array(0., dtype=float32)}, 'multiple_beta_model/~/mock_xray_counts_beta_model/~/xray_surface_brightness_beta_model': {'log_bkg': Array(-5., dtype=float32), 'log_e_0': Array(-4., dtype=float32), 'log_r_c': Array(-1., dtype=float32), 'beta': Array(0.6666667, dtype=float32)}, 'multiple_beta_model/~/mock_xray_counts_beta_model_1/~/ellipse_radius': {'angle': Array(0., dtype=float32), 'eccentricity': Array(0., dtype=float32), 'x_c': Array(0., dtype=float32), 'y_c': Array(0., dtype=float32)}, 'multiple_beta_model/~/mock_xray_counts_beta_model_1/~/xray_surface_brightness_beta_model': {'log_bkg': Array(-5., dtype=float32), 'log_e_0': Array(-4., dtype=float32), 'log_r_c': Array(-1., dtype=float32), 'beta': Array(0.6666667, dtype=float32)}}


In [19]:
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp 

prior_distributions = {
    'multiple_beta_model/~/mock_xray_counts_beta_model/~/ellipse_radius': {
        'angle': dist.Uniform(0., jnp.pi/2),
        'eccentricity': dist.Uniform(0, 0.99),
        'x_c': dist.Normal(0, 1),
        'y_c': dist.Normal(0, 1)
    },
    'multiple_beta_model/~/mock_xray_counts_beta_model/~/xray_surface_brightness_beta_model': {
        'log_bkg': jnp.asarray(-100.),
        'log_e_0': dist.Uniform(-9, -3),
        'log_r_c': dist.Uniform(-3, 0),
        'beta': dist.Uniform(0, 5)
    },
        'multiple_beta_model/~/mock_xray_counts_beta_model_1/~/ellipse_radius': {
        'angle': dist.Uniform(0., jnp.pi/2),
        'eccentricity': dist.Uniform(0, 0.99),
        'x_c': dist.Normal(-0.4, 0.1),
        'y_c': dist.Normal(-0.4, 0.1)
    },
    'multiple_beta_model/~/mock_xray_counts_beta_model_1/~/xray_surface_brightness_beta_model': {
        'log_bkg': jnp.asarray(-100.),
        'log_e_0': dist.Uniform(-9, -3),
        'log_r_c': dist.Uniform(-3, 0),
        'beta': dist.Uniform(0, 5)
    }
}

def numpyro_model(observed_cluster=None):
    
    # Here, we inform numpyro that we want to draw the parameters from prior distributions
    samples = hk.data_structures.to_haiku_dict(prior_distributions)
    
    for module, parameter, prior in hk.data_structures.traverse(prior_distributions):
        
        samples[module][parameter] = numpyro.sample(module + '_' + parameter, prior) if isinstance(prior, dist.Distribution) else prior
    
    # We compute the expected values using the samples from prior distribution
    images_simulator = hk.without_apply_rng(hk.transform(lambda : MultipleBetaModel(observed_cluster, 2)()))
    expected_counts = images_simulator.apply(samples)
    
    # We compare it to the actually observed counts in each pixel
    numpyro.sample('likelihood', dist.Poisson(expected_counts), obs=observed_cluster.img)

In [None]:
from jax.random import PRNGKey
from numpyro.infer import MCMC, NUTS

kernel = NUTS(numpyro_model, max_tree_depth=10)
mcmc = MCMC(kernel, num_chains=4, num_warmup=10000, num_samples=1000)

mcmc.run(PRNGKey(10), observed_cluster=cluster_voronoi)

  0%|          | 0/11000 [00:00<?, ?it/s]

  0%|          | 0/11000 [00:00<?, ?it/s]

  0%|          | 0/11000 [00:00<?, ?it/s]

  0%|          | 0/11000 [00:00<?, ?it/s]

In [None]:
import arviz as az 
import matplotlib.pyplot as plt 

inference_data = az.from_numpyro(mcmc)
az.summary(inference_data)

In [None]:
with az.style.context("arviz-darkgrid", after_reset=True):
    az.plot_trace(inference_data, compact=False)

plt.show();

In [None]:
import json
import numpy as np
from jax.tree_util import tree_map

posterior_parameters = tree_map(lambda x: jnp.median(x), mcmc.get_samples())

hk_posterior_parameters = hk.data_structures.to_haiku_dict(prior_distributions)

for module, parameter, prior in hk.data_structures.traverse(prior_distributions):
    
    #print(module+'_'+parameter)
    if isinstance(hk_posterior_parameters[module][parameter], dist.Distribution):
        hk_posterior_parameters[module][parameter] = posterior_parameters[module + '_' + parameter]

In [None]:
import cmasher as cmr
import numpy as np
from matplotlib.colors import LogNorm, SymLogNorm

cluster_to_plot = cluster.reduce_to_r500(1.5)

images_simulator_full = hk.without_apply_rng(hk.transform(lambda : MultipleBetaModel(cluster_to_plot, 2)()))
best_fit_image = images_simulator_full.apply(hk_posterior_parameters)

In [None]:
fig, axs = plt.subplots(
    figsize=(12, 5),
    nrows=1,
    ncols=3,
    subplot_kw={'projection': cluster.wcs}
)

mask = cluster_to_plot.exp > 0
xsb_fluc = (cluster_to_plot.img - best_fit_image)/(2*cluster_to_plot.exp)
img_norm = LogNorm(vmin=0.5, vmax=200)

map_img = axs[0].imshow(np.where(mask, cluster_to_plot.img, np.nan), norm=img_norm, cmap=cmr.cosmic)
map_fit = axs[1].imshow(np.where(mask, best_fit_image, np.nan), norm=img_norm, cmap=cmr.cosmic)
map_fluc = axs[2].imshow(np.where(mask, xsb_fluc, np.nan), cmr.guppy, norm=SymLogNorm(vmin=-1e-6, vmax=1e-6, linthresh=1e-8))

plt.colorbar(map_img, ax=axs[0], location='bottom', label='Counts (True image)')
plt.colorbar(map_fit, ax=axs[1], location='bottom', label='Counts (Fitted image)')
plt.colorbar(map_fluc, ax=axs[2], location='bottom', label='Fluctuations')

plt.show();