## This notebook is meant to accompany the trial_nondetection_MAXI.py file that is included in this directory with additional comments. 

### This analysis can produce $\sim 330$ GB of data so be sure that there is enough storage on your computer.

In this notebook, we will go through the code to produce Figure 6 of the associated BAT survey paper. This example outlines how to analyze BAT survey data to obtain a light curve/flux upper limits for a newly identified source, such as MAXI J0637-430. 

First, we need to import our usual python packages.

In [None]:
import glob
import os
import sys
import batanalysis as ba
import matplotlib.pyplot as plt
import numpy as np
import astropy.units as u
from astropy.time import Time, TimeDelta
from astropy.io import fits
from pathlib import Path
import swiftbat
import pickle
from matplotlib import ticker

Then we need to create a custom catalog file with the MAXI source. 

In [1]:
object_name='MAXI J0637-430'

#define the coordinates in RA/Dec & galactic Lat/lon
object_ra=99.09830
object_dec=-42.86781
object_glat=251.51841
object_glon=-20.67087

incat=ba.create_custom_catalog(object_name, object_ra, object_dec, object_glat, object_glon)

If we were continuing an analysis and the above cell was already run, we do not need to create the custom catalog. Instead we can simply do:
```
incat=Path("path/to/custom_catalog.cat")
```

Now that we have our catalog of sources with the recently identified MAXI source, we can search for all the BAT survey datasets that have this object in the BAT FOV with a partial coding fraction of ~19% (ie the area of the detector plane that is exposed to the source coordinates is $> 1000$ cm$^2$.

We query HEASARC for all observations from 2019-11-01 to 2020-01-30, which is when the source was first found and undergoing spectral transitions. Then, we filter these observations to determine which meet our `minexposure` criteria and download them into the current working directory. When then obtain the observation IDs for the successfully downloaded datasets and exclude some problematic observation IDs. 

In [None]:
object_batsource = swiftbat.source(ra=object_ra, dec=object_dec, name=object_name)
table_everything, query = ba.from_heasarc(time_range=Time(["2019-11-01","2020-01-30"]), return_query=True)


minexposure = 1000     # cm^2 after cos adjust
exposures = u.Quantity([object_batsource.exposure(ra=row['ra'], dec=row['dec'], roll=row['roll_angle'])[0] for row in table_everything])
table_exposed = table_everything[exposures.value > minexposure]

result = ba.download_swiftdata(table_exposed)

obs_ids=[i for i in table_exposed['obsid'] if result[i]['success'] and i not in ["00012012026", "00012172020", "00035344062", "00045604023", "00095400024", "03102102001", "03109915005", "03110367008"]]

If the user is continuing an analysis, they do not need to do the whole querying/downloading part of the workflow again. They can simply loop through the observation IDs that have already been downloaded and analyzed by doing:
```
obs_ids=[i.parent.name.split("_")[0] for i in sorted(ba.datadir().glob("*_surveyresult/batsurvey.pickle"))]
```


With the data downloaded, we can now craft our dictionary of `batsurvey` key/value pairs that we want passed to the HEASoft function, then set the path to where the pattern maps are located and then call the parallelized analysis function. 

In setting our `batsurvey` key/value pairs we denote that the detector thresholds should be such that the number of BAT active detectors are $>8000$. This value was obtained by analyzing the observations with `detthresh` and `detthresh2` set to 9000, which is the default value, and looking through the failure messages that were saved to each BatSurvey object. 

In [None]:
input_dict=dict(cleansnr=6,cleanexpr='ALWAYS_CLEAN==T', incatalog=f"{incat}", detthresh=8000, detthresh2=8000)
noise_map_dir=Path("/path/to/PATTERN_MAPS/")
batsurvey_obs=ba.parallel.batsurvey_analysis(obs_ids, input_dict=input_dict, patt_noise_dir=noise_map_dir, nprocs=10)

Similar to our other analyses, we can now calculate the spectrum for each pointing, the detector response function, and subsequently fit the spectra with the default `cflux*po` model in Xspec. Here, we set `use_cstat=True` since we expect our MAXI source to have low enough counts such that xspec needs to take poisson statistics into account (this is set to be True by default but we are explicit here for the user's knowledge). We also explicitly set `ul_pl_index=2` which sets the photon index of the power law used to obtain a flux upper limit to be 2, which is the default value but we specify it explicitly here for clarity.

In [None]:
batsurvey_obs=ba.parallel.batspectrum_analysis(batsurvey_obs, object_name, use_cstat=True, ul_pl_index=2, nprocs=14)

Next, we can create our outventory file and define our time bins for mosaicing the BAT survey data. Notice that we pass our custom inventory file into the `batmosaic_analysis` for it to be able to search for the new MAXI source that we are interested in.

Remember that when continuing an analysis, the outventory file can just be set to be the path to the previously created outventory file (similar to defining the custom catalog when continuing a simulation).

We only use `nproc=3` here since each process uses ~10 GB and our laptop has limited memory. 

In [None]:
outventory_file=ba.merge_outventory(batsurvey_obs)
time_bins=ba.group_outventory(outventory_file, np.timedelta64(1, "W"))
mosaic_list, total_mosaic=ba.parallel.batmosaic_analysis(batsurvey_obs, outventory_file, time_bins, catalog_file=incat, nprocs=3)

Now we can conduct our spectral analyses for the source of interest in the weekly mosaics and the total 2 month mosaic image. 

In [None]:
mosaic_list=ba.parallel.batspectrum_analysis(mosaic_list, object_name, use_cstat=True, nprocs=5)
total_mosaic=ba.parallel.batspectrum_analysis(total_mosaic, object_name, use_cstat=True, nprocs=1)

We can also now take a look at the results of the survey analyses and the mosaic analyses (although we could have looked at the results of the survey analyses earlier to get some insight into the analyses and any issues that may have popped up).

In [None]:
fig, axes=ba.plot_survey_lc([batsurvey_obs,mosaic_list], id_list= object_name, time_unit="UTC", values=["rate","snr", "flux", "PhoIndex", "exposure"], same_figure=True)

The next cell concatenates information from the BAT survey analyses and the mosaic analyses for us to plot in publication quality figures.

In [None]:
# save the data in a dictionary for convenient custom plotting for publication quality figures
all_data=ba.concatenate_data(batsurvey_obs, object_name, ["met_time", "utc_time", "exposure", "rate","rate_err","snr", "flux", "PhoIndex"])

with open('all_data_dictionary.pkl', 'wb') as f:
     pickle.dump(all_data, f)
     
all_data_weekly=ba.concatenate_data(mosaic_list, object_name, ["user_timebin/met_time", "user_timebin/met_stop_time", "user_timebin/utc_time", "user_timebin/utc_stop_time", "exposure", "rate","rate_err","snr", "flux", "PhoIndex"])

with open('weekly_mosaic_dictionary.pkl', 'wb') as f:
     pickle.dump(all_data_weekly, f)

#make the plot
energy_range=None
time_unit="MET"
values=["rate", "snr", "flux"]

survey_obsid_list=["all_data_dictionary", "weekly_mosaic_dictionary"]

obs_list_count=0
for observation_list in survey_obsid_list:

    with open(observation_list+".pkl", 'rb') as f:
        all_data=pickle.load(f)
        data=all_data[object_name]

    # get the time centers and errors
    if "mosaic" in observation_list:

        if "MET" in time_unit:
            t0 = TimeDelta(data["user_timebin/met_time"], format='sec')
            tf = TimeDelta(data["user_timebin/met_stop_time"], format='sec')
        elif "MJD" in time_unit:
            t0 = Time(data[time_str_start], format='mjd')
            tf = Time(data[time_str_end], format='mjd')
        else:
            t0 = Time(data["user_timebin/utc_time"])
            tf = Time(data["user_timebin/utc_stop_time"])
    else:
        if "MET" in time_unit:
            t0 = TimeDelta(data["met_time"], format='sec')
        elif "MJD" in time_unit:
            t0 = Time(data[time_str_start], format='mjd')
        else:
            t0 = Time(data["utc_time"])
        tf = t0 + TimeDelta(data["exposure"], format='sec')

    dt = tf - t0

    if "MET" in time_unit:
        time_center = 0.5 * (tf + t0).value
        time_diff = 0.5 * (tf - t0).value
    elif "MJD" in time_unit:
        time_diff = 0.5 * (tf - t0)
        time_center = t0 + time_diff
        time_center = time_center.value
        time_diff = time_diff.value

    else:
        time_diff = TimeDelta(0.5 * dt)  # dt.to_value('datetime')
        time_center = t0 + time_diff

        time_center = np.array([i.to_value('datetime64') for i in time_center])
        time_diff = np.array([np.timedelta64(0.5 * i.to_datetime()) for i in dt])

    x = time_center
    xerr = time_diff

    if obs_list_count == 0:
        fig, axes = plt.subplots(len(values), sharex=True) #, figsize=(10,12))

    axes_queue = [i for i in range(len(values))]
    # plot_value=[i for i in values]

    e_range_str = f"{14}-{195} keV"
    #axes[0].set_title(object_name + '; survey data from ' + e_range_str)

    for i in values:
        ax = axes[axes_queue[0]]
        axes_queue.pop(0)

        y = data[i]
        yerr = np.zeros(x.size)
        y_upperlim = np.zeros(x.size)

        label = i

        if "rate" in i:
            yerr = data[i + "_err"]
            label = "Count rate (cts/s)"
        elif i + "_lolim" in data.keys():
            # get the errors
            lolim = data[i + "_lolim"]
            hilim = data[i + "_hilim"]

            yerr = np.array([lolim, hilim])
            y_upperlim = data[i + "_upperlim"]

            # find where we have upper limits and set the error to 1 since the nan error value isnt
            # compatible with upperlimits
            yerr[:, y_upperlim] = 0.2 * y[y_upperlim]

        if "mosaic" in observation_list:
            if "weekly" in observation_list:
                zorder = 9
                c = "blue"
                m = "o"
                l="Weekly Mosaic"
                ms=5
                a=0.8
            else:
                zorder = 9
                c='green'
                m = "s"
                l = "Monthly Mosaic"
                ms=7
                a = 1
        else:
            zorder = 4
            c = "gray"
            m = "."
            l = "Survey Snapshot"
            ms=3
            a = 0.3

        ax.errorbar(x, y, xerr=xerr, yerr=yerr, uplims=y_upperlim, linestyle="None", marker=m, markersize=ms,
                    zorder=zorder, color=c, label=l, alpha=a)
                    
        #plt.gca().ticklabel_format(useMathText=True)
        ax.xaxis.set_major_formatter(ticker.ScalarFormatter(useMathText=True))


        if ("flux" in i.lower()):
            ax.set_yscale('log')

        if ("snr" in i.lower()):
            ax.set_yscale('log')

        ax.set_ylabel(label)

    # if T0==0:
    if "MET" in time_unit:
        label_string = 'MET Time (s)'
    elif "MJD" in time_unit:
        label_string = 'MJD Time (s)'
    else:
        label_string = 'UTC Time (s)'

    axes[-1].set_xlabel(label_string)
    
    obs_list_count += 1


#add the UTC times as well
utc_time=Time(["2019-11-01", "2019-12-01", "2020-01-01", "2020-01-30"])
met_time=[]
for i in utc_time:
    met_time.append(swiftbat.datetime2met(i.datetime, correct=True))

for i,j in zip(met_time, utc_time.ymdhms):
    for ax in axes:
        ax.axvline(i, 0, 1, ls='--', color='k')
        if ax==axes[0]:
            ax.text(i, ax.get_ylim()[1]*1.03, f'{j["year"]}-{j["month"]}-{j["day"]}', fontsize=10, ha='center')

axes[1].set_ylabel("SNR")
axes[2].set_ylabel(r"Flux (erg/s/cm$^2$)")

axes[1].legend(loc= "lower center", ncol=2)

for ax, l in zip(axes, ["a","b","c","d"]):
    ax.text(1.0, .95, f"({l})", ha='right', va='top', transform=ax.transAxes,  fontsize=12)

fig.tight_layout()
plot_filename = object_name + '_survey_lc.pdf'
fig.savefig(plot_filename, bbox_inches="tight")