In [54]:
import numpy as np
import astropy.units as u
import matplotlib.pyplot as plt
import scipy.stats as st
%matplotlib qt

In [81]:
def linear_slew(a, b, steps):
    slope = (b - a) / steps
    return a + slope * np.arange(steps)

total_time = 20 << u.s
time_bin = 0.1 << u.s

time_bins = np.arange(0, total_time.value + time_bin.value, time_bin.value) << u.s
time_mids = time_bins[:-1] + np.diff(time_bins)/2

num_time_bins = time_mids.size

temp_start, temp_end = (10, 22) << u.MK
em_start, em_end = (1, 2) << (1e49 * u.cm**-3)

In [82]:
temps = linear_slew(temp_start, temp_end, num_time_bins)
ems = linear_slew(em_start, em_end, num_time_bins)

In [83]:
def spike(t, location, duration, height):
    slice_ = (t >= (location - duration/2)) & (t < (location + duration/2))

    def spike(tt):
        arg1 = (tt < duration / 2) * tt * (2 * height / duration)
        arg2 = (tt >= duration / 2) * 2 * height * (1 - tt / duration)
        return arg1 + arg2

    output = spike((t[1] - t[0]) * np.arange((slice_.sum())))
    ret = np.zeros_like(t)
    ret[slice_] = output
    return ret

In [119]:
def noise_psd(N, psd = lambda f: 1):
        X_white = np.fft.rfft(np.random.randn(N))
        S = psd(np.fft.rfftfreq(N))
        # Normalize S
        S = S / np.sqrt(np.mean(S**2))
        X_shaped = X_white * S
        return np.fft.irfft(X_shaped)

def PSDGenerator(f):
    return lambda N: noise_psd(N, f)

@PSDGenerator
def white_noise(f):
    return 1

@PSDGenerator
def blue_noise(f):
    return np.sqrt(f)

@PSDGenerator
def violet_noise(f):
    return f

@PSDGenerator
def brownian_noise(f):
    return 1/np.where(f == 0, float('inf'), f)

@PSDGenerator
def pink_noise(f):
    return 1/np.where(f == 0, float('inf'), np.sqrt(f))

In [201]:
# figure out the spectral slope time shape

# dominate this with a linear shift from high to low electron index
# aka a soft to hard spectrum
index_start, index_end = 8, 3

# we have particle injections across the interval
# which occur in blips where the index hardens
# significantly
num_injections = 10

# triangular injection rise/fall within this time
injection_duration = 1 << u.s

# how much harder does the spectrum become per injection,
# as a proportion?
injection_index_change = 0.4

# Deterministic rng, if we want
rng = np.random.default_rng()
# rng = np.random.default_rng(np.random.MT19937(1275830))
injection_time_indices = rng.integers(0, num_time_bins - 1, size=num_injections)

# The spectral index starts linearly going up
indices = np.abs(pink_noise(time_mids.size)) + linear_slew(index_start, index_end, num_time_bins)

# The electron flux goes up and then back down
efluxes = np.abs(pink_noise(time_mids.size))
efluxes += 2

# The spectral indices and electron fluxes spike
# at the same time: beamlike injection of electrons
for it in injection_time_indices:
    cur_index = indices[it] 
    indices -= (s := spike(
        time_mids.to_value(u.s),
        time_mids[it].to_value(u.s),
        injection_duration.to_value(u.s),
        injection_index_change * cur_index
    ))

    cur_flux = efluxes[it]
    injection_flux_change = 1
    efluxes += spike(
        time_mids.to_value(u.s),
        time_mids[it].to_value(u.s),
        injection_duration.to_value(u.s),
        injection_flux_change * cur_flux
    )

efluxes -= 2

In [202]:
fig, ax = plt.subplots()
ax.stairs(indices, time_bins.value, label='spectral index')
ax.stairs(efluxes, time_bins.value, label='electron flux')
ax.legend()
plt.show()

In [203]:
cutoff_energies = (20 << u.keV) * np.ones_like(time_mids.value)

In [204]:
from yaff import common_models as cm
from yaff.fitting import Parameter

def model(params: cm.ArgsT):
    return cm.thermal(params) + cm.thick_target(params)

In [205]:
energy_bins = np.geomspace(4, 500, num=60)
thermal_truth = np.zeros((time_mids.size, energy_bins.size - 1))
nonthermal_truth = np.zeros((time_mids.size, energy_bins.size - 1))
for i in range(cutoff_energies.size):
    parameters = {
        'temperature': temps[i],
        'emission_measure': ems[i],
        'cutoff_energy': cutoff_energies[i],
        'spectral_index': indices[i] << u.one,
        'electron_flux': (float(efluxes[i]) << (1e35 * u.electron / u.s))
    }
    parameters = {k: Parameter(v, False) for (k, v) in parameters.items()}
    args = {
        'parameters': parameters,
        'photon_energy_edges': energy_bins
    }
    # spectrogram[i] = model(args)
    thermal_truth[i] = cm.thermal(args)
    nonthermal_truth[i] = cm.thick_target(args)

thermal_truth <<= (u.ph / u.cm**2 / u.keV / u.s)
nonthermal_truth <<= (u.ph / u.cm**2 / u.keV / u.s)

def full_spectrogram():
    return thermal_truth + nonthermal_truth



In [206]:
import matplotlib.colors as mcol

unit = u.ph / u.keV / u.cm**2 / u.s

fig, ax = plt.subplots()
norm = mcol.SymLogNorm(linthresh=1e-4, vmin=0, vmax=full_spectrogram().to_value(unit).max())
cmap = plt.get_cmap('plasma').copy()
ax.pcolormesh(
    time_bins.to_value(u.s),
    energy_bins,
    full_spectrogram().T.to_value(unit),
    norm=norm,
    cmap=cmap
)
ax.set(yscale='log', xlabel='time (s)', ylabel='energy (keV)')
plt.show()

In [207]:
test_idxs = (10, 20)
fig, ax = plt.subplots()
for i in test_idxs:
    th_example = thermal_truth[i]
    nth_example = nonthermal_truth[i]
    ret = ax.stairs(th_example, energy_bins)
    col = ret.get_edgecolor()
    ax.stairs(nth_example, energy_bins, color=col)
    ax.set(xscale='log', yscale='log', xlabel='energy keV', ylabel='flux')

In [208]:
energy_mids = (energy_bins[:-1] + np.diff(energy_bins)/2)
closest = lambda a, v: np.argmin(np.abs(a - v))

'''
Let's say we have a Ba133 source on board.
For X-rays that's about 2e5 count/second.
There will be lines at 4 keV, 31 keV, and 81 keV
'''
count_rate = 10
noise = (count_rate * num_time_bins * time_bin) * np.ones_like(thermal_truth[0].value)
noise[closest(energy_mids, 4):closest(energy_mids, 6)] *= 100
noise[closest(energy_mids, 29):closest(energy_mids, 32)] *= 10
noise[closest(energy_mids, 80):closest(energy_mids, 84)] *= 5

background = (st.norm.rvs(loc=noise, scale=np.sqrt(noise), size=(time_bins.size - 1, noise.size)).T / num_time_bins) * u.ph
# background *= 0

area = 10 << u.cm**2
exposure = total_time
de = np.diff(energy_bins) << u.keV
thermal_photons = (thermal_truth * area * time_bin * de).to(u.ph)
nonthermal_photons = (nonthermal_truth * area * time_bin * de).to(u.ph)
# photons += (background.T.to_value(u.one) << u.ph)
thermal_photons = thermal_photons.astype(int)
nonthermal_photons = nonthermal_photons.astype(int)

systematic = 0.05
data = (st.norm.rvs(
    loc=(tot := (thermal_photons + nonthermal_photons + background.T).to_value(u.ph)),
    scale=np.sqrt(tot + (systematic * tot)**2)
) << u.ph).astype(int)

data[data < 0] = 0

In [209]:
import matplotlib.colors as mcol

unit = u.ph

fig, ax = plt.subplots()
norm = mcol.LogNorm()
cmap = plt.get_cmap('plasma').copy()
ax.pcolormesh(
    time_bins.to_value(u.s),
    energy_bins,
    data.T.to_value(unit),
    norm=norm,
    cmap=cmap
)
ax.set(yscale='log', xlabel='time (s)', ylabel='energy (keV)')
plt.show()

In [210]:
fig, ax = plt.subplots()
ax.stairs(data.sum(axis=0), energy_bins, label='data')
ax.stairs(thermal_photons.sum(axis=0), energy_bins, label='thermal')
ax.stairs(nonthermal_photons.sum(axis=0), energy_bins, label='thermal')
ax.stairs(background.sum(axis=1), energy_bins, label='bg')
ax.set(xscale='log', yscale='log')
ax.legend()
plt.show()

In [195]:
fig, ax = plt.subplots()
for s in data.T:
    ax.stairs(s.value, time_bins.value)
ax.set(yscale='log', ylim=(1e-9, 1e6))
plt.show()

In [211]:
from tedec import decomp

In [221]:
dimensionless_data = data.to_value(u.ph).T

nearest = lambda a, v: np.argmin(np.abs(a - v))

ta, tb = nearest(energy_mids, 4), nearest(energy_mids, 8)
na, nb = nearest(energy_mids, 37), nearest(energy_mids, 42)

pack = decomp.DataPacket(
    data=dimensionless_data,
    basis_timeseries=[
        dimensionless_data[ta:tb].sum(axis=0),
        dimensionless_data[na:nb].sum(axis=0),
    ],
    constant_offset=True
)

systematic = 0.1
ret = decomp.bootstrap(
    pack,
    errors=np.sqrt(dimensionless_data + (systematic * dimensionless_data)**2),
    num_iter=1000
)

In [222]:
from yaff import plotting
from astropy import visualization as viz

th_mean = ret[:, 0, :].mean(axis=0) << u.ph
th_std = ret[:, 0, :].std(axis=0) << u.ph
nth_mean = ret[:, 1, :].mean(axis=0) << u.ph
nth_std = ret[:, 1, :].std(axis=0) << u.ph

# scale by # time bins (need to update)
bkg_part = ret[:, 2, :]# * (num_time_bins - 1)
bkg_mean = bkg_part.mean(axis=0) * u.ph
bkg_std = bkg_part.std(axis=0) * u.ph

fig, ax = plt.subplots()

thermal_ph = (thermal_truth.sum(axis=0) * time_bin * area * de).to_value(u.ph)
nonthermal_ph = (nonthermal_truth.sum(axis=0) * time_bin * area * de).to_value(u.ph)

with viz.quantity_support():
    ax.stairs(thermal_ph, energy_bins, label='true thermal spectrum')
    ax.stairs(nonthermal_ph, energy_bins, label='true nonthermal spectrum')
    ax.stairs(dimensionless_data.sum(axis=1), energy_bins, label='full spectrum')
    ax.stairs(background.sum(axis=1), energy_bins, label='true background')
    
    num_sigma = 2
    plotting.stairs_with_error(energy_bins << u.keV, th_mean, num_sigma*th_std, ax=ax, label='decomposed thermal')
    plotting.stairs_with_error(energy_bins << u.keV, nth_mean, num_sigma*nth_std, ax=ax, label='decomposed nonthermal')
    plotting.stairs_with_error(energy_bins << u.keV, bkg_mean, num_sigma*bkg_std, ax=ax, label='bkg decom')

ax.legend()
ax.set(xscale='log', yscale='log', ylim=(None, None))
plt.show()