In [None]:
import astropy.units as u
import matplotlib.pyplot as plt
%matplotlib qt
plt.style.use('nice.mplstyle')
import numpy as np
import scipy.stats as st

from nustar_tools.spectra.grade_spectra import GradeCollection
from nustar_tools.spectra.response import ResponseHandler

import yaff.fitting as fitting
from yaff import plotting as yap

In [None]:
def thermal(arg_dict: dict[str, object]):
    # Imports need to be inside the model
    # function for pickling/multiprocessing
    from sunkit_spex.legacy import thermal

    # The dict type annotation in the function
    # declaration is ambiguous; so, annotate the variables here
    ph_edges: np.ndarray = arg_dict['photon_energy_edges']
    params: dict[str, fitting.Parameter] = arg_dict['parameters']

    thermal_portion = thermal.thermal_emission(
        energy_edges=ph_edges << u.keV,
        temperature=params['temperature'].as_quantity(),
        emission_measure=params['emission_measure'].as_quantity()
    ).to_value(u.ph / u.s / u.keV / u.cm**2)

    return thermal_portion

In [None]:
data_dir = 'nustar-data/'
file_format = 'fpm{fpm}_g{grade}.pha'

collection = GradeCollection(
    f'{data_dir}{file_format}',
    grades = ['0-4', '21-24'],
    fpms = ['A', 'B']
)
collection.prepare_data()

In [None]:
rmf_file = f'{data_dir}/fpmA_g0-4.rmf'
arf_file = f'{data_dir}/fpmA_g0-4.arf'
handler = ResponseHandler(rmf_file, arf_file)

In [None]:
counts, edges = collection.data['A']['0-4'].spectrum
srm = handler.srm

dp = fitting.DataPacket(
    counts=counts,
    counts_error=np.sqrt(counts.value) << u.count,
    effective_exposure=collection.data['A']['0-4'].exposure,
    background_counts=0*counts,
    background_counts_error=0*counts,
    count_energy_edges=edges,
    photon_energy_edges=handler.energy_edges,
    response_matrix=srm
)

In [None]:
# Set energy bounds to restrict where we care about the likelihood
mids = dp.count_energy_edges[:-1] + np.diff(dp.count_energy_edges)/2
energy_bounds = (mids >= 3) & (mids <= 5)

def log_likelihood(data: fitting.DataPacket, model: np.ndarray):
    # For Poisson likelihood, the model must comprise
    # of integers, otherwise scipy shits itself
    discrete_model = model.astype(int)

    # Any zero-count bins cannot contribute to the log-likelihood for two reasons:
    # 1. the PMF of a "generalized" Poisson distribution is 1 at zero, 0 elsewhere,
    #    meaning ANY model value other than zero will screw up the log likelihood
    # 2. even if the model IS exactly zero, it doesn't affect the log likelihood as ln(1) = 0.
    restrict = (data.counts > 0) & energy_bounds
    return st.poisson(data.counts).logpmf(discrete_model)[restrict].sum()

# Define the parameters with their initial guesses (all frozen to start)
starting_parameters = {
    'temperature': fitting.Parameter(4 << u.MK, frozen=True),
    'emission_measure': fitting.Parameter(100 << (1e42 * u.cm**-3), frozen=True),
}

# The priors we give are just "bounds" on
# the physical values. They could be something
# more interesting like a truncated normal,
# or some other probability distribution.
log_priors = {
    'temperature': fitting.simple_bounds(0, 100),
    'emission_measure': fitting.simple_bounds(0, 10000),
}

# Name the parameter groups so we can loop
# over them later
thermal_names = ['temperature', 'emission_measure']

In [None]:
fitta = fitting.BayesFitter(
    data=dp,
    model_function=thermal,
    parameters=starting_parameters,
    log_priors=log_priors,
    log_likelihood=log_likelihood
)

In [None]:
yap.plot_data_model(fitta)

In [None]:
print("minimize thermal")
for n in thermal_names:
    fitta.parameters[n].frozen = False
fitta = fitting.normal_minimize(fitta)

In [None]:
yap.plot_data_model(fitta, num_model_samples=30)

In [None]:
fitta.parameters

In [None]:
fitta.perform_fit({'nwalkers': 4}, {'nsteps': 2000, 'progress': True})

In [None]:
import lzma
fitta.save('nustar-fit.dill.xz', open_func=lzma.open)

In [None]:
fitta.emplace_best_mcmc()

In [None]:
fitta.parameters

In [None]:
import importlib
importlib.reload(yap)
yap.plot_data_model(fitta, num_model_samples=20)
plt.show()

In [None]:
fitta.parameters

In [None]:
from yaff import plotting as yap
yap.plot_parameter_chains(fitta)

In [None]:
fig = plt.figure(layout='tight', figsize=(20, 20))
yap.corner_plot(fitta, burnin=100, fig=fig)