Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes STIX file loading and lightcurve/spectrogram plotting #98

Merged
merged 6 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/98.breaking.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update STIX spectrogram and srm loading and plotting functions for the new STIX file format.
17 changes: 10 additions & 7 deletions sunxspex/sunxspex_fitting/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def _atimes2mdates(self, astrotimes):
List of matplotlib dates.
"""
# convert astro time to datetime then use list comprehension to convert to matplotlib dates
return [mdates.date2num(dt.tt.datetime) for dt in astrotimes]
return [mdates.date2num(dt.utc.datetime) for dt in astrotimes]

def _mdates_minute_locator(self, _obs_dt=None):
""" Try to determine a nice tick separation for time axis on the lightcurve.
Expand Down Expand Up @@ -1290,12 +1290,12 @@ def lightcurve(self, energy_ranges=None, axes=None, rebin_time=1):
_y_pos = ax.get_ylim()[0] + (ax.get_ylim()[1]-ax.get_ylim()[0])*0.95 # stop region label overlapping axis spine
if hasattr(self, "_start_background_time") and (type(self._start_background_time) != type(None)) and hasattr(self, "_end_background_time") and (type(self._end_background_time) != type(None)):
ax.axvspan(*self._atimes2mdates([self._start_background_time, self._end_background_time]), alpha=0.1, color='orange')
ax.annotate("BG", (self._atimes2mdates([self._start_background_time])[0], _y_pos), color='orange', va="top", size=_def_fs-8)
ax.annotate("BG", (self._atimes2mdates([self._start_background_time])[0], _y_pos), color='orange', va="top", size=_def_fs-2)

# plot event time range
if hasattr(self, "_start_event_time") and hasattr(self, "_end_event_time"):
ax.axvspan(*self._atimes2mdates([self._start_event_time, self._end_event_time]), alpha=0.1, color='purple')
ax.annotate("Evt", (self._atimes2mdates([self._start_event_time])[0], _y_pos), color='purple', va="top", size=_def_fs-8)
ax.annotate("Evt", (self._atimes2mdates([self._start_event_time])[0], _y_pos), color='purple', va="top", size=_def_fs-2)

self._lightcurve_data = {"mdtimes": _ts, "lightcurves": _lcs, "lightcurve_error": _lcs_err, "energy_ranges": energy_ranges}

Expand Down Expand Up @@ -1395,16 +1395,19 @@ def spectrogram(self, axes=None, rebin_time=1, rebin_energy=1, **kwargs):

ax.set_title(self._instrument()+"Spectrogram [Counts s$^{-1}$]")

# change event and background start and end times from astropy dates to matplotlib dates
start_evt_time, end_evt_time, start_bg_time, end_bg_time = self._atimes2mdates([self._start_event_time, self._end_event_time, self._start_background_time, self._end_background_time])

# plot background time range if there is one
_y_pos = ax.get_ylim()[0] + (ax.get_ylim()[1]-ax.get_ylim()[0])*0.95 # stop region label overlapping axis spine
if hasattr(self, "_start_background_time") and (type(self._start_background_time) != type(None)) and hasattr(self, "_end_background_time") and (type(self._end_background_time) != type(None)):
ax.plot(self._atimes2mdates([self._start_background_time, self._end_background_time]), [etop, etop], alpha=0.9, color='orange', lw=10)
ax.annotate("BG", (self._atimes2mdates([self._start_background_time])[0], _y_pos), color='orange', va="top", size=_def_fs-8)
ax.hlines(y=etop, xmin=start_bg_time, xmax=end_bg_time, alpha=0.9, color='orange', capstyle='butt', lw=10)
ax.annotate("BG", (start_bg_time, _y_pos), color='orange', va="top", size=_def_fs-2)

# plot event time range
if hasattr(self, "_start_event_time") and hasattr(self, "_end_event_time"):
ax.plot(self._atimes2mdates([self._start_event_time, self._end_event_time]), [etop, etop], alpha=0.9, color='#F37AFF', lw=10)
ax.annotate("Evt", (self._atimes2mdates([self._start_event_time])[0], _y_pos), color='#F37AFF', va="top", size=_def_fs-8)
ax.hlines(y=etop, xmin=start_evt_time, xmax=end_evt_time, alpha=0.9, color='#F37AFF', capstyle='butt', lw=10)
ax.annotate("Evt", (start_evt_time, _y_pos), color='#F37AFF', va="top", size=_def_fs-2)

self._spectrogram_data = {"spectrogram": _spect, "extent": _ext}

Expand Down
17 changes: 12 additions & 5 deletions sunxspex/sunxspex_fitting/io.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
The ``io`` module contains code to read instrument specific spectral data.
"""
from astropy.io import fits
import numpy as np

from sunpy.io.special.genx import read_genx
from astropy.io import fits

__all__ = ["_read_pha", "_read_arf", "_read_rmf", "_read_rhessi_spec_file", "_read_rhessi_srm_file",
"_read_stix_spec_file", "_read_stix_srm_file"]
Expand Down Expand Up @@ -150,6 +150,13 @@ def _read_stix_srm_file(srm_file):
`dict`
STIX SRM data (photon bins, count bins, and SRM in units of [counts/keV/photons]).
"""
contents = read_genx(srm_file)
return {"photon_energy_bin_edges": contents["DRM"]['E_2D'], "count_energy_bin_edges": contents["DRM"]['EDGES_OUT'],
"drm": contents['DRM']['SMATRIX']}
with fits.open(srm_file) as hdul:
d0 = hdul[1].header
d1 = hdul[1].data
d3 = hdul[2].data

pcb = np.concatenate((d1['ENERG_LO'][:, None], d1['ENERG_HI'][:, None]), axis=1)

return {"photon_energy_bin_edges": pcb,
"count_energy_bin_edges": np.concatenate((d3['E_MIN'][:, None], d3['E_MAX'][:, None]), axis=1),
"drm": d1['MATRIX']*d0["GEOAREA"]}
25 changes: 5 additions & 20 deletions sunxspex/sunxspex_fitting/stix_spec_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def _get_spec_file_info(spec_file):
_plus_half_bin_width = np.ceil(time_deltas/2)
t_hi = times_mids + _plus_half_bin_width

spec_stimes = [Time(sdict["0"][0]["DATE_BEG"], format='isot', scale='utc')+TimeDelta(time_diff_so2e * u.s)+TimeDelta(dt * u.ds) for dt in t_lo]
spec_etimes = [Time(sdict["0"][0]["DATE_BEG"], format='isot', scale='utc')+TimeDelta(time_diff_so2e * u.s)+TimeDelta(dt * u.ds) for dt in t_hi]
spec_stimes = [Time(sdict["0"][0]["DATE-BEG"], format='isot', scale='utc')+TimeDelta(time_diff_so2e * u.s)+TimeDelta(dt * u.cs) for dt in t_lo]
spec_etimes = [Time(sdict["0"][0]["DATE-BEG"], format='isot', scale='utc')+TimeDelta(time_diff_so2e * u.s)+TimeDelta(dt * u.cs) for dt in t_hi]
time_bins = np.concatenate((np.array(spec_stimes)[:, None], np.array(spec_etimes)[:, None]), axis=1)

channel_bins_inds, channel_bins = _return_masked_bins(sdict)
Expand All @@ -54,21 +54,6 @@ def _get_spec_file_info(spec_file):
return channel_bins, channel_bins_inds, time_bins, lvt, counts, counts_err, cts_rates, cts_rate_err


def _ds_times2s_times(ds_times):
""" STIX time might be in ds (deci-seconds). Convert to seconds.

Parameters
----------
ds_times : time bins
A 2D array of time bins.

Returns
-------
A 2d array of the time bin edges.
"""
return ds_times


def _return_masked_bins(sdict):
""" Return the energy bins where there is data.

Expand All @@ -85,9 +70,9 @@ def _return_masked_bins(sdict):
e_bins = np.concatenate((sdict["4"][1]['e_low'][:, None], sdict["4"][1]['e_high'][:, None]), axis=1)

# get all indices of the energy bins needed
mask_inds = sdict["1"][1]['energy_bin_mask'][0].astype(bool)
mask_inds = sdict["1"][1]['energy_bin_edge_mask'][0].astype(bool)

return mask_inds, e_bins[mask_inds]
return mask_inds, e_bins


def _spec_file_units_check(stix_dict, time_dels):
Expand All @@ -110,7 +95,7 @@ def _spec_file_units_check(stix_dict, time_dels):
"""
# stix can be saved out with counts, counts/sec, or counts/sec/cm^2/keV using counts, rate, or flux, respectively
if stix_dict["0"][0]["BUNIT"] == "counts":
counts = stix_dict["2"][1]["counts"][:, 1:]
counts = stix_dict["2"][1]["counts"][:, :]
counts_err = np.sqrt(counts) # should this be added in quadrature to the estimated count compression error
cts_rates = counts / time_dels[:, None]
cts_rate_err = counts_err / time_dels[:, None]
Expand Down
25 changes: 18 additions & 7 deletions sunxspex/sunxspex_fitting/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,26 @@ def test_read_stix_spec_file(mock_open):
assert np.array_equal(v[1][k], np.arange(k))


@patch('sunxspex.sunxspex_fitting.io.read_genx')
def test_read_stix_srm_file(mock_read_genx):
@patch('astropy.io.fits.open')
def test_read_stix_srm_file(mock_open):

photon_bins = np.array([[1, 2], [2, 3]])
count_bins = np.array([[2, 3], [4, 5]])
drm = np.eye(2)
ret = {"DRM": {'E_2D': photon_bins, 'EDGES_OUT': count_bins, 'SMATRIX': drm}}
count_bins = np.array([[2, 3], [3, 4]])
drm = 2*np.eye(2)

hdul = []
headers = [{}, {'GEOAREA': 2}, {}]
data = [{}, {'MATRIX': np.eye(2), 'ENERG_LO': np.array([1, 2]), 'ENERG_HI': np.array([2, 3])}, {'E_MIN': np.array([2, 3]), 'E_MAX':np.array([3, 4])}]

for i in range(len(headers)):
m = MagicMock()
m.header = headers[i]
m.data = data[i]
hdul.append(m)

mock_open.return_value.__enter__.return_value = hdul
res = _read_stix_srm_file('test.fits')

mock_read_genx.return_value = ret
res = _read_stix_srm_file('test.genx')
assert np.array_equal(res['photon_energy_bin_edges'], photon_bins)
assert np.array_equal(res['count_energy_bin_edges'], count_bins)
assert np.array_equal(res['drm'], drm)
Loading