In [1]:
import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt
import pandas as pd
from chainconsumer import ChainConsumer, Chain
import torch
from sbi import utils as utils
from sbi.inference import infer
import pitszi

In [2]:
cluster = pitszi.Model(silent=True, redshift=0.5, M500=1e15*u.Msun)
cluster.los_reso = 100*u.kpc
cluster.los_size = 2*u.Mpc

In [3]:
def simulator(theta):
    norm, Linj = np.asarray(theta)
    cluster.model_pressure_fluctuation = {'name': 'CutoffPowerLaw',
                                          'statistics': 'gaussian',
                                          'Norm': norm,
                                          'slope': -11./3,
                                          'Linj': Linj*u.kpc,
                                          'Ldis': 1*u.kpc}
    sz_img = cluster.get_sz_map()
    noise = np.random.normal(0,2, sz_img.shape)*1e-5
    img = sz_img + noise*0
    return img.flatten()

In [4]:
num_dim = 2

prior = utils.BoxUniform(
    low=torch.FloatTensor([0., 0.]),
    high=torch.FloatTensor([1., 2000.])
    )

posterior = infer(simulator, prior, method='SNLE', num_simulations=1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

 Neural network successfully converged after 488 epochs.

  thin = _process_thin_default(thin)


In [5]:
fake_img_list = []
for i in range(10):
    fake_cluster = pitszi.Model(silent=True, redshift=0.5, M500=1e15*u.Msun)
    fake_cluster.los_reso = 100*u.kpc
    fake_cluster.los_size = 2*u.Mpc
    fake_cluster.model_pressure_fluctuation = {'name': 'CutoffPowerLaw',
                                               'statistics': 'gaussian',
                                               'Norm': 0.6,
                                               'slope': -11./3,
                                               'Linj': 500*u.kpc,
                                               'Ldis': 1*u.kpc}
    fake_img = fake_cluster.get_sz_map()
    fake_img_list.append(fake_img.flatten())
fake_img_list = np.array(fake_img_list)

In [6]:
samples = posterior.sample((100,), x=fake_img_list, show_progress_bars=True)

Generating 20 MCMC inits via resample strategy:   0%|          | 0/20 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
samples_df = pd.DataFrame.from_dict({par:np.asarray(samples[:, i]) for i, par in enumerate(['norm', 'Linj'])})
chain_vanilla = Chain(samples = samples_df, name="Vanilla results")
cc = ChainConsumer()
cc.add_chain(chain_vanilla)
cc.plotter.plot();

In [None]:
plt.imshow(fake_img)
plt.colorbar()