In [None]:
import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt
import scipy.stats as st
import simulation_support as simsup
%matplotlib qt

In [None]:
total_time = 10 << u.s
time_bin = 0.1 << u.s

time_bins = np.arange(0, (total_time + time_bin).to_value(u.s), time_bin.to_value(u.s)) << u.s
time_mids = time_bins[:-1] + np.diff(time_bins)/2

num_time_bins = time_mids.size

temp_start, temp_end = (13, 16) << u.MK
em_start, em_end = (5, 8) << (1e48 * u.cm**-3)

In [None]:
temps = simsup.linear_slew(temp_start, temp_end, num_time_bins)
num_humps = 4
period = total_time / num_humps
scaling = 1 / 40
sine_modulation = 1 + scaling * (1 + np.sin(2 * np.pi * (time_mids / period).to_value(u.one)))
temps *= sine_modulation
ems = simsup.quadratic_slew(em_start, em_end, num_time_bins)

In [None]:
import astropy.visualization as viz

fig, ax = plt.subplots()
with viz.quantity_support():
    ax.stairs(temps, time_bins.value, label='temperature', color='orange')
    axx = ax.twinx()
    axx.stairs(ems, time_bins.value, label='emission measures', color='blue')

ax.legend(loc='upper left')
axx.legend(loc='lower left')
plt.show()

In [None]:
num_injections = 5
injection_duration = 1 << u.s

rng = np.random.default_rng()
# injection_time_indices = rng.integers(0, num_time_bins - 1, size=num_injections)
injection_time_indices = np.arange(num_time_bins)[::num_time_bins // num_injections]
print(injection_time_indices)

# The spectral index starts linearly going up
index_start, index_end = 6, 3
indices = (index_noise := np.abs(simsup.pink_noise(time_mids.size))) + simsup.linear_slew(index_start, index_end, num_time_bins)

efluxes = index_noise.copy() + simsup.linear_slew(0.1, 2, num_time_bins)

# The spectral indices and electron fluxes spike
# at the same time: beamlike injection of electrons
for it in injection_time_indices:
    cur_index = indices[it] 
    injection_index_change = 0.2
    index_delta = min(injection_index_change * cur_index, 3)
    indices -= (s := simsup.spike(
        time_mids.to_value(u.s),
        time_mids[it].to_value(u.s),
        injection_duration.to_value(u.s),
        index_delta
    ))

    cur_flux = efluxes[it]
    print(cur_flux)
    injection_flux_change = 1
    flux_delta = min(injection_flux_change * cur_flux, 3)
    efluxes += simsup.spike(
        time_mids.to_value(u.s),
        time_mids[it].to_value(u.s),
        injection_duration.to_value(u.s),
        flux_delta
    )

In [None]:
fig, ax = plt.subplots()
ax.stairs(indices, time_bins.value, label='spectral index')
ax.stairs(efluxes, time_bins.value, label='electron flux')
for idx in injection_time_indices:
    ax.axvline(time_mids[idx].value, color='red')
ax.legend()
plt.show()

In [None]:
cutoff_energies = (20 << u.keV) * np.ones_like(time_mids.value)

In [None]:
from yaff import common_models as cm
from yaff.fitting import Parameter

def model(params: cm.ArgsT):
    return cm.thermal(params) + cm.thick_target(params)

In [None]:
# Effective area of 10cm2 with 100um of Al
geometric_area = 10 << u.cm**2
response_vector = [1.56e-03, 6.77e-03, 2.30e-02, 6.39e-02, 1.51e-01, 3.09e-01, 5.62e-01, 9.26e-01, 1.40e+00, 1.97e+00, 2.62e+00, 3.31e+00, 4.03e+00, 4.73e+00, 5.40e+00, 6.03e+00, 6.60e+00, 7.11e+00, 7.56e+00, 7.94e+00, 8.28e+00, 8.57e+00, 8.81e+00, 9.01e+00, 9.18e+00, 9.33e+00, 9.44e+00, 9.54e+00, 9.62e+00, 9.69e+00, 9.74e+00, 9.79e+00, 9.83e+00, 9.86e+00, 9.88e+00, 9.90e+00, 9.92e+00, 9.93e+00, 9.95e+00, 9.96e+00, 9.96e+00, 9.97e+00, 9.97e+00, 9.98e+00, 9.98e+00, 9.99e+00, 9.99e+00, 9.99e+00, 9.99e+00, 9.99e+00, 9.99e+00, 9.99e+00, 1.00e+01, 1.00e+01, 1.00e+01, 1.00e+01, 1.00e+01, 1.00e+01, 1.00e+01,] << (u.cm**2 * u.ct / u.ph)

In [None]:
energy_bins = np.geomspace(4, 200, num=60) << u.keV
exposure = total_time
de = np.diff(energy_bins) << u.keV
energy_mids = (energy_bins[:-1] + np.diff(energy_bins)/2)

In [None]:
thermal_truth = np.zeros((time_mids.size, energy_bins.size - 1)) << u.ct
nonthermal_truth = np.zeros((time_mids.size, energy_bins.size - 1)) << u.ct

for i in range(cutoff_energies.size):
    parameters = {
        'temperature': temps[i],
        'emission_measure': ems[i],
        'cutoff_energy': cutoff_energies[i],
        'spectral_index': indices[i] << u.one,
        'electron_flux': (float(efluxes[i]) << (1e35 * u.electron / u.s))
    }
    parameters = {k: Parameter(v, False) for (k, v) in parameters.items()}
    args = {
        'parameters': parameters,
        'photon_energy_edges': energy_bins.to_value(u.keV)
    }
    thermal_truth[i] = response_vector * ((cm.thermal(args) << (u.ph / u.cm**2 / u.keV / u.s)) * time_bin * de)
    nonthermal_truth[i] = response_vector * ((cm.thick_target(args) << (u.ph / u.keV / u.cm**2 / u.s)) * time_bin * de)

thermal_truth[thermal_truth < 0] = 0
nonthermal_truth[nonthermal_truth < 0] = 0

thermal_truth = thermal_truth.astype(int)
nonthermal_truth = nonthermal_truth.astype(int)

def full_spectrogram():
    return thermal_truth + nonthermal_truth

In [None]:
import matplotlib.colors as mcol

unit = u.ct

fig, ax = plt.subplots()
norm = mcol.LogNorm()#linthresh=1, vmin=0, vmax=full_spectrogram().to_value(unit).max())
cmap = plt.get_cmap('plasma').copy()
ax.pcolormesh(
    time_bins.to_value(u.s),
    energy_bins,
    full_spectrogram().T.to_value(unit).astype(int),
    norm=norm,
    cmap=cmap
)
ax.set(yscale='log', xlabel='time (s)', ylabel='energy (keV)')
plt.show()

In [None]:
test_idxs = (10, 20)
with viz.quantity_support():
    fig, ax = plt.subplots()
    for i in test_idxs:
        th_example = thermal_truth[i]
        nth_example = nonthermal_truth[i]
        ret = ax.stairs(th_example, energy_bins)
        col = ret.get_edgecolor()
        ax.stairs(nth_example, energy_bins, color=col)
        ax.set(xscale='log', yscale='log', xlabel='energy keV')
plt.show()

In [None]:
closest = lambda a, v: np.argmin(np.abs(a - v))

'''
Let's say we have a Ba133 source on board.
For X-rays that's about 2e5 count/second.
There will be lines at 4 keV, 31 keV, and 81 keV
'''
noise = np.ones_like(thermal_truth[0].value).astype(float)

baseline_rate = 10 << u.Hz
baseline_cts = (baseline_rate * time_bin * num_time_bins)
noise *= baseline_cts

line1_rate = 200 << u.Hz
line1_cts = (line1_rate * time_bin * num_time_bins)
noise[closest(energy_mids, 6 << u.keV):closest(energy_mids, 7 << u.keV)] = line1_cts

line2_rate = 100 << u.Hz
line2_cts = (line2_rate * time_bin * num_time_bins)
noise[closest(energy_mids, 29 << u.keV):closest(energy_mids, 32 << u.keV)] = line2_cts

# Quieter between lines at high energy
noise[closest(energy_mids, 35 << u.keV):closest(energy_mids, 100 << u.keV)] /= 3
noise[closest(energy_mids, 8 << u.keV):closest(energy_mids, 28 << u.keV)] /= 3
noise[closest(energy_mids, 80 << u.keV):closest(energy_mids, 84 << u.keV)] *= 5
noise[closest(energy_mids, 120 << u.keV):] *= 4

mu = noise
sig = 4 * np.sqrt(noise)
background = ((st.norm.rvs(loc=noise, scale=sig, size=(time_bins.size - 1, noise.size)).T) / num_time_bins).astype(int) << u.ct
# background *= 0
background[background < 0] = 0

mean = (thermal_truth + nonthermal_truth + background.T).to_value(u.ct).astype(int)
data = st.poisson.rvs(mean)

data[data < 0] = 0

In [None]:
import matplotlib.colors as mcol

unit = u.ct

fig, ax = plt.subplots()
norm = mcol.LogNorm()
cmap = plt.get_cmap('plasma').copy()
ax.pcolormesh(
    time_bins.to_value(u.s),
    energy_bins,
    data.T,
    norm=norm,
    cmap=cmap
)
ax.set(yscale='log', xlabel='time (s)', ylabel='energy (keV)', title='noisy spectrogram')
plt.show()

In [None]:
fig, ax = plt.subplots()
with viz.quantity_support():
    ax.stairs(data.sum(axis=0), energy_bins, label='data')
    ax.stairs(thermal_truth.sum(axis=0).astype(int), energy_bins, label='thermal')
    ax.stairs(nonthermal_truth.sum(axis=0).astype(int), energy_bins, label='non thermal')
    ax.stairs(background.sum(axis=1).astype(int), energy_bins, label='bg')
    ax.set(xscale='log', yscale='log')
    ax.legend()
plt.show()

In [None]:
nearest = lambda a, v: np.argmin(np.abs(a - v))

fig, ax = plt.subplots()
energy_bounds = ((4, 10), (10, 20), (20, 40), (35, 47), (40, 80), (80, 300)) << u.keV
for (ea, eb) in energy_bounds:
    a, b = nearest(energy_mids, ea), nearest(energy_mids, eb)
    s = data[:, a:b].sum(axis=1)
    ax.stairs(s, time_bins.value, label=f"{ea:.0f} $\\rightarrow$ {eb:.0f}")
ax.set(yscale='log', ylim=(1e-9, 1e6))
ax.legend()
plt.show()

In [None]:
from tedec import decomp

In [None]:
dimensionless_data = data.T

ta, tb = nearest(energy_mids, 5 << u.keV), nearest(energy_mids, 6 << u.keV)
na, nb = nearest(energy_mids, 36 << u.keV), nearest(energy_mids, 47 << u.keV)

pack = decomp.DataPacket(
    data=dimensionless_data,
    basis_timeseries=[
        dimensionless_data[ta:tb].sum(axis=0),
        dimensionless_data[na:nb].sum(axis=0),
        dimensionless_data[-1]
    ],
    constant_offset=False
)

systematic = 0.
ret = decomp.bootstrap(
    pack,
    errors=np.sqrt(dimensionless_data + (systematic * dimensionless_data)**2),
    num_iter=1000
)

In [None]:
from yaff import plotting
from astropy import visualization as viz

sys = 0.05
err = lambda a: np.sqrt(a + (a * sys)**2)

th_mean = ret[:, 0, :].mean(axis=0) << u.ph
th_std = ret[:, 0, :].std(axis=0) << u.ph
nth_mean = ret[:, 1, :].mean(axis=0) << u.ph
nth_std = ret[:, 1, :].std(axis=0) << u.ph

# scale by # time bins (need to update)
# bkg_part = ret[:, 2, :] * (num_time_bins - 1)
bkg_part = ret[:, 1, :] * 0
bkg_mean = bkg_part.mean(axis=0) * u.ph
bkg_std = bkg_part.std(axis=0) * u.ph

fig, ax = plt.subplots()

# thermal_ph = (thermal_truth.sum(axis=0) * time_bin * area  * de).to_value(u.ph)
# nonthermal_ph = (nonthermal_truth.sum(axis=0) * time_bin * area * de).to_value(u.ph)

with viz.quantity_support():
    # plotting.stairs_with_error(energy_bins, th := thermal_photons.sum(axis=0), err(th.value) << u.ph, label='true thermal spectrum')
    plotting.stairs_with_error(energy_bins, nth := nonthermal_truth.sum(axis=0), err(nth.value) << u.ct, label='true nonthermal spectrum')
    plotting.stairs_with_error(energy_bins, dm := dimensionless_data.sum(axis=1) << u.ph, err(dm.value) << u.ph, label='full spectrum')
    ax.stairs(background.sum(axis=1), energy_bins, label='true background')
    
    num_sigma = 1
    plotting.stairs_with_error(energy_bins << u.keV, th_mean << u.ph, num_sigma*th_std, ax=ax, label='decomposed thermal')
    plotting.stairs_with_error(energy_bins << u.keV, nth_mean << u.ph, num_sigma*nth_std, ax=ax, label='decomposed nonthermal')
    plotting.stairs_with_error(
        energy_bins << u.keV,
        (nth_mean + th_mean) << u.ph,
        num_sigma*np.sqrt(nth_std**2 + th_std**2),
        ax=ax,
        label='decomposed full'
    )

ax.legend()
ax.set(xscale='log', yscale='log', ylim=(None, None))
plt.show()

In [None]:
import fathon
from fathon import fathonUtils as fu

# ...