# Most basic time decomposition simulation
We assume a flare with mostly fixed physical parameters and two basis light curves.
Only the emission measure and electron flux are allowed to vary.
These parameters act as scaling factors for the emission.

The instrument response matrix is ignored for now.
The background is Poisson noise

The physical parameters are defined below:

In [None]:
from astropy import units as u
from yaff.fitting import Parameter

In [None]:
thermal_physical_params = {
    "temperature": Parameter(20 << u.MK, True),
    "emission_measure": Parameter(1 << (1e49 * u.cm**-3), True),
}

nonthermal_physical_params = {
    "electron_flux": Parameter(2 << (1e35 * u.electron / u.s), True),
    "spectral_index": Parameter(6 << u.one, True),
    "cutoff_energy": Parameter(25 << u.keV, True),
}

## Start the actual simulation process

In [None]:
import os

# Set this so that we don't use too many cores
# for matrix multiplication when fitting data
os.environ["OMP_NUM_CORES"] = "6"

import copy

from astropy import visualization as viz
from matplotlib import pyplot as plt
import numpy as np

import scipy.stats as st
from yaff import common_models as cm
from yaff import plotting

from tedec import fractional_brownian_motion as fbm
from tedec import decomp

%matplotlib qt
plt.style.use("style.mplstyle")

## Define physical parameters for the simulation

## Generate some basis light curves to use later

In [None]:
seed = 132457
np.random.seed(seed)
time_bin = 0.1 << u.s
integration = 10 << u.s
steps = int(integration / time_bin)
thermal_basis = fbm.make_timeseries(num=steps, hurst=0.955)
nonthermal_basis = fbm.make_timeseries(num=steps, hurst=0.5)

In [None]:
def rebin_clumps(histogram, clump_size):
    ret = np.zeros(histogram.size // clump_size)
    for i in range(0, histogram.size, clump_size):
        ret[i // clump_size] = histogram[i : i + clump_size].sum()
    return ret


# Bin down the time granularity
real_dt = 0.2 << u.s
bin_down_factor = int(real_dt / time_bin)
thermal_basis = rebin_clumps(thermal_basis, bin_down_factor)
nonthermal_basis = rebin_clumps(nonthermal_basis, bin_down_factor)

In [None]:
fig, ax = plt.subplots()
axx = ax.twinx()
t = np.arange(thermal_basis.size + 1)
ax.stairs(thermal_basis, t, color="red", label="thermal")
ax.legend(loc="upper left")
ax.set(ylabel="thermal magnitude")

axx.stairs(nonthermal_basis, t, color="black", label="nonthermal")
axx.legend(loc="lower right")
axx.set(ylabel="nonthermal magnitude")

for a in (ax, axx):
    a.spines["left"].set_color("red")
ax.yaxis.label.set_color("red")
ax.tick_params(axis="y", colors="red", which="both")

plt.show()

In [None]:
def normalize(s):
    return np.nan_to_num(s / s.sum())


norm_th = normalize(thermal_basis)
norm_nth = normalize(nonthermal_basis)

fig, ax = plt.subplots()
ax.stairs(norm_th, t, label="thermal normalized", color="red")
ax.stairs(norm_nth, t, label="nonthermal normalized", color="black")
ax.legend()
plt.show()

## Take the basis lightcurves and use them to scale $\text{EM}$ and $\varphi_e$

In [None]:
# Vary EM by factor of 3
th_scale = norm_th - norm_th.min()
th_scale /= th_scale.max()
th_scale *= 2
th_scale += 1

# Vary electron flux by max factor of 4
nth_scale = norm_nth - norm_nth.min()
nth_scale /= nth_scale.max()
nth_scale *= 3
nth_scale += 2

## Generate a spectrogram using a (thermal + thick target) model

In [None]:
def model(params: cm.ArgsT):
    return cm.thermal(params) + cm.thick_target(params)

In [None]:
energies = np.geomspace(3, 100, num=40) << u.keV

all_args = {
    "photon_energy_edges": energies.to_value(u.keV),
    "parameters": (thermal_physical_params | nonthermal_physical_params),
}

thermal_args = {
    "photon_energy_edges": energies.to_value(u.keV),
    "parameters": thermal_physical_params,
}

nonthermal_args = {
    "photon_energy_edges": energies.to_value(u.keV),
    "parameters": nonthermal_physical_params,
}

In [None]:
de = energies[1:] - energies[:-1]
area = 10 << u.cm**2


def flux_to_photons(func, args):
    """Convert a photon flux into photons with the areas, de, dt defined above"""
    flux_unit = u.ph / u.cm**2 / u.s / u.keV
    return ((func(args) << flux_unit) * de * (real_dt << u.s) * area).to(u.ph)


thermal_truth = list()
nonthermal_truth = list()
spectrogram = list()
for i in range(th_scale.size):
    # Update the EM and phi_e by the scale factor
    args = copy.deepcopy(all_args)
    args["parameters"]["emission_measure"].value *= th_scale[i]
    args["parameters"]["electron_flux"].value *= nth_scale[i]

    thermal_truth.append(flux_to_photons(cm.thermal, args))
    nonthermal_truth.append(flux_to_photons(cm.thick_target, args))
    spectrogram.append(flux_to_photons(model, args))

spectrogram = (spectrogram << u.ph).T
thermal_truth = (thermal_truth << u.ph).T
nonthermal_truth = (nonthermal_truth << u.ph).T

In [None]:
tests = np.arange(spectrogram.shape[0])

with viz.quantity_support():
    fig, ax = plt.subplots()
    for test in tests:
        ax.stairs(spectrogram[test], real_dt * t)
    ax.set(xlabel="time (s)", ylabel="photons incident")
    plt.show()

In [None]:
from matplotlib import colors as mcol

fig, ax = plt.subplots()
norm = mcol.LogNorm(vmin=spectrogram.min().value, vmax=spectrogram.max().value)
ax.pcolormesh((t * real_dt).value, energies.value, spectrogram.value, norm=norm)
ax.set(xlabel="time", ylabel="energy", yscale="log")
plt.show()

In [None]:
energy_mids = (energies[:-1] + np.diff(energies) / 2) << u.keV


def closest(a, v):
    return np.argmin(np.abs(a - v))


"""
Let's say we have something like a Ba133 source on board.
There will be lines at 4 keV, 31 keV, and 81 keV
"""
count_rate = 20 << (u.Hz / u.keV)
noise = (count_rate * real_dt * de.value) * np.ones(energy_mids.size)
noise[closest(energy_mids, 4 << u.keV) : closest(energy_mids, 5 << u.keV)] *= 100
noise[closest(energy_mids, 29 << u.keV) : closest(energy_mids, 32 << u.keV)] *= 10
noise[closest(energy_mids, 80 << u.keV) : closest(energy_mids, 84 << u.keV)] *= 5

background = list()
for _ in range(thermal_basis.shape[0]):
    background.append(st.poisson.rvs(noise))
background = (np.array(background) << u.ct).T

noise.shape, background.shape

In [None]:
from matplotlib import colors as mcol

fig, ax = plt.subplots()
# ax.pcolormesh(t, energies, background)
norm = mcol.LogNorm(vmin=background.value.min(), vmax=3 * background.value.max())
# norm = None
norm = None
ax.pcolormesh(t, energies.value, background.value, norm=norm)
ax.set(xlabel="time", ylabel="energy")
plt.show()

In [None]:
# Add counting statistics & systematics onto the photon data
systematic = 0.02
dimensionless_spectrogram = spectrogram.to_value(u.ph)
data = (
    st.norm.rvs(
        loc=dimensionless_spectrogram,
        scale=np.sqrt(
            dimensionless_spectrogram + (systematic * dimensionless_spectrogram) ** 2
        ),
    )
    << u.ct
)

# Insert the background from the radioactive source
noisy_data = data + background
if noisy_data.min() < 0:
    raise ValueError("Can't have negative counts")

In [None]:
from matplotlib import colors as mcol

fig, ax = plt.subplots()
norm = mcol.LogNorm(vmin=100, vmax=None)
pcm = ax.pcolormesh(t, energies, noisy_data.value, norm=norm)
fig.colorbar(pcm, label="photons")
ax.set(xlabel="time", ylabel="energy")
plt.show()

In [None]:
reconstructed = spectrogram.sum(axis=1)
fig, ax = plt.subplots()

de = energies[1:] - energies[:-1]

# fake_response = area * (u.ct / u.ph)
fake_response = 1 * (u.ct / u.ph)

with viz.quantity_support():
    ax.stairs(
        fake_response * reconstructed / de.value,
        energies,
        label="reconstructed data",
        linestyle="dashed",
        color="red",
    )
    ax.stairs(
        fake_response * thermal_truth.sum(axis=1) / de.value, energies, label="thermal"
    )
    ax.stairs(
        fake_response * nonthermal_truth.sum(axis=1) / de.value,
        energies,
        label="thermal",
    )
    ax.stairs(
        noisy_data.sum(axis=1) / de.value,
        energies,
        label="noisy, reconstructed data",
        linestyle="dashed",
        color="orange",
    )
    ax.stairs(background.sum(axis=1) / de.value, energies, label="background")

ax.legend()
ax.set(xscale="log", yscale="log", ylabel="photons / keV", xlabel="energy keV")
plt.show()

In [None]:
import importlib

importlib.reload(decomp)


def nearest(a, v):
    return np.argmin(np.abs(a - v))


thermal_index = nearest(energy_mids, 3.5 << u.keV)
nonthermal_index = nearest(energy_mids, 23 << u.keV)

dat = noisy_data.value

pack = decomp.DataPacket(
    data=dat,
    basis_timeseries=[
        dat[thermal_index],
        dat[nonthermal_index],
        # dat[-1]
    ],
    constant_offset=(was_offset := True),
)

systematic = 0.02
ret = decomp.bootstrap(
    pack, errors=np.sqrt(dat + (systematic * dat) ** 2), num_iter=1000
)

In [None]:
th_mean = ret[:, 0, :].mean(axis=0) << u.ph
th_std = ret[:, 0, :].std(axis=0) << u.ph
nth_mean = ret[:, 1, :].mean(axis=0) << u.ph
nth_std = ret[:, 1, :].std(axis=0) << u.ph

# scale by # time bins (need to update)
bkg_part = ret[:, 2, :] * (th_mean.size if was_offset else 1)
bkg_mean = bkg_part.mean(axis=0) << u.ph
bkg_std = bkg_part.std(axis=0) << u.ph

fig, ax = plt.subplots()


def shorthand_stairs(e, val, err, num_sigma, ax, **kw):
    if isinstance(val, u.Quantity):
        if not isinstance(err, u.Quantity) or err.unit != val.unit:
            raise ValueError(
                "Error and value to plot need to both have the SAME unit, or not have units."
            )
        unit = val.unit
        val = val.value
        err = err.to_value(unit)
    else:
        unit = u.ph

    de = energies[1:] - energies[:-1]
    to_plot = (val << unit) / de
    to_plot_err = num_sigma * (err << unit) / de
    # print(to_plot)

    plotting.stairs_with_error(e, to_plot, to_plot_err, ax=ax, **kw)


de = np.diff(energies) << u.keV
de = de.value
with viz.quantity_support():
    num_sigma = 0

    shorthand_stairs(
        energies,
        th := thermal_truth.sum(axis=1),
        np.sqrt(th.value) << th.unit,
        num_sigma,
        ax,
        label="thermal truth",
    )
    shorthand_stairs(
        energies,
        nth := nonthermal_truth.sum(axis=1),
        np.sqrt(nth.value) << nth.unit,
        num_sigma,
        ax,
        label="nonthermal truth",
    )
    shorthand_stairs(
        energies,
        bkg := background.sum(axis=1).value,
        np.sqrt(bkg),
        num_sigma,
        ax,
        label="background truth",
    )

    shorthand_stairs(
        energies, nth_mean, nth_std, num_sigma, ax, label="decomposed nonthermal"
    )
    shorthand_stairs(
        energies, th_mean, th_std, num_sigma, ax, label="decomposed thermal"
    )
    shorthand_stairs(energies, bkg_mean, bkg_std, num_sigma, ax, label="decomposed bkg")

ax.legend()
ax.set(ylabel="ph / keV", xscale="log", yscale="log")  # , ylim=(1e2, None))
plt.show()

## Fit the individually decmoposed components

In [None]:
# Save all of the relevant objects/data into a directory
import pathlib

output = pathlib.Path("typical-flare")
if output.is_dir():
    raise RuntimeError(
        f"Do not overwrite your files by accident... '{output}' exists already"
    )

output.mkdir()

### Fit nonthermal decomposed data

In [None]:
from yaff import fitting
from yaff import common_likelihoods


def systematic(s, c, a):
    return np.sqrt(s**2 + ((c * a).value << s.unit) ** 2)


sys = 0.05
nth_data = fitting.DataPacket(
    counts=(nth_as_cts := nth_mean.to_value(u.ph) << u.ct),
    counts_error=systematic(nth_std, nth_as_cts, sys).to_value(u.ph) << u.ct,
    background_counts=(0 * nth_as_cts),
    background_counts_error=(0 * nth_as_cts),
    effective_exposure=(integration << u.s),
    count_energy_edges=energies,
    photon_energy_edges=energies,
    response_matrix=area * (np.eye(nth_as_cts.size) << (u.ct / u.ph)),
)

nonthermal_priors = {
    "electron_flux": fitting.simple_bounds(0, 20),
    "spectral_index": fitting.simple_bounds(2, 20),
    "cutoff_energy": fitting.simple_bounds(1, 80),
}

rng = np.random.default_rng()
params = {
    k: fitting.Parameter(v.as_quantity() * rng.uniform(0.9, 1.1), frozen=False)
    for (k, v) in nonthermal_physical_params.items()
}

fit_range = (energy_mids > 3 << u.keV) & (energy_mids < 80 << u.keV)
likelihood = common_likelihoods.chi_squared_factory(fit_range)

fr = fitting.BayesFitter(
    data=nth_data,
    model_function=cm.thick_target,
    parameters=params,
    log_priors=nonthermal_priors,
    log_likelihood=likelihood,
)

In [None]:
fr.parameters

In [None]:
fr = fitting.levenberg_minimize(fr)
fr.parameters

In [None]:
fr.run_emcee(emcee_constructor_kw={"nwalkers": 20}, emcee_run_kw={"nsteps": 8000})

In [None]:
from yaff import plotting as yap

yap.plot_parameter_chains(
    fr, names=fr.free_param_names, params=list(fr.free_parameters)
)
plt.show()

In [None]:
samples = fr.generate_model_samples(100)
# samples=None
fig = plt.figure()
yap.plot_data_model(fr, model_samples=samples, fig=fig)
plt.show()

In [None]:
import corner

burnin = 50 * fr.emcee_sampler.nwalkers
corner_chain = fr.emcee_sampler.flatchain[burnin:]
param_names = fr.free_param_names

fig = plt.figure(figsize=(10, 8), layout="tight")
corner.corner(
    corner_chain,
    fig=fig,
    bins=20,
    labels=param_names,
    quantiles=(0.05, 0.5, 0.95),
    show_titles=True,
    truths=(
        nonthermal_physical_params["electron_flux"].value * nth_scale.mean(),
        nonthermal_physical_params["spectral_index"].value,
        nonthermal_physical_params["cutoff_energy"].value,
    ),
    truth_color="red",
)
fig.savefig(output / "decomp nonthermal.png", dpi=300)
plt.show()

In [None]:
import gzip

fr.save(output / "nonthermal-only-fitter.dill.gz", open_func=gzip.open)

### Fit thermal decomposed data

In [None]:
from yaff import fitting
from yaff import common_likelihoods

th_data = fitting.DataPacket(
    counts=(th_as_cts := th_mean.to_value(u.ph) << u.ct),
    counts_error=systematic(th_std, th_as_cts, sys).to_value(u.ph) << u.ct,
    background_counts=(0 * th_as_cts),
    background_counts_error=(0 * th_as_cts),
    effective_exposure=(integration << u.s),
    count_energy_edges=energies,
    photon_energy_edges=energies,
    response_matrix=area * (np.eye(nth_as_cts.size) << (u.ct / u.ph)),
)

thermal_priors = {
    "temperature": fitting.simple_bounds(10, 40),
    "emission_measure": fitting.simple_bounds(1e-4, 1e4),
}

rng = np.random.default_rng()
params = {
    k: fitting.Parameter(v.as_quantity() * rng.uniform(0.9, 1.1), frozen=False)
    for (k, v) in thermal_physical_params.items()
}

likelihood = common_likelihoods.chi_squared_factory(
    restriction=(restriction := (energy_mids < 16 << u.keV))
)

fr = fitting.BayesFitter(
    data=th_data,
    model_function=cm.thermal,
    parameters=params,
    log_priors=thermal_priors,
    log_likelihood=likelihood,
)

In [None]:
fitting.levenberg_minimize(fr, restriction=restriction)

In [None]:
fr.parameters

In [None]:
fr.run_emcee(emcee_constructor_kw={"nwalkers": 20}, emcee_run_kw={"nsteps": 1000})

In [None]:
from yaff import plotting as yap

yap.plot_parameter_chains(
    fr, names=fr.free_param_names, params=list(fr.free_parameters)
)
plt.show()

In [None]:
import corner

burnin = 100 * fr.emcee_sampler.nwalkers
corner_chain = fr.emcee_sampler.flatchain[burnin:]
param_names = fr.free_param_names

fig = plt.figure(figsize=(10, 8), layout="tight")
corner.corner(
    corner_chain,
    fig=fig,
    bins=20,
    labels=param_names,
    quantiles=(0.05, 0.5, 0.95),
    show_titles=True,
    truths=(
        thermal_physical_params["temperature"].value,
        thermal_physical_params["emission_measure"].value * th_scale.mean(),
    ),
    truth_color="red",
)
plt.savefig(output / "decomp thermal.png", dpi=300)
plt.show()

In [None]:
samples = fr.generate_model_samples(num=100)
fig = plt.figure()
yap.plot_data_model(fr, model_samples=samples, fig=fig)
plt.show()

In [None]:
import gzip

fr.save(output / "thermal-only-fitter.dill.gz", open_func=gzip.open)

## Do a traditional two-model fit

In [None]:
dp = fitting.DataPacket(
    counts=(cts := noisy_data.sum(axis=1).value) << u.ct,
    counts_error=np.sqrt(cts + (sys * cts) ** 2) << u.ct,
    background_counts=(bg := background.sum(axis=1).value) << u.ct,
    background_counts_error=np.sqrt(bg + (sys * bg) ** 2) << u.ct,
    effective_exposure=(integration << u.s),
    count_energy_edges=energies,
    photon_energy_edges=energies,
    response_matrix=area * (np.eye(cts.size) << (u.ct / u.ph)),
)

In [None]:
priors = thermal_priors | nonthermal_priors
params = {
    k: fitting.Parameter(v.as_quantity() * rng.uniform(0.9, 1.1), frozen=False)
    for (k, v) in (thermal_physical_params | nonthermal_physical_params).items()
}


def model(args):
    return cm.thermal(args) + cm.thick_target(args)


likelihood = common_likelihoods.chi_squared_factory(
    restriction := energy_mids < (70 << u.keV)
)

fr = fitting.BayesFitter(
    data=dp,
    model_function=model,
    parameters=params,
    log_priors=priors,
    log_likelihood=likelihood,
)

In [None]:
fr.parameters

In [None]:
fitting.levenberg_minimize(fr, restriction)

In [None]:
fr.parameters

In [None]:
fr.run_emcee(emcee_constructor_kw={"nwalkers": 20}, emcee_run_kw={"nsteps": 1000})

In [None]:
from yaff import plotting as yap

yap.plot_parameter_chains(
    fr, names=fr.free_param_names, params=list(fr.free_parameters)
)
plt.show()

In [None]:
fr.parameters

In [None]:
import corner

burnin = 50 * fr.emcee_sampler.nwalkers
corner_chain = fr.emcee_sampler.flatchain[burnin:]
param_names = fr.free_param_names

fig = plt.figure(figsize=(20, 20), layout="tight")
corner.corner(
    corner_chain,
    fig=fig,
    bins=20,
    labels=param_names,
    quantiles=(0.05, 0.5, 0.95),
    show_titles=True,
    truths=(
        thermal_physical_params["temperature"].value,
        thermal_physical_params["emission_measure"].value * th_scale.mean(),
        nonthermal_physical_params["electron_flux"].value * nth_scale.mean(),
        nonthermal_physical_params["spectral_index"].value,
        nonthermal_physical_params["cutoff_energy"].value,
    ),
    truth_color="red",
)

plt.savefig(output / "traditional.png", dpi=300)
plt.show()

In [None]:
samples = fr.generate_model_samples(num=100)
fig = plt.figure()
yap.plot_data_model(fr, model_samples=samples, fig=fig)
plt.show()

In [None]:
import gzip

fr.save(output / "traditional-fitter.dill.gz", open_func=gzip.open)