# Fit a real flare: 2011 RHESSI M9
### You may download the data for this flare from this [Google Drive link](https://drive.google.com/file/d/1eL5cczLQX-VPCCEQDrW6WXByTcbhYdEx)

In [None]:
import astropy.units as u
import astropy.time as atime
import numpy as np
from sunkit_spex.extern import rhessi
import os

import matplotlib.pyplot as plt
%matplotlib qt
plt.style.use('nice.mplstyle')

os.environ["OMP_NUM_THREADS"] = "1"

from example_support import adapt_rhessi_data, thermal_and_thick
from yaff import fitting
from yaff import plotting as yap

## First, we load the data in using `sunkit_spex.extern.rhessi.RhessiLoader` and adapt it to the target format

In [None]:
# Load in the data and set the event/background times

rl = rhessi.RhessiLoader(
    spectrum_fn='rhessi-data/trevor-flare-30-jul-2011-logspace-bkg_spec.fits',
    srm_fn='rhessi-data/trevor-flare-30-jul-2011-logspace-bkg_srm.fits',
)

start_event_time = atime.Time('2011-07-30T02:08:20')
end_event_time = atime.Time('2011-07-30T02:10:20')
start_background_time = atime.Time('2011-07-30T01:54:00')
end_background_time = atime.Time('2011-07-30T01:56:00')
rl.update_event_times(start_event_time, end_event_time)
rl.update_background_times(start_background_time, end_background_time)

# Add on a 10% systematic error
rl.systematic_error = 0.1

# Put the "sunkit-spex" format into a DataPacket
dp = adapt_rhessi_data(rl)

## Next, we define the log likelihood function to use and parameters
- The `log_likelihood` enforces the "energy fitting range"---no fancy logic required to enforce this elsewhere
- The parameters are explicitly declared with units and "frozen" state. This is verbose, but the intent is clear.
- Finally, log priors are placed on the parameters. In this case, the prior is just a uniform prior (aka "bounds").

In [None]:
# The background counts are scaled
# by effective exposure already.
background_counts = rl['extras']['background_rate'] * rl['count_channel_binning'] * rl['effective_exposure']

# Set energy bounds to restrict where we care about the likelihood
mids = dp.count_energy_edges[:-1] + np.diff(dp.count_energy_edges)/2
energy_bounds = (mids >= 6) & (mids <= 70)

def log_likelihood(data: fitting.DataPacket, model: np.ndarray):
    '''Basic chi2 log likelihood, which subtracts the
       background from the data'''
    # Some count bins might be negative, so use nan_to_num
    return -np.nan_to_num(
        ((data.counts - background_counts - model)**2 / data.counts_error**2)[energy_bounds]
    ).sum()

# Define the parameters with their initial guesses (all frozen to start)
starting_parameters = {
    'temperature': fitting.Parameter(12 << u.MK, frozen=True),
    'emission_measure': fitting.Parameter(1 << (1e49 * u.cm**-3), frozen=True),
    'electron_flux': fitting.Parameter(20 << (1e35 * u.electron / u.s), frozen=True),
    'spectral_index': fitting.Parameter(3 << u.one, frozen=True),
    'cutoff_energy': fitting.Parameter(10 << u.keV, frozen=True)
}

# The priors we give are just "bounds" on
# the physical values. They could be something
# more interesting like a truncated normal,
# or some other probability distribution.
log_priors = {
    'temperature': fitting.simple_bounds(0, 100),
    'emission_measure': fitting.simple_bounds(0, 10000),
    'electron_flux': fitting.simple_bounds(0, 10000),
    'spectral_index': fitting.simple_bounds(2, 20),
    'cutoff_energy': fitting.simple_bounds(1, 90)
}

# Name the parameter groups so we can loop
# over them later
thermal_names = ['temperature', 'emission_measure']
nonthermal_names = ['electron_flux', 'spectral_index', 'cutoff_energy']

## Construct the actual fitter object with the data, model, priors, and likelihood.

In [None]:
fitta = fitting.BayesFitter(
    data=dp,
    # Model function is defined in external file
    # It's just a normal function
    model_function=thermal_and_thick,
    parameters=starting_parameters,
    log_priors=log_priors,
    log_likelihood=log_likelihood
)

## Before proceeding to an MCMC run, we minimize the parameters using "normal" minimization
- The "normal" minimization uses `scipy.optimize.minimize` and is a ~20 line function which uses the already-assembled `BayesFitter.log_posterior` method.
- The "normal" minimizer is left as a free function rather than a method of the `BayesFitter` class to keep the code more decoupled.

In [None]:
fitta = fitting.BayesFitter(
    data=dp,
    model_function=thermal_and_thick,
    parameters=starting_parameters,
    log_priors=log_priors,
    log_likelihood=log_likelihood
)

print("minimize thermal")
fitta = fitting.normal_minimize(fitta)

print("minimize nonthermal")
for n in thermal_names:
    fitta.parameters[n].frozen = True
for n in nonthermal_names:
    fitta.parameters[n].frozen = False
fitta = fitting.normal_minimize(fitta)

print("minimize all")
for n in (nonthermal_names + thermal_names):
    fitta.parameters[n].frozen = False

fitta = fitting.normal_minimize(fitta)

print('"best-fit" parameters are:')
fitta.parameters

## We can plot the model on top of the data to see how the "normal" minimization did

In [None]:
import importlib
importlib.reload(yap)
yap.plot_data_model(fitta, background_counts=background_counts)
plt.show()

## Now that the parameters have been (quickly) minimized via chi2, we can perturb that solution with MCMC to get meaningful uncertainties

In [None]:
for n in (thermal_names + nonthermal_names):
    fitta.parameters[n].frozen = False

fitta.perform_fit({'nwalkers': os.cpu_count() // 2}, {'nsteps': 1000, 'progress': True})
fitta.emplace_best_mcmc()
fitta.parameters

## Look at the parameter chains to determine the "burn-in," i.e. where the solution has converged

In [None]:
chain = fitta.emcee_sampler.flatchain.T
param_names = list(fitta.parameters.keys())
param_units = list(v.unit for v in fitta.parameters.values())

fig, axs = plt.subplots(nrows=chain.shape[0], layout='constrained', figsize=(10, 8))
for (param_chain, name, ax, unit) in zip(chain, param_names, axs, param_units):
    ax.plot(param_chain, label=name)
    ax.set(title=name, ylabel=unit)

plt.show()

## Make some corner plots of the parameters and annotate with 90% posterior intervals

In [None]:
import corner
burn = 300
corner_chain = chain.T[burn:]

fig = plt.figure(figsize=(20, 20))
corner.corner(
    corner_chain,
    fig=fig,
    bins=20,
    labels=param_names,
    quantiles=(0.05, 0.5, 0.95),
    show_titles=True,
)

## Finally, plot a few (sample) models over the data. Fit seems to have worked.

In [None]:
num_samples = 400
rng = np.random.default_rng()
some_params = rng.choice(corner_chain, size=num_samples)

fig, ax = plt.subplots(layout='constrained')

for pset in some_params:
    for k, v in zip(param_names, pset):
        fitta.parameters[k].value = v
    count_model = fitta.eval_model()
    ax.stairs(count_model, dp.count_energy_edges, color='black', alpha=0.05)

yap.stairs_with_error(
    bins=dp.count_energy_edges << u.keV,
    rate=(dp.counts - background_counts) << u.ct,
    error=np.sqrt(dp.counts_error**2 + background_counts) << u.ct,
    ax=ax,
    line_kw={'color': 'blue', 'lw': 3},
    label='data with error'
)

ax.stairs(background_counts, dp.count_energy_edges, color='gray', lw=3, label='pre-flare background')

ax.legend()
ax.set(title='Data (blue) vs model samples (black)', xlabel='keV', ylabel='counts', xscale='log', yscale='log', xlim=(3.5, 120), ylim=(100, 5e5))