Make cumulative corner plot by reading mcmc posteriors for all the fitted bursts 

In [None]:
from emcee.autocorr import AutocorrError
import pylab as plt
import numpy as np
%matplotlib inline
import glob
import emcee
import pandas as pd
from tqdm import tqdm 

In [None]:
def radiometer(tsys, gain, bandwidth, time, npol=2):
    return tsys / gain / np.sqrt(npol * bandwidth * time)

In [None]:
def samples2params(samples, meta_info):
    fraction = np.sum(samples[:, 4] / samples[:, 5] < 6) / samples.shape[0]
    print(f"tau fraction {fraction:.3f}")
    if fraction > 0.5:
        print("Using tau")
        use_tau = True
        mask = samples[:, 4] / samples[:, 5] < 6
        samples = samples[mask, :]
    else:
        use_tau = False
        mask = samples[:, 4] / samples[:, 5] > 6
        samples = np.delete(samples, 5, 1)
        samples = samples[mask, :]

    samples[:, 0] = (
        meta_info["fileheader"]["fch1"]
        + samples[:, 0] * meta_info["fileheader"]["native_foff"]
    )
    samples[:, 1] *= np.abs(meta_info["fileheader"]["native_foff"])
    samples[:, 2] *= (
        radiometer(
            27, 10, 2.355 * samples[:, 1] * 1e6, meta_info["fileheader"]["native_tsamp"]
        )
        * 81.92e-3
        / np.sqrt(64 - sum(meta_info["mask"]))
    )
    samples[:, 3] = (
        (samples[:, 3] + meta_info["nstart"])
        * meta_info["fileheader"]["native_tsamp"]
        / 3600
        / 24
    ) + meta_info["fileheader"]["tstart"]
    samples[:, 4] *= meta_info["fileheader"]["native_tsamp"] * 1e3
    if use_tau:
        samples[:, 5] *= meta_info["fileheader"]["native_tsamp"] * 1e3 * 81.92e-3
        samples[:, 5] *= (1000 / meta_info["fileheader"]["fch1"]) ** (-4)

    param_list = [
        r"$\mu_f$ (MHz)",
        r"$\sigma_f$ (MHz)",
        r"$S$ (Jy ms)",
        r"$\mu_t$ (ms)",
        r"$\sigma_t$ (ms)",
    ]
    if use_tau:
        param_list += [r"$\tau$ (ms)"]

    param_list += [r"DM (pc cm$^{-3}$)"]
    return samples, param_list, mask


def get_chains_and_parameters(h5_filename, json_filename, thin=1):
    reader = emcee.backends.HDFBackend(h5_filename)

    try:
        tau = reader.get_autocorr_time()
        burnin = int(2 * np.max(tau))
        print(f"burnin using tau is: {burnin}")
        samples = reader.get_chain(discard=burnin, flat=True, thin=thin)

    except (AutocorrError, ValueError):
#         return None, None
        samples = reader.get_chain(discard=0, flat=True, thin=thin)
        burnin = int(samples.shape[0] * 0.75)
        samples = samples[burnin:, :]
        
    print("burn-in: {0}".format(burnin))
    print("flat chain shape: {0}".format(samples.shape))

    with open(json_filename, "r") as f:
        meta_info = json.loads(f.read())

    if samples.shape[-1] == 7:
        samples, param_list, _ = samples2params(samples, meta_info)
        return samples, param_list
    elif samples.shape[-1] == 14:
        first_samples, first_params, mask1 = samples2params(samples[:, :7], meta_info)
        second_samples, second_params, mask2 = samples2params(
            samples[mask1, 7:], meta_info
        )
        #         print(first_samples[mask2].shape, first_params)
        #         print(second_samples.shape, second_params)
        param_list = []
        for index, param in enumerate(first_params):
            param_list.append(param + str(1))
        for index, param in enumerate(second_params):
            param_list.append(param + str(2))
        return (np.hstack([first_samples[mask2], second_samples]), param_list)
    else:
        first_samples, first_params, mask1 = samples2params(samples[:, :7], meta_info)
        second_samples, second_params, mask2 = samples2params(
            samples[mask1, 7:14], meta_info
        )
        third_samples, third_params, mask3 = samples2params(
            samples[mask2, 14:], meta_info
        )
        param_list = []
        for index, param in enumerate(first_params):
            param_list.append(param + str(1))
        for index, param in enumerate(second_params):
            param_list.append(param + str(2))
        for index, param in enumerate(third_params):
            param_list.append(param + str(3))
        #         for index, param in enumerate(first_params + second_params + third_params):
        #             param_list.append(param + str(((index // 7) + 1)))
        return (
            np.hstack(
                [
                    first_samples[mask2, :][mask3, :],
                    second_samples[mask3, :],
                    third_samples,
                ]
            ),
            param_list,
        )

In [None]:
import json 

In [None]:
PATH = "121102_paper/mcmc_final/"
cids = [x.split("/")[-1][:-11] for x in glob.glob(PATH + "*.h5")]

In [None]:
def try_or_move_ahead_thin(cand_id):
    rng = np.random.RandomState(2021)
    try:
        h5_filename = PATH + cand_id + "_samples.h5"
        json_filename = PATH + cand_id + ".json"
        samples, param_list = get_chains_and_parameters(h5_filename, json_filename, 
                                                        thin=1)
        
        random_samples = np.zeros((1000, samples.shape[1]))
        for i in range(samples.shape[1]):
            random_samples[:, i] = rng.choice(samples[:, i], size=1000, replace=False)

        samples = random_samples
        
        if not np.any(samples):
            return np.empty((0, 7))
        
        nps = samples.shape[1]
        if nps == 7:
            s = samples
        elif nps > 7:
            first_comp, second_comp, third_comp = False, False, False
            for p in param_list:
                if '$\\tau$ (ms)1' in p:
                    first_comp = True
                if '$\\tau$ (ms)2' in p:
                    second_comp = True
                if '$\\tau$ (ms)3' in p:
                    third_comp = True
            s = np.empty((0, 7))
            if first_comp and second_comp and third_comp:
                s = np.concatenate((s, samples[:, :7]), axis=0)
                s = np.concatenate((s, samples[:, 7:14]), axis=0)
                s = np.concatenate((s, samples[:, 14:]), axis=0)
            elif first_comp and second_comp and not third_comp:
                s = np.concatenate((s, samples[:, :7]), axis=0)
                s = np.concatenate((s, samples[:, 7:14]), axis=0)
            elif second_comp and third_comp and not first_comp:
                s = np.concatenate((s, samples[:, 6:13]), axis=0)
                s = np.concatenate((s, samples[:, 13:]), axis=0)
            elif first_comp and third_comp and not second_comp:
                s = np.concatenate((s, samples[:, :7]), axis=0)
                s = np.concatenate((s, samples[:, 13:]), axis=0)                
            elif first_comp and not second_comp and not third_comp:
                s = np.concatenate((s, samples[:, :7]), axis=0)   
            elif second_comp and not first_comp and not third_comp:
                s = np.concatenate((s, samples[:, 6:13]), axis=0)   
            elif third_comp and not first_comp and not second_comp:
                s = np.concatenate((s, samples[:, 12:]), axis=0)   
            else:
                s = np.empty((0, 7))
        #     s = s[1:, :]
        else:
            s = np.empty((0, 7))
        print(cand_id, samples.shape, s.shape)
        return s
    except FileNotFoundError as e:
        return cand_id, "FileNotFoundError"

In [None]:
all_bursts_bary = pd.read_csv('../data/all_bursts_bary.csv')

In [None]:
m = (all_bursts_bary['use_fluence'] == True) & (all_bursts_bary['fit_method'] == 'mcmc')

In [None]:
cids_use_f = list(set(all_bursts_bary[m].cand_id))

In [None]:
from joblib import Parallel, delayed
samples_thinned = Parallel(n_jobs=20)(delayed(try_or_move_ahead_thin)(cid) 
                                      for cid in tqdm(cids_use_f))

In [None]:
samples_thinned = np.concatenate(samples_thinned, axis=0)

In [None]:
with open('samples_thousand.npy', 'wb') as f:
    np.save(f, samples_thinned)

In [None]:
with open('samples_thousand.npy', 'rb') as f:
    a = np.load(f)

In [None]:
samples_plot = np.take(a, indices=[0, 1, 2, 4, 5, 6], axis=1)

In [None]:
param_list = [ r"$\mu_f$ (MHz)", r"$\sigma_f$ (MHz)",r"$\log_{10}$(S (Jy ms))", 
              r"$\log_{10}$($\sigma_t$ (ms))", r"$\tau$ (ms)", 
              r"DM (pc cm$^{-3}$)"]

In [None]:
from chainconsumer import ChainConsumer

In [None]:
samples_plot.shape

In [None]:
samples_plot_new = samples_plot.copy()

In [None]:
ext = {}
ext[r"$\mu_f$ (MHz)"] = (800, 2200)
ext[r"$\sigma_f$ (MHz)"] = (0, 1000)
ext[r"$\log_{10}$(S (Jy ms))"] = (-2, 1.5)
ext[r"$\log_{10}$($\sigma_t$ (ms))"] = (-1.5, 1)
ext[r"$\tau$ (ms)"] = (0, 4)
ext[r"DM (pc cm$^{-3}$)"] = (500, 650)

In [None]:
c = ChainConsumer()
c.add_chain(samples_plot_new, parameters=param_list)
corner_plot_path = "./"
# corner_plot_path += "mcmc_final/final_corner_plots/"
with plt.style.context(['science']):
    fig = c.plotter.plot(
        figsize="grow",
        filename=corner_plot_path + 'cumulative_corner_plot' + ".pdf",
        display=False, extents=ext
    )

In [None]:
def get_number(x, dx):
    """ Returns a string of the measurement value"""
    """ together with the measurement error"""
    """ x: measurement value"""
    """ dx: measurment error"""

    # Power of dx
    power_err = np.log10(dx)

    # Digits of dx in format a.bcd
    n_err = dx / (10**np.floor(power_err))

    # If the second digit in dx is >=5
    # round the 1st digit in dx up
    if n_err % 1 >= 0.5:
        # If the first digit of dx is 9
        # the precision is one digit less
        if int(n_err) == 9: 
            err = 1
        # The precision of x is determined by the precision of dx
            prec=int(-np.floor(np.log10(dx))) - 1           
        else:
            err = np.ceil(n_err)
            # The precision of x is determined by the precision of dx
            prec=int(-np.floor(np.log10(dx)))
    # Otherwise round down
    else:      
        err = np.floor(n_err) 
        # The precision of x is determined by the precision of dx
        prec=int(-np.floor(np.log10(dx)))
    return x, err, prec

def get_err_string(x, le, ue):
    min_err = min(le, ue)
    x, err, prec = get_number(x, min_err)
    s = ''
    if min_err > 1:
        s += str(int(x))
        les = int(np.round(le, prec))
        ues = int(np.round(ue, prec))
    else:
        s += str(np.round(x, prec))
        les = np.round(le, prec)
        ues = np.round(ue, prec)
    ret = f'${s}^' + '{+' + str(ues) + '}_{-' + str(les) + '}$'
    return ret

In [None]:
samples_plot_summary = samples_plot.copy()

In [None]:
a = np.quantile(samples_plot_summary, [0.16, 0.5, 0.84], axis=0)
median_values = a[1]
upper_errors = a[2] - a[1]
lower_error = a[1] - a[0]

In [None]:
median_values, upper_errors, lower_error

In [None]:
for i in range(len(median_values)):
    x = median_values[i]
    ue = upper_errors[i]
    le = lower_error[i]
    print(get_err_string(x, le, ue))