# <center>ASTR4004/8004 - Inference - Part 6</center> 

## Simulation-based inference

Our goal is to perform inference on a model's parameters $\theta$ given observations $D$ and learn the posterior distribution $P(\theta|D)$. Normally, we do this with Bayes' rule:
$$
P(\theta|D)=P(D|\theta)\frac{P(\theta)}{P(D)},
$$
which relies on the likelihood function $P(D|\theta)$.

<font color='red'> However, what if we don't know the likelihood or there is no functional form to evaluate the likelihood?</font>

You can train a classifier $\tilde{P}(y = 1 \mid \theta, D)$ to distinguish whether the sample $(\theta, D)$ is drawn jointly $(y=1)$ or marginally $(y=0)$. The classifier will inform the likelihood-to-evidence ratio.

$$
\frac{P(D|\theta)}{P(D)} = \frac{P(\theta|D)}{P(\theta)} =  
\frac{P(\theta, D)}{P(\theta)P(D)}
\equiv 
\frac{\tilde{P}(\theta, D \mid y = 1)}{\tilde{P}(\theta, D \mid y = 0)} 
= 
\frac{\tilde{P}(\theta, D, y = 1)}{\tilde{P}(\theta, D, y = 0)} 
= 
\frac{\tilde{P}(y = 1 \mid \theta, D)}{\tilde{P}(y = 0 \mid \theta, D)} 
= 
\frac{\tilde{P}(y = 1 \mid \theta, D)}{1 - \tilde{P}(y = 1 \mid \theta, D)}
$$

This notebook shows a simple example of model recovery using `swyft` and its backend PyTorch (`torch`).

## Useful packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import swyft

## Load data and prepare the input and output pairs

In [None]:
data = np.loadtxt('../../data/samples_m_t.dat')
np.random.seed(0) # fix your seeds for reproducibility!
shuffle_index = np.random.permutation(len(data))
data = data[shuffle_index]

In this example, the first two columns correspond to dark matter mass in units of eV and lifetime in units of second, while the following are effective parameters characterizing the heating, ionization and excitation coefficients from dark matter to the intergalactic medium. In particular, the 2nd, 3rd, 4th and 5th columns can be used to compute the heating coefficient normalized by the lifetime as a function of redshift following a Schechter function.

In [None]:
# identify input parameters and output observables (i.e., fheat)
params = np.log10(data[:,0]).reshape([-1,1])

zs = np.arange(5, 35)
fheat = data[:,2]+np.log10(data[:,1]) + np.log10(np.exp(data[:,3]*(zs[:,None]-15)) * ((zs[:,None])/15)**data[:,4])
fheat = fheat.T

# We keep the first 5 samples as observations (amortization!), and use the rest for training
obs = ...
samples = ...

# let's visualize the dataset
for i in range(1000):
    plt.plot(zs, samples[i]['fheat'], color='k', lw=0.1)

for i in range(5):
    plt.plot(zs, obs[i]['fheat'], color='r', lw = 2)
    
plt.ylabel(r"$\log_{10}\left[f_{\rm heat}\right]$")
plt.xlabel(r"$z$")
plt.legend(loc=0)

We could also take a look at the joint and marginal samples to get a feeling for the classification that will happen.

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = plt.subplot(1, 1, 1)
...

plt.scatter(..., alpha=0.3, c='r', s=2., label='marginal')
plt.scatter(..., alpha=0.3, c='b', s=2., label='joint');
plt.xlabel(r'$log_{10}(m_{\chi}/ {\rm eV})$')
plt.ylabel(r"$\log_{10}\left[f_{\rm heat}(z=5)\right]$")
plt.legend()

## Inference network

Swyft comes with a few default networks. Here we use swyft.LogRatioEstimator_1dim, which is a dense network that estimates one-dimensional posteriors. You can use LogRatioEstimator_Ndim to estimate higher-dimensional marginalized posteriors.

In [None]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        ...
        
    def forward(self, data, theta):
        ...
        
        return logratios

## Training
Training is now done using the SwyftTrainer class.

In [None]:
model = swyft.SwyftTrainer(precision = 64)
network = Network()
model.fit(network, swyft.SwyftDataModule(samples))

## Inference

Since the inference network estimates the logarithm of the posterior-to-prior ratio, we can obtain weighted posterior samples by running many prior samples through the inference network. To this end, we first generate prior samples.

In [None]:
logm_min = 6
logm_max = 12
prior = np.random.rand(100000, 1)
prior = prior * (logm_max - logm_min) + logm_min

prior_samples = swyft.Samples(logm = prior)

Then we evaluate the inference network by using the infer method of the swyft.Trainer object.

In [None]:
for i in range(5):
    predictions = ...
    swyft.plot_posterior(predictions, "logm[0]", truth={'logm[0]':obs[i]['logm']}, 
                    labels = [r'$\log_{10}(m_{\chi}/{\rm eV})$',]);

## Converage tests

How do we know if we are getting it right?

If the obtained posterior is accurate, we expect the true parameters to fall outside the p=95% highest posterior density region in only 1-p=5% of the cases. If the chance goes above 5%, the posterior is narrower than it should be and hence overconfident; If the chance is lower than 5%, the posterior is wider than the truth and hence conservative.

In [None]:
coverage_samples = model.test_coverage(network, samples[-500:], prior_samples)

# The pp figure plots 1-p vs 1-p
swyft.plot_pp(coverage_samples, "logm[0]")
plt.text(0.4, 0.5, "...", rotation=45,
        color="green", fontsize=12, ha="center", va="center")

plt.text(0.45, 0.35, "...", rotation=45,
        color="green", fontsize=12, ha="center", va="center")

plt.show()

# The zz figure plots "sigma" vs "sigma" as z is defined as 1-0.5*(1-p)
swyft.plot_zz(coverage_samples, "logm[0]")
plt.text(2.5, 3, "...", rotation=45,
        color="green", fontsize=12, ha="center", va="center")

plt.text(2.7, 2.3, "...", rotation=45,
        color="green", fontsize=12, ha="center", va="center");