In [None]:
import numpy as np
from mpmath import mp
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import pesummary
from pesummary.gw.fetch import fetch_open_samples
from pesummary.io import read
import seaborn as sns
import h5py
import sys
import warnings

In [None]:
def calculate_pn_terms(q, chi1, chi2, pn_order,pn_orders):
    # PN coefficients and expression are taken from https://arxiv.org/pdf/1508.07253
    gamma = np.euler_gamma
    eta = q / (1 + q)**2
    delta = (1 - q) / (1 + q)
    chi_s = 0.5 * (chi1 + chi2)
    chi_a = 0.5 * (chi1 - chi2)
    
    # Input check
    if pn_order not in pn_orders:
        print("Wrong PN value")
        sys.exit(1) 
    
    
    if pn_order == 1:  # 1PN
        return 3715 / 756 + (55 / 9) * eta
    elif pn_order == 1.5:  # 1.5PN
        return -16 * np.pi + (113 / 3) * delta * chi_a + (113 / 3 - (76 / 3) * eta) * chi_s
    elif pn_order == 2:  # 2PN
        return (15293365 / 508032 + (27145 / 504) * eta + (3085 / 72) * eta**2 +
                (-405 / 8 + 200 * eta) * chi_a**2 - (405 / 4) * delta * chi_a * chi_s +
                (-405 / 8 + (5 / 2) * eta) * chi_s**2)
    elif pn_order == 3:  # 3PN
        return (11583231236531 / 4694215680 - (6848 / 21) * gamma - (640 / 3) * np.pi**2 +
                (-15737765635 / 3048192 + (2255 / 12) * np.pi**2) * eta +
                (76055 / 1728) * eta**2 - (127825 / 1296) * eta**3 +
                (2270 / 3) * np.pi * delta * chi_a +
                ((2270 / 3) * np.pi - 520 * np.pi * eta) * chi_s)
    elif pn_order == 3.5: #3.5PM
        return ((77096675/254016) * np.pi + (378515/1525) * np.pi * eta - (74045/756) * np.pi * eta**2 +
                delta*(-25150083775/3048192 + (26804935/6048) * eta - (1985/48) * eta**2)*chi_a + 
               (-25150083775/3048192 + (10566655595/762048) * eta - (1042165/3024) * eta**2 + (5345/36) * eta**3)*chi_s)


# Defining condition equation for horizon
def horizon_eq(r, a, b, n, M):
    # eq (4)
    return r**2 + a**2 - 2*M * (1 - b * (M / r)**n) * r

# Searching real solutions 
def horizon_eq_real_sol(b, a, n, M):
    r_values = np.linspace(0.001, 5, 100)
    values = horizon_eq(r_values, a, b, n, M)
    
    sign_changes = np.any(np.diff(np.sign(values)) != 0)
    return sign_changes

# Finding b_c where we would meet a sign change
def find_min_b(a, n, M):
    b_min, b_max = 0.0, 2.0  
    while b_max - b_min > 1e-4:  
        b_mid = (b_min + b_max) / 2
        if horizon_eq_real_sol(b_mid, a, n, M):  
            b_min = b_mid  
        else:
            b_max = b_mid  
    
    return (b_min + b_max) / 2  

# Loading relevant parameters of event(i)
def parameters_call(samples):
    q = samples["mass_ratio"]
    a = samples["final_spin"]
    chi1 = samples["spin_1z"]
    chi2 = samples["spin_2z"]
    num = len(q)
    return q, a, chi1, chi2, num

# Loading events
# You can load the files either from local files or online (see PESummarys)
def events_load():
    all_samples = {}
    for event in event_list:
        file_name = f"files/IGWN-GWTC3p0-v2-{event}_PEDataRelease_mixed_cosmo.h5" 
        data = read(file_name)  
        samples_dict = data.samples_dict
        posterior_samples = samples_dict["C01:Mixed"]
        all_samples[event] = posterior_samples
    return all_samples

# Calculating main results
def results_calc(all_samples,event_list, pn_orders):

    # Dictionary to store results
    results = {}
    b_c_total = {}

    # Loop through each event
    for event in event_list:
        samples = all_samples[event]
        q_samples, a_samples, chi1_samples, chi2_samples, num_samples = parameters_call(samples)

        # Initialize b_c_total for each event inside the loop, based on num_samples
        b_c_total[event] = {n: [None] * num_samples for n in pn_orders}

        event_results = {}
        print(event)

        # Loop through each PN order
        for order in pn_orders:
            rel_errors = np.zeros(num_samples)

            # Compute rel_error for each sample
            for i in range(num_samples):
                q, a, chi1, chi2 = q_samples[i], a_samples[i], chi1_samples[i], chi2_samples[i]
                b_c = find_min_b(a, order, M)
                b_c_total[event][order][i] = b_c
                
                # eq (7)
                Q = (1 + q**order) / (1 + q)**order
                B = ((((order + 2) * (order + 1) * Q) / 3)) * b_c
                pc_pn = (20 * B) / ((order - 4) * (2 * order - 5))
                gr_pn = calculate_pn_terms(q, chi1, chi2, order, pn_orders)
                
                # eq (8)
                rel_errors[i] = pc_pn / gr_pn  

            # Compute average
            filtered_rel_error = filter_percentiles(rel_errors)
            avg_rel_error = np.mean(filtered_rel_error)

            # Compute 90% credible interval (5th to 95th percentile)
            lower_90 = np.percentile(filtered_rel_error, 5)
            upper_90 = np.percentile(filtered_rel_error, 95)

            # Compute 50% credible interval (25th to 75th percentile)
            lower_50 = np.percentile(filtered_rel_error, 25)
            upper_50 = np.percentile(filtered_rel_error, 75)

            event_results[order] = (avg_rel_error, lower_50, upper_50, lower_90, upper_90)

        results[event] = event_results
    return results, b_c_total

def event_stats(results,pn_orders):
    events_data = {}

    for event, pn_data in results.items():
        events_data[event] = {}  # Initialize event entry
        for order in pn_orders:
            avg, lower_50, upper_50, lower_90, upper_90 = pn_data[order]
            events_data[event][order] = [avg, [lower_50, upper_50], [lower_90, upper_90]]
    return events_data

# Plot fig (1)
def b_values_analytical(pn_orders):
    
    plt.rcParams['font.family'] = 'Calibri'

    chi = np.linspace(0, 1, 100)

    plt.figure(figsize=(8, 6), dpi=200)

    colors = plt.cm.rainbow(np.linspace(0, 1, len(pn_orders)))

    line_styles = ['-', '--', '-.', ':', (0, (3, 1, 1, 1, 1, 1))]  
    
    
    for i, n in enumerate(pn_orders):
        # eq (5)
        oopsilon = (n + (n**2 - (n**2 - 1) * chi**2)**0.5) / (n + 1)
        b = oopsilon**n * (1 - chi**2 / (2 * oopsilon) - oopsilon / 2)
        print(f"value of b for {n} is {b[0]}")

        plt.plot(chi, b, linewidth=2.5, color=colors[i], linestyle=line_styles[i], label=f"n = {n}")

    plt.xlabel(r"$\chi$", fontsize=16, fontweight='bold')
    plt.ylabel(r"$b_{c}(\chi)$", fontsize=16, fontweight='bold')

    plt.legend(loc="upper right", fontsize=12)

    plt.text(0.5, 0.9, "$b_c(\chi)$ for Different PNs", fontsize=18, fontweight='bold', ha='center', 
             transform=plt.gca().transAxes)

    plt.tight_layout()
    plt.grid(False)
    plt.show()


# Plot fig (2) 
def b_hist_plot(events_shortened,pn_orders,b_c_total):
    
    # Suppress the specific FutureWarning related to 'use_inf_as_na'
    warnings.filterwarnings("ignore", message="use_inf_as_na option is deprecated")

    fig, axes = plt.subplots(3, 2, figsize=(10, 12), dpi=200)

    plt.rcParams['font.family'] = 'Calibri'

    ax_list = axes.flatten()

    handles = []

    colors = plt.cm.viridis(np.linspace(0, 1, len(event_list)))

    for idx, (n, ax) in enumerate(zip(pn_orders[:-1], ax_list[:-1])):  
        for event_idx, event in enumerate(event_list):
            color = colors[event_idx] 
            b_c_data = b_c_total[event][n]
            b_c_data = np.where(np.isinf(b_c_data), np.nan, b_c_data)

            sns.histplot(b_c_data, bins=70, kde=True, 
                         label=events_shortened[event_idx], stat="density", color=color, alpha=0.7, ax=ax)

            if idx == 0:  
                handles.append(mlines.Line2D([], [], color=color, label=event))

        ax.set_title(f"n = {n}")
        ax.set_xlabel("$b_{c}$")
        ax.set_ylabel("Density")
        ax.grid(False)

    n = pn_orders[-1]
    last_ax = ax_list[4] 
    for event_idx, event in enumerate(event_list):
        color = colors[event_idx]  

        # Replace infinite values with NaN before plotting
        b_c_data = b_c_total[event][n]
        b_c_data = np.where(np.isinf(b_c_data), np.nan, b_c_data)

        sns.histplot(b_c_data, bins=70, kde=True, 
                     label=events_shortened[event_idx], stat="density", color=color, alpha=0.7, ax=last_ax)

    last_ax.set_title(f"n = {n}")
    last_ax.set_xlabel("$b_{c}$")
    last_ax.set_ylabel("Density")
    last_ax.grid(False)

    ax_list[5].axis('off')  

    fig.suptitle("Histograms of $b_{c}$ for Different PC-PN Values and Events\n \n", fontsize=14, y=1.02)

    fig.legend(handles=handles, labels=events_shortened, loc='upper center', bbox_to_anchor=(0.5, 0.985), ncol=3)

    plt.tight_layout()
    plt.show()
    
# Plot fig (3)
def pn_plot(events_data,pn_order):
    
    plt.rcParams['font.family'] = 'Calibri'

    # Concatenating artificially
    chi_data = {
        "$\\chi = 0.7$": {
            1: [0.18, [0, 0], [0.13, 0.24]],
            1.5: [-0.046, [0, 0], [-0.065, -0.037]],
            2: [0.138, [0, 0], [0.10, 0.242]],
            3: [0.032, [0, 0], [-0.091, 0.162]],
            3.5: [-0.021, [0, 0], [-0.146, 0.1]]
        }
    }

    # Add to total events
    events_data.update(chi_data)

##############
    events = list(events_data.keys())
    #events_shortened = [event.split('_')[0] if '_' in event else event for event in events]
    events_shortened = events_shortened_names(events_data)

    # Boundaries for each PN order
    boundariesG = [(-4, 14), (-3, 3), (-18, 73), (-33, 60), (-102, 119)]
    boundariesR = [(0, 9), (-5, -1), (2, 47), (-4, 38), (-95, 83)]

    boundariesG = [(low / 100, high / 100) for low, high in boundariesG]
    boundariesR = [(low / 100, high / 100) for low, high in boundariesR]

    # Markers and colors
    markers_50 = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'X', 'P']
    markers_90 = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*','X', 'P']
    colors = plt.cm.viridis(np.linspace(0, 1, len(event_list)))

    # Create subplots
    fig, axes = plt.subplots(len(pn_orders), 1, figsize=(8, 12), dpi=300, sharex=True)

    # Loop over PN orders
    for idx, order in enumerate(pn_orders):
        ax = axes[idx]

        # Plot each event for the current PN order
        for event_idx, event in enumerate(events):
            mean, err_50, err_90 = events_data[event][order]

            # Extract the bounds
            err_50_lower, err_50_upper = err_50
            err_90_lower, err_90_upper = err_90

            if event == "$\\chi = 0.7$":
                # Only plot 90% error bars for these events
                ax.errorbar(event_idx, mean, yerr=[[mean - err_90_lower], [err_90_upper - mean]], 
                            fmt=markers_90[event_idx], capsize=5, color="black", label=event if idx == 0 else None)
            else:
                # Plot both 50% and 90% error bars
                ax.errorbar(event_idx, mean, yerr=[[mean - err_50_lower], [err_50_upper - mean]], 
                            fmt=markers_50[event_idx], capsize=5, color=colors[event_idx], alpha=0.6, label=None)
                ax.errorbar(event_idx, mean, yerr=[[mean - err_90_lower], [err_90_upper - mean]], 
                            fmt=markers_90[event_idx], capsize=5, color=colors[event_idx], label=event[:8] if idx == 0 else None)

        # Fill boundary regions
        lowerG, upperG = boundariesG[idx]
        lowerR, upperR = boundariesR[idx]
        ax.fill_between(range(len(events)), lowerG, upperG, color='green', alpha=0.1, hatch='//', label='Boundary G' if idx == 0 else "")
        ax.fill_between(range(len(events)), lowerR, upperR, color='red', alpha=0.1, hatch='\\', label='Boundary R' if idx == 0 else "")

        # Draw boundary lines
        ax.hlines([lowerG, upperG], xmin=0, xmax=len(events)-1, colors='black', alpha=0.5, linestyles='dashed')
        ax.hlines([lowerR, upperR], xmin=0, xmax=len(events)-1, colors='black', alpha=0.5, linestyles='dotted')

        # Customize subplot
        ax.set_ylabel(f'{order}PN $\delta_\phi$')
        ax.grid(False)

    # Set x-ticks and labels
    axes[-1].set_xticks(range(len(events)))
    axes[-1].set_xticklabels(events_shortened, rotation=30, ha="right")

    # Add main title
    plt.suptitle('$\delta_{\phi}$ Across Events', fontsize = 16, y = 0.945)

    # Add legend for boundaries
    green_patch = plt.Line2D([0], [0], color='green', alpha=0.5, lw=4, label=r"General $\delta_\phi$")
    red_patch = plt.Line2D([0], [0], color='red', alpha=0.5, lw=4, label=r"Restricted $\delta_\phi$")
    fig.legend(handles=[green_patch, red_patch], loc='upper right', fontsize=8, bbox_to_anchor=(0.9875, 0.955))

    # Show plot
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    fig.subplots_adjust(hspace=0)
    plt.show()

def events_shortened_names(events_data):
    events = list(events_data.keys())
    events_shortened = [event.split('_')[0] if '_' in event else event for event in events]
    return events_shortened

    
# Function to highlight the rows where delta_phi is outside either boundary
def highlight_delta_phi(row, general_bound, restricted_bound):
    delta_phi = row[r"$\delta_\phi$ (%)"]

    general_check = general_bound[0] <= delta_phi <= general_bound[1]

    restricted_check = restricted_bound[0] <= delta_phi <= restricted_bound[1]

    if not general_check or not restricted_check:
        return ['background-color: lightgreen'] * len(row)  

    return [''] * len(row)  

# Function to filter data between 0.5th and 99.5th percentiles
def filter_percentiles(data):
    sorted_data = np.sort(data)  # Sort the array
    lower_percentile = np.percentile(sorted_data, 0.5)  # 0.5th percentile
    upper_percentile = np.percentile(sorted_data, 99.5)  # 99.5th percentile
    filtered_data = sorted_data[(sorted_data >= lower_percentile) & (sorted_data <= upper_percentile)]
    return filtered_data

# Apply highlighting to the table
def highlight_table(table, title="Summary of PN Terms and Boundary Errors"):

    styled_table = table.copy()

    # Apply highlight function to each row
    styled_table = styled_table.style.apply(
        lambda row: highlight_delta_phi(row, row[r"general $\delta_\phi$ (%) (GWTC-3)"], row[r"restricted $\delta_\phi$ (%) (GWTC-3)"]),            axis=1
        )
    styled_table.set_caption(title) 

    return styled_table

# Table (1)
def table(results):
    print("\nRel Error Table with 50% and 90% Credible Intervals")
    print("------------------------------------------------------------")
    print("Event            |  1PN            |  1.5PN          |  2PN            |  3PN            |  3.5PN")
    print("------------------------------------------------------------")
    for event, pn_data in results.items():
        row = f"{event[:8]}  "  # Shortened event name for formatting
        for order in pn_orders:
            avg, lower_50, upper_50, lower_90, upper_90  = pn_data[order]
            row += f"| {avg:.3f} ± {lower_50:.3f} {upper_50:.3f} (50%), and ± {lower_90:.3f} {upper_90:.3f} (90%)"
        print(row)


In [None]:
#9 Parameterized events from GWTC-3

event_list = [
    "GW191129_134029", "GW191204_171526", "GW191216_213338", "GW200316_215756",
    "GW200129_065458", "GW200202_154313", "GW200225_060421", "GW200311_115853",
    "GW200115_042309"
]

#PN orders
pn_orders = [1, 1.5, 2, 3, 3.5]

#Normalized mass
M = 1

# Load data
all_samples = events_load()

# Results
results, b_c_total = results_calc(all_samples,event_list,pn_orders)
table(results)

# Plots
events_data = event_stats(results)
pn_plot(events_data,pn_orders)
b_values_analytical(pn_orders)
b_hist_plot(events_shortened_names(events_data),pn_orders,b_c_total)