# Bayes MCMC spectral fitter: fit a line

In [None]:
import astropy.units as u
import numpy as np
from yaff import fitting, rebin_flux, plotting as yap
import scipy.stats as st

import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('nice.mplstyle')

## Define fake data (counts, errors, etc) for fake spectroscopy
This data is just a line of counts made with `np.linspace`.
It is sampled as a normal distribution assuming $\sqrt N$ errors.

The response matrix starts as diagonal and then gets interpolated to allow different-sized count
vs photon energy bins. The interpolation preserves "probability flux" along the appropriate
response matrix axis.

In [None]:
# Make some fake data
cts = np.linspace(800, 3000, 5) << u.ct
cts_err = np.sqrt(cts.value) << u.ct

# Add a gaussian approximation of Poisson error onto the counts
cts = st.norm.rvs(loc=cts, scale=cts_err) << u.ct
eff_exp = 2 << u.s

count_edges = [2, 4, 6, 8, 10, 12] << u.keV
photon_edges = np.linspace(1, 50, num=40) << u.keV

# the SRM needs to get interpolated along the
# `target` aka counts axis
diag_srm = np.eye(photon_edges.size - 1)

# the rows indicate counts
# cols are photon energy bin entries
# assuming mtarix multiplication S\vec{p} = \vec{c}
# p means photon, c means model (vectors)
srm = list()
for row in diag_srm:
    interp = rebin_flux.flux_conserving_rebin(
        photon_edges, row, count_edges
    )
    srm.append(interp)


srm = np.array(srm).T << (u.ct / u.ph)
area = 1 << u.cm**2

pack = fitting.DataPacket(
    counts=cts,
    counts_error=cts_err,
    background_counts=0 * cts,
    background_counts_error=0 * cts,
    effective_exposure=eff_exp,
    count_energy_edges=count_edges,
    photon_energy_edges=photon_edges,
    response_matrix=(area * srm)
)

## Define a model to fit
Here, we just assume we're gonna fit a line, because we are.
THe model accepts a `dict` of arguments:
- Photon energy edges
- Parameters from the fitter (a `dict[str, yafp.Parameter]`)

These are used to compute the model.

Inside the model function, you can restrict or "tie" certain parameters to one another.
If you were fitting two lines and wanted to keep the intercepts the same, for instance,
this could be enforced in the model function.

The model is also just a pure Python function with very little wrapping it.
This gives flexibility to fit any kind of model you'd like.
It could even be a method bound to an instance of an object.

In [None]:
def line_model(arg_dict: dict[str, object]):
    ''' Fit a line to the data we get '''
    ph_edges = arg_dict['photon_energy_edges']
    params = arg_dict['parameters']

    midpoints = ph_edges[:-1] + np.diff(ph_edges)/2
    return params['intercept'].value + (params['slope'].value*midpoints)

## Probability functions: log likelihood and log priors
The next step is to define a likelihood you would like to use and enforce some prior knowledge on your parameters.

The probability function which gets sampled by `emcee` is the [(log) posterior](https://en.wikipedia.org/wiki/Posterior_probability).
We use the log of the probability so that there is more granularity in the probability fluctuations.

Here we use a $\chi^2$ log likelihood which works in a lot of cases.
You can also use a Poisson or negative binomial likelihood by using e.g.
[`st.poisson.logpdf`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.poisson.html).

In [None]:
def log_likelihood(data: fitting.DataPacket, model: np.ndarray):
    '''Basic chi2 log likelihood'''
    return -np.nan_to_num((data.counts - model)**2 / data.counts_error**2).sum()


# Define the model parameters we want to use
params = {
    'slope': fitting.Parameter(-10 << u.ph / u.keV, frozen=False),
    'intercept': fitting.Parameter(3000 << u.ph, frozen=False)
}

# Define the priors on those parameters (uniform from -1000 to 1000 for each)
log_priors = {
    'slope': fitting.simple_bounds(-10000, 10000),
    'intercept': fitting.simple_bounds(-10000, 10000),
}

## Actual fitting: make sure stuff works
Now that the mathematics defining the model have been set up, fitting is straightforward.
The fitter coordinates parameter variations and basic `emcee.EnsembleSampler` management.
It also facilitates easy conversion from a photon to count model.



In [None]:
fitter = fitting.BayesFitter(
    data=pack,
    model_function=line_model,
    parameters=params,
    log_priors=log_priors,
    log_likelihood=log_likelihood
)

In [None]:
# Look at the parameters to check if they're in good shape
fitter.parameters

In [None]:
yap.plot_data_model(fitter)
# Initial comparison of model to data
# fig, ax = plt.subplots(layout='constrained')
# ax.stairs(pack.counts, pack.count_energy_edges, label='data')
# 
# # The fitter will multiply the response matrix etc
# mod = fitter.eval_model()
# ax.stairs(mod, pack.count_energy_edges, label='initial model guess')
# 
# ax.legend()
# 
# ax.set(xlabel='Energy (keV)', ylabel='Counts (ct)', title='A line of counts and a model')

### Notice that the initial guess is horrendous

In [None]:
# Finally, perform the fit and give emcee some kwargs if you want
fitter.perform_fit(
    emcee_constructor_kw=dict(),
    emcee_run_kw=dict(nsteps=10000, progress=True)
)

# Optionally save the fit result to a compressed pickle file
# fitter.save('test.pkl.xz', open_func=lzma.open)

## Diagnostics: autocorrelation and MCMC chains
It's always a good idea to make sure your fit has "enough" samples; this can be assessed by looking at the autocorrelation time of the parameter chains.

Oftentimes in X-ray spectroscopy the autocorrelation time is very long because the parameters are strongly correlated. It is not always possible to make `emcee` happy, but it's good to check.

In [None]:
# No error thrown means that the autocorrelation time is much shorter
# than the MCMC chain length
print('autocorrelation times:', fitter.emcee_sampler.get_autocorr_time())
slope_chain, inter_chain = fitter.emcee_sampler.flatchain.T

In [None]:
# Plot the chains to see how well things converged
fig = plt.figure(figsize=(8, 8))
yap.plot_parameter_chains(fitter, fig=fig)

## Now, plot some sample models over the data

In [None]:
fig = plt.figure(figsize=(8, 6))
yap.plot_data_model(fitter, num_model_samples=100, fig=fig)