In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import astropy.io.fits as fits
import sys
import builtins
import corner
import astromodels
import yaml
import os
import re
from ultranest.integrator import read_file
from ultranest.plot import runplot, traceplot

from pathlib import Path
from time import perf_counter
from emcee.autocorr import integrated_time

from setup_define_scripts import (load_data_from_yaml, load_params_priors_from_yaml, 
                                  build_components_from_yaml, build_spectrum, build_model_and_data_from_yaml,
                                  kev_plot_xylike_with_ind_model, hz_plot_xylike_data, 
                                  plot_ogip_with_model, hz_plot_model_space_sed,quick_eval,
                                  hz_eval_and_plot_sed, add_bhjet_radiative_components_to_plot, 
                                  )


from threeML.io.logging import silence_warnings
silence_warnings()

from threeML import *
from threeML import (
    SmoothlyBrokenPowerLaw,TbAbs,ZDust,
    Constant, PointSource,Model,DataList,Log_uniform_prior,
    Uniform_prior,Truncated_gaussian,Log_normal,
)

from threeML.utils.OGIP.response import OGIPResponse
%matplotlib inline
_suffix_re = re.compile(r"_(\d+)$")

sys.path.append("PyBHJet/")
from pybhjet_3ml import BHJetModel
from importing_data import * #converts normal data into expected threeml units 
from unit_conversion import * #unit conversions after bhjet is run
from bhjet_plotting import * #controlling the output 

def _bhjet_safe_clone_model(model_instance):
    return model_instance
astromodels.clone_model = _bhjet_safe_clone_model


latex_map = {
    "pspec": r"$p$",
    "z_diss": r"$z_{\rm diss}$",
    "z_acc": r"$z_{\rm acc}$",
    "jetrat": r"$N_j$", 
    "r": r"$r$", 
    "t_e": r"$T_e$", 
    "f_heat": r"$f_{\rm heat}$", 
    "f_sc": r"$f_{\rm sc}$",
    "e_bmv": r"$A_v$",
}


##### 0. If there are issues with importing bhjet model or any of the other scripts in the PyBHJet dir, check the paths for the sys.path to make sure that they point to the main directory for pybhjet.


##### Step 1: First, to load the data, this is done via a yaml file: ```model_data_load.yaml```
 
 The first section is ```data```: this will take two kinds of data: XYLike (flux) and OGIP (counts). The section names that are supplied to these two are how they will be used to be referred to/plotted later (e.g. right now they are "flux_data_catgegory" and "ogip_data_category")

 _________________

 ##### **Loading flux data**: 
1. ```kind```: always should be "combine_dataframes" -> this supplies the data to a function (importing_data.py) that expects that your files are (hz, mJy, mJy_err). 
    1. It also expects that your data is divided into files (can be multiple) for radio, IR, and Opt-UV, (e.g. NGC4594_radio.dat, NGC4594_IR.dat, NGC4594_UV.dat). 
    2. It will convert the data to the units that threeml expects: (diff. photon flux, keV) and return three XYLike objects that threeml expects, which are named in the code rad, ir, uv.  
2. ```directory```: "path/to/files" (from where the yaml is located) 
3. ```extension```: suffix of your file type 
4. ```columns```: this can be left as is, if your data is correctly supplied as (hz, mJy, mJy_err) 

_________________

##### **Loading OGIP data** 
1. ```kind```: will always be set to "ogip" 
2. ```observation```: "path/to/files" (from where the yaml is located)
3. ```arf_file```:     "path/to/files" 
4. ```response```:     "path/to/files" 
5. ```background```:   "path/to/files" 
6. ```energy_range```: input 
7. ```rebin_on_source```: input 

_________________

##### **Assigning Sources**: 

Each dataset needs to be turned into it's own "source" in threeml, so the next section creates a source for each one (e.g. src_radio), assigns the dataset loaded (named rad, ir, uv), and also the "spectrum". Can leave RA, Dec at 0,0. 

ra: 0.0 \
dec: 0.0 \
spectrum: radio_model \
datasets: [rad] 


ra: 0.0 \
dec: 0.0 \
spectrum: irvu_model \
datasets: [ir, uv]



##### Step 2: Model Definition, this is done via a yaml file: ```model_data_load.yaml```

The second section is ```model```. This will include as few/many components as you plan to use for fitting all your data. This is where the model objects are assigned to names. 

In ```components```, you need to specifically use the name as provided from pybhjet/Xspec/threeml. Do not include parentheses after the name! You can then assign each model component to a "colloquial" name that you'll use, e.g. : 

jet: BHJetModel \
gal_ext: TbAbs \
intr_ext: TbAbs \
dust_ext: ZDust \

_________________

Then, in ```spectra```, you'll decide which components you want applied to which datasets. The way that you form these model compositions will depend on the names that you gave them in ```components```. 

radio_model: jet \
iruv_model: dust_ext * jet \
xray_model: gal_ext * intr_ext * jet
_________________

Thirdly, in ```sed_components``` - this is repetitive from spectra, but relates to the plotting. Eventually I will make this consistent and just use the previous components: 

jet: jet \
dust_times_jet: dust_ext * jet \
abs_times_jet: gal_ext * intr_ext * jet

_________________

##### Step 3: Model Parameters, this is also done via a yaml file: ```model_data_load.yaml```

Create a section under ```parameters``` for each model ```component``` that was created earlier. This is where each indiviudal parameter value will be set, fixed/free, and priors: 

value: 5.9 \
free: true \
bounds: [2.0, 10.0] \
prior: {type: uniform, min: 2.0, max: 10.0}


##### **Actually Loading Data & Model**: 

In [None]:
model_yaml_path = "ngc4594/model_data_load.yaml"
data_yaml_path = 'ngc4594/model_data_load.yaml'

model_obj, data_obj, model_components, data_dict, sources, sed_comp = build_model_and_data_from_yaml(data_yaml_path, model_yaml_path)

This is a shortcut way to see the parameters for each of your components that you assign: 

In [None]:
jet = model_components["jet"]
gal_ext = model_components["gal_ext"]
dust_ext = model_components["dust_ext"]
intr_ext = model_components["intr_ext"]

In [None]:
jet.free_parameters

In [None]:
fig, ax = hz_plot_model_space_sed(data_dict, model_components, sed_comp)
add_bhjet_radiative_components_to_plot(model_components, ax) 

plt.ylim(1e-17, 1e-10)
plt.xlim(1e8, 1e20)
plt.show()

In [None]:
fig = plot_ogip_with_model(data_dict['ogip_data_category'], model_obj = model_obj, model_labels=['Chandra'])

In [None]:
fig, ax, stat_total, per_stats = hz_eval_and_plot_sed(model_obj,data_dict,model_components,sed_comp)
plt.ylim(1e-17, 1e-10)
plt.xlim(1e8, 1e20)
plt.show()

#### Testing MCMC 

In [None]:
bhjet_ba = BayesianAnalysis(model_obj, data_obj)
bhjet_ba.set_sampler("emcee")

n_dim = len(model_obj.free_parameters)
n_walkers = 3 * n_dim   
# n_threads = min(8, n_walkers) 

burn_in = 5                                  
n_samples = 10      

bhjet_ba.sampler.setup(
    n_walkers=n_walkers,
    share_spectrum = True,
    n_burn_in=burn_in,
    n_iterations=n_samples,
    # n_threads=n_threads,
    seed=42,
)

# bhjet_ba.sample()
# res_emcee_bpl = ba.results
# res_emcee_bpl.write_to("test2.fits", overwrite=True)

Testing Nested Sampling

In [None]:
import ultranest.stepsampler

ba = BayesianAnalysis(model_obj, data_obj)
ba.set_sampler("ultranest")

#slice sampler in ultranest 
n_steps = 2 * len(model_obj.free_parameters)
ba.sampler.setup(stepsampler = ultranest.stepsampler.SliceSampler(nsteps=n_steps, generate_direction=ultranest.stepsampler.generate_mixture_random_direction))

#ba.sample()
# res_emcee_bpl = ba.results
# res_emcee_bpl.write_to("test2.fits", overwrite=True)

Output Stuff: 

In [None]:
def _strip_threeml_index_from_tail(full_name: str) -> str:
    '''annoying workaround for the numbers that threeml adds '''
    parts = full_name.split(".")
    parts[-1] = _suffix_re.sub("", parts[-1], count=1)
    return ".".join(parts)

def reduce_to_tail_key(full_name: str) -> str:
    """
    turn full 3ml parameter name to a key for plotting (assigned to latex names)
    src_xray.spectrum.main.composite.jetrat_1 -> jetrat
    """
    stripped = _strip_threeml_index_from_tail(full_name)
    return stripped.split(".")[-1]

def set_model_un_output(model_obj, post_path, value_col: str = "median"):
    
    free_params = model_obj.free_parameters  # OrderedDict of all the param names 

    base_to_full = {}
    for full_name in free_params.keys():
        base = _strip_threeml_index_from_tail(full_name)
        if base not in base_to_full:
            base_to_full[base] = full_name

    # assign values
    for name, val in zip(summary_eq["name"].values, summary_eq[value_col].values):
        base = _strip_threeml_index_from_tail(name)
        full = base_to_full.get(base)
        if full is None:
            continue
        
        free_params[full].value = float(val)


In [None]:
def return_chains_eq_post(path, summary=False): 

    '''returns a numpy array with each row corresponding to a step, with samples for each parameter at that step, 
    columns corresponding to each parameter 
    '''
    
    chains_path = Path(path + "chains")
    eq_wpost_path = chains_path / "chains/equal_weighted_post.txt" 
    equ_post_df = pd.read_csv(eq_wpost_path,sep='\s+',comment="#")
    equ_post_df.columns = [reduce_to_tail_key(c) for c in equ_post_df.columns]
    labels = [latex_map.get(c, c) for c in equ_post_df.columns]

    if summary == True: 
        ndim = len(model_obj.free_parameters)  
        sequence, final = read_file(chains_path, x_dim=ndim)

        summary_eq = pd.DataFrame(
        {
            "name": equ_post_df.columns,
            "min": equ_post_df.min(axis=0).values,
            "max": equ_post_df.max(axis=0).values,
            "mean": equ_post_df.mean(axis=0).values,
            "median": equ_post_df.median(axis=0).values,
            "std": equ_post_df.std(axis=0).values,
        }
        )
        return summary_eq, sequence, chains_path
    
    else: 
        samples = equ_post_df.to_numpy(dtype=float)
        return samples, labels, chains_path


In [None]:
summary_eq, sequence, chains_path = return_chains_eq_post(path, summary=True)
set_model_un_output(model_obj, summary_eq)