# <center>ASTR4004/8004 - Inference - Part 6</center> 

## Simulation-based inference

Our goal is to perform inference on a model's parameters $\theta$ given observations $D$ and learn the posterior distribution $P(\theta|D)$. Normally, we do this with Bayes' rule:
$$
P(\theta|D)=P(D|\theta)\frac{P(\theta)}{P(D)},
$$
which relies on the likelihood function $P(D|\theta)$.

<font color='red'> However, what if we don't know the likelihood or there is no functional form to evaluate the likelihood?</font>

Remember our model can reproduce different sets of output $D'$ for given $\theta$. This means that we can measure *the frequecy of those outputs that reproduce/match the real data ($D$) to obtain a probability $P(\theta|D)$*, which is the posterior probability. 

This notebook shows a simple example of linear regression using `swyft` and its backend PyTorch (`torch`).

## Useful packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import swyft

## Load data and prepare the input and output pairs

In [None]:
data = np.loadtxt('../../data/samples_m_t.dat')
np.random.seed(0) # fix your seeds for reproducibility!
shuffle_index = np.random.permutation(len(data))
data = data[shuffle_index]

In this example, the first two columns correspond to dark matter mass in units of eV and lifetime in units of second, while the following are effective parameters characterizing the heating, ionization and excitation coefficients from dark matter to the intergalactic medium. In particular, the 3, 4 and 5 columns can be used to compute the heating coefficient normalized by the lifetime as a function of redshift following a Schechter function.

In [None]:
# identify input parameters and output observables (i.e., fheat)
params = np.log10(data[:,:2])

zs = np.arange(5, 35)
fheat = data[:,2] * np.log10(np.exp(data[:,3]*(zs[:,None]-15)) * ((zs[:,None]+1)/16)**data[:,4])
fheat = fheat.T

we could also take a look at the joint and marginal samples to get a feeling for the classification that will happen.

In [None]:
idx_arr = np.linspace(0, len(params) - 1, len(params), dtype=np.int32)
np.random.shuffle(idx_arr)
fig = plt.figure(figsize=(10, 5))
ax = plt.subplot(1, 2, 1)
plt.scatter(..., ..., alpha=0.3, c='r', s=2., label='marginal')
plt.scatter(..., ..., alpha=0.3, c='b', s=2., label='joint');
plt.xlabel(r'$log_{10}(m_{\chi}/ {\rm eV})$')
plt.ylabel(r"$\log_{10}\left[f_{\rm heat}/\tau_{\chi} (s^{-1})\right]$")
plt.legend()

ax = plt.subplot(1, 2, 2)
plt.scatter(..., ..., alpha=0.3, c='r', s=2., label='marginal')
plt.scatter(..., ..., alpha=0.3, c='b', s=2., label='joint');
plt.xlabel(r'$log_{10}(\tau_{\chi}/s)$')
plt.legend()

In [None]:
# We keep the first sample as observation, and use the rest for training
samples = swyft.Samples(fheat = fheat[1:], params = params[1:])
obs = ...

for i in range(100):
    plt.plot(zs, samples[i]['fheat'], color='k', lw=0.1)
    
plt.plot(zs, obs['fheat'], color='r', lw = 2, label = 'target obs')
plt.ylabel(r"$\log_{10}\left[f_{\rm heat}/\tau_{\rm chi} (s^{-1})\right]$")
plt.xlabel(r"$z$")
plt.legend(loc=0)

## Inference network

Swyft comes with a few default networks. Here we use swyft.LogRatioEstimator_1dim, which is a dense network that estimates one-dimensional posteriors. You can use LogRatioEstimator_Ndim to estimate higher-dimensional marginalized posteriors.

In [None]:
class Network(swyft.SwyftModule):
    def __init__(self):
        super().__init__()
        ...
        
    def forward(self, data, theta):
        ...
        
        return logratios

## Training
Training is now done using the SwyftTrainer class.

In [None]:
model = swyft.SwyftTrainer(precision = 64)
network = Network()
model.fit(network, swyft.SwyftDataModule(samples))

## Inference

Since the inference network estimates the logarithm of the posterior-to-prior ratio, we can obtain weighted posterior samples by running many prior samples through the inference network. To this end, we first generate prior samples.

In [None]:
logm_min = 6
logm_max = 12
logt_min = 26
logt_max = 33
prior = np.random.rand(100000, 2)
prior[:,0] = prior[:,0] * (logm_max - logm_min) + logm_min
prior[:,1] = prior[:,1] * (logt_max - logt_min) + logt_min

prior_samples = swyft.Samples(params = prior)

Then we evaluate the inference network by using the infer method of the swyft.Trainer object.

In [None]:
predictions = ...

truth = {k: v for k, v in zip(["params[%i]"%i for i in range(2)], obs['params'])}
swyft.plot_posterior(predictions, ["params[%i]"%i for i in range(2)], truth=truth, 
                    labels = [r'$\log_{10}(m_{\chi}/{\rm eV})$', 
                              r'$\log_{10}(\tau_{\chi}/{\rm s})$']);