# Spectral Fit

In [None]:
import sherpa.astro.ui as sau
import matplotlib.pyplot as plt
from sherpa_contrib.chart import save_chart_spectrum

In [None]:
import logging
from pathlib import Path

import yaml
from astropy.table import Table

log = logging.getLogger(__name__)

YAML_DUMP_KWARGS = {
    "sort_keys": False,
    "indent": 4,
    "width": 80,
    "default_flow_style": False,
}


def convert_spectrum_chart_to_rdb(filename, overwrite=False):
    """Convert chart spectrum to rdb format"""
    data = Table.read(filename, format="ascii")
    data.rename_column("col1", "emin")
    data.rename_column("col2", "emax")
    data.rename_column("col3", "flux")

    filename_rdb = Path(filename).with_suffix(".rdb")

    log.info(f"Writing {filename_rdb}")
    data.write(filename_rdb, format="ascii.rdb", overwrite=overwrite)


def sherpa_parameter_to_dict(par):
    """Sherpa parameter to dict"""
    data = {}
    data["name"] = str(par.name)
    data["value"] = float(par.val)
    data["min"] = float(par.min)
    data["max"] = float(par.max)
    data["frozen"] = bool(par.frozen)
    data["unit"] = str(par.units)
    return data


def sherpa_model_to_dict(model):
    """Convert Sherpa model to dict"""
    data = {
        "name": model.name,
        "type": model.type,
    }

    if model.type == "binaryopmodel":
        data["operator"] = str(model.opstr)
        data["lhs"] = sherpa_model_to_dict(model.lhs)
        data["rhs"] = sherpa_model_to_dict(model.rhs)
        return data

    parameters = []

    for par in model.pars:
        data_par = sherpa_parameter_to_dict(par)
        parameters.append(data_par)

    data["parameters"] = parameters
    return data


def write_sherpa_model_to_yaml(model, filename, overwrite=True):
    """Write Sherpa model to YAML file"""
    data = sherpa_model_to_dict(model)

    if Path(filename).exists() and not overwrite:
        raise IOError(f"File exists: {filename}")

    with open(filename, "w") as fh:
        yaml.dump(data, fh, **YAML_DUMP_KWARGS)

    log.info(f"Writing {filename}")


In [None]:
sau.reset()

In [None]:
sau.show_all()

In [None]:
DATASET_IDS = [f"obs-id-{obs_id}" for obs_id in snakemake.config["obs_ids"]]

In [None]:
for dataset_id, filename in zip(DATASET_IDS, snakemake.input):
    sau.load_data(dataset_id, filename)
    sau.group_counts(dataset_id, 10)

sau.notice(0.5, 7)

In [None]:
sau.set_stat("cstat")
sau.set_method("simplex")

In [None]:
for dataset_id in DATASET_IDS:
    sau.set_source(dataset_id, sau.xsphabs.absorption * sau.powlaw1d.pwl)

sau.xsphabs.absorption.nh.val = 0.09
sau.xsphabs.absorption.nh.frozen = True
sau.powlaw1d.pwl.ampl.val = 0.001
sau.powlaw1d.pwl.gamma.val = 1.5

In [None]:
sau.fit()

In [None]:
for dataset_id in DATASET_IDS:
    sau.plot_fit_resid(dataset_id)
    plt.show()

In [None]:
for dataset_id in DATASET_IDS:
    sau.set_pileup_model(dataset_id, sau.jdpileup.jdp)

sau.jdpileup.jdp.f.min = 0.85
sau.jdpileup.jdp.ftime = 0.6
sau.jdpileup.jdp.fracexp = 0.987

In [None]:
sau.fit()

In [None]:
for dataset_id in DATASET_IDS:
    sau.plot_fit_resid(dataset_id)
    plt.show()

In [None]:
for dataset_id in DATASET_IDS:
    sau.set_analysis(dataset_id, "energy", "rate", factor=1)

e_min = 0.5
e_max = 7

save_chart_spectrum(
        str(snakemake.output[0]), elow=e_min, ehigh=e_max, clobber=True, id=DATASET_IDS[0]
)

In [None]:
write_sherpa_model_to_yaml(sau.get_source(id=DATASET_IDS[0]), filename=str(snakemake.output[1]))