In [None]:
import os, pickle
from itertools import zip_longest 
from typing import Tuple, List, Iterable, Iterator
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pymbar import bar
from pymbar.other_estimators import exp
from endstate_correction.protocol import Results
from endstate_correction.constant import zinc_systems, blacklist

FIGURES_PATH = '../figures'
BOOTSTRAP_ITERATIONS = 1000
CONFIDENCE_LEVEL = 90.0
SWITCHING_LENGTHS = [5_000, 10_000, 20_000, 50_000]

os.makedirs(FIGURES_PATH, exist_ok=True)
colors = sns.color_palette("flare", n_colors = 201) 

all_used_systems = [x for x in zinc_systems if x[0] not in blacklist]
all_systems_results = {}
for neq_switching_length in SWITCHING_LENGTHS:
    filename = f'../data/switching_results/all_results_{neq_switching_length}.pickle'
    try:
        with open(filename, 'rb') as f:
            all_systems_results[neq_switching_length] = pickle.load(f)
    except FileNotFoundError:
        print(f"File {filename} not found.")
    except pickle.UnpicklingError:
        print(f"Error unpickling file {filename}.")


In [None]:
def calculate_confidence_interval(
    bootstrap_results: List[float],
) -> Tuple[float, float]:
    """
    Calculate lower and upper percentiles for a confidence interval from bootstrap results.
    """
    alpha = 100 - CONFIDENCE_LEVEL
    lower_p = alpha / 2.0
    lower = np.percentile(bootstrap_results, lower_p)
    upper_p = (100 - alpha) + (alpha / 2.0)
    upper = np.percentile(bootstrap_results, upper_p)
    return lower, upper


def bootstrap_exp(fct: callable, dE_values: List[float]) -> Tuple[float, float, float]:
    """
    Perform bootstrapping using an exponential average estimator.

    Args:
        fct: Function to bootstrap
        dE_values: Energies to bootstrap

    Returns:
        Tuple: Mean, lower, upper
    """
    assert callable(fct), "Provided function is not callable."
    bootstrap_metrics = []
    for _ in range(BOOTSTRAP_ITERATIONS):
        indices = np.random.choice(
            range(0, len(dE_values)), size=len(dE_values), replace=True
        )
        selection = np.take(dE_values, indices)
        r = fct(selection)["Delta_f"]
        bootstrap_metrics.append(r)

    mean = fct(dE_values)["Delta_f"]
    lower, upper = calculate_confidence_interval(bootstrap_metrics)
    return mean, lower, upper


def bootstrap_bar(
    fct: callable, fw_f: List[float], rv_f: List[float]
) -> Tuple[float, float, float]:
    """
    Perform bootstrapping using a BAR estimator.

    Args:
        fct: Function to bootstrap
        fw_f: Forward work values to bootstrap
        rv_f: Reverse work values to bootstrap

    Returns:
        Tuple: Mean, lower, upper
    """
    assert callable(fct), "Provided function is not callable."
    bootstrap_metrics = []
    for _ in range(BOOTSTRAP_ITERATIONS):
        indices_fw = np.random.choice(range(0, len(fw_f)), size=len(fw_f), replace=True)
        selection_fw = np.take(fw_f, indices_fw)
        indices_rv = np.random.choice(range(0, len(rv_f)), size=len(rv_f), replace=True)
        selection_rv = np.take(rv_f, indices_rv)
        r = fct(selection_fw, selection_rv)["Delta_f"]
        bootstrap_metrics.append(r)

    mean = fct(fw_f, rv_f)["Delta_f"]
    lower, upper = calculate_confidence_interval(bootstrap_metrics)
    return mean, lower, upper


def calculate_bootstrap_estimates(w_F, w_R, equ_mbar):
    # generate bootstrap estimates for forward and reverse distributions
    rexp_f, rexp_f_lower, rexp_f_upper = bootstrap_exp(exp, w_F)
    rexp_r, rexp_r_lower, rexp_r_upper = bootstrap_exp(exp, w_R)
    bar_bi, bar_bi_lower, bar_bi_upper = bootstrap_bar(bar, w_F, w_R)

    # equ dG
    equ_rs = [
        equ.compute_free_energy_differences()["Delta_f"][0][-1] for equ in equ_mbar
    ]

    return (
        rexp_f,
        rexp_f_lower,
        rexp_f_upper,
        rexp_r,
        rexp_r_lower,
        rexp_r_upper,
        bar_bi,
        bar_bi_lower,
        bar_bi_upper,
        equ_rs,
    )


def calculate_offsets_and_adjust_values(
    min_F:float,
    rexp_f:float,
    rexp_f_lower:float,
    rexp_f_upper:float,
    rexp_r:float,
    rexp_r_lower:float,
    rexp_r_upper:float,
    bar_bi:float,
    bar_bi_lower:float,
    bar_bi_upper:float,
    equ_rs:List[float],
):
    rexp_f, rexp_f_lower, rexp_f_upper = (
        rexp_f - min_F,
        rexp_f_lower - min_F,
        rexp_f_upper - min_F,
    )
    rexp_r, rexp_r_lower, rexp_r_upper = (
        (rexp_r * -1) - min_F,
        (rexp_r_lower * -1) - min_F,
        (rexp_r_upper * -1) - min_F,
    )
    bar_bi, bar_bi_lower, bar_bi_upper = (
        bar_bi - min_F,
        bar_bi_lower - min_F,
        bar_bi_upper - min_F,
    )
    equ_r = np.average(equ_rs) - min_F
    # 90% CI: https://sphweb.bumc.bu.edu/otlt/MPH-Modules/PH717-QuantCore/PH717-Module6-RandomError/PH717-Module6-RandomError11.html
    # CI = +- t-score(/sigma / sqrt(n))
    equ_dDG = 2.920 * (np.std(equ_rs) / np.sqrt(len(equ_rs)))

    print(
        f"{np.round(equ_r, 2)} [{np.round(equ_dDG, 2)};{np.round(equ_dDG, 2)}],",
        end=" ",
    )
    print(
        f"{np.round(rexp_f, 2)} [{np.round(rexp_f_lower, 2)};{np.round(rexp_f_upper, 2)}],",
        end=" ",
    )
    print(
        f"{np.round(rexp_r, 2)} [{np.round(rexp_r_lower, 2)};{np.round(rexp_r_upper, 2)}],",
        end=" ",
    )
    print(
        f"{np.round(bar_bi, 2)} [{np.round(bar_bi_lower, 2)};{np.round(bar_bi_upper, 2)}],",
        end=" ",
    )

    return (
        rexp_f,
        rexp_f_lower,
        rexp_f_upper,
        rexp_r,
        rexp_r_lower,
        rexp_r_upper,
        bar_bi,
        bar_bi_lower,
        bar_bi_upper,
        equ_r,
        equ_dDG,
    )


def grouper(n: int, iterable: Iterable, padvalue=None) -> Iterator[Tuple]:
    """
    Group an iterable into tuples of length n, filling in with padvalue if necessary.

    Args:
        n: Length of output tuples
        iterable: Input iterable
        padvalue: Value to use to fill in if iterable length is not divisible by n

    Returns:
        Iterator of tuples
    """
    return zip_longest(*[iter(iterable)] * n, fillvalue=padvalue)


def plot_kde_and_scatter(axs, w_F: list, w_R: list, min_F: float, colors):
    sns.kdeplot(
        w_F - min_F, label="f_forw", ax=axs, fill=True, alpha=0.5, color=colors[0]
    )
    y1, x = np.histogram(w_F - min_F, density=True)
    y2, x = np.histogram((w_R * -1) - min_F, density=True)

    y_max = max([y1.max(), y2.max()])
    # sns.rugplot(w_F-min_F, ax=ax1, lw=1, alpha=.1,color=colors[0])
    axs.scatter(
        w_F - min_F,
        np.random.uniform(low=(y_max / 6) * -1, high=0.0, size=len(w_F)),
        s=0.4,
        color=colors[0],
        alpha=0.8,
    )

    sns.kdeplot(
        (w_R * -1) - min_F,
        label="-f_rev",
        ax=axs,
        fill=True,
        alpha=0.5,
        color=colors[-1],
    )
    # sns.rugplot((w_R*-1)-min_F, ax=ax2, lw=1, alpha=.1,color=colors[-1])
    axs.scatter(
        (w_R * -1) - min_F,
        np.random.uniform(low=(y_max / 6) * -1, high=0.0, size=len(w_R)),
        s=0.4,
        color=colors[-1],
        alpha=0.8,
    )


def add_estimates_to_plot(
    axs, rexp_f: float, rexp_r: float, bar_bi: float, equ_r: float, colors
):
    # bidirectional estimate
    axs.axvline(x=bar_bi, color="red", lw=3, ls=":", label=r"$\Delta G_{bid}$")

    # EXP forward
    axs.axvline(
        x=rexp_f, color=colors[0], lw=3, ls=":", label=r"$\Delta G_{forw}$", alpha=0.5
    )

    # EXP reverse
    axs.axvline(
        x=rexp_r, color=colors[-1], lw=3, ls=":", alpha=0.5, label=r"$\Delta G_{rev}$"
    )

    # equ dG
    axs.axvline(
        x=equ_r, color="black", lw=3, ls="-", alpha=0.5, label=r"$\Delta G_{equ}$"
    )


def add_legend_to_plot(
    rexp_f: float,
    rexp_f_lower: float,
    rexp_f_upper: float,
    rexp_r: float,
    rexp_r_lower: float,
    rexp_r_upper: float,
    bar_bi: float,
    bar_bi_lower: float,
    bar_bi_upper: float,
    equ_r: float,
    equ_dDG: float,
):
    return f"""$\Delta G_{{forw}}$ = {rexp_f:.2f} [{rexp_f_lower:.2f}; {rexp_f_upper:.2f}]
$\Delta G_{{rev}}$ = {rexp_r:.2f} [{rexp_r_upper:.2f}; {rexp_r_lower:.2f}]
$\Delta G_{{bid}}$ = {bar_bi:.2f} [{bar_bi_lower:.2f}; {bar_bi_upper:.2f}]
$\Delta G_{{equ}}$ = {equ_r:.2f} [{equ_r-equ_dDG:.2f}; {equ_r+equ_dDG:.2f}]"""


def _plot_dist(min_F, w_F, w_R, equ_mbar, colors, axs):
    (
    rexp_f,
    rexp_f_lower,
    rexp_f_upper,
    rexp_r,
    rexp_r_lower,
    rexp_r_upper,
    bar_bi,
    bar_bi_lower,
    bar_bi_upper,
    equ_r,
    equ_dDG,
    ) = calculate_offsets_and_adjust_values(
        min_F, *calculate_bootstrap_estimates(w_F, w_R, equ_mbar)
    )

    plot_kde_and_scatter(axs, w_F, w_R, min_F, colors)
    add_estimates_to_plot(axs, rexp_f, rexp_r, bar_bi, equ_r, colors)
    return add_legend_to_plot(
        rexp_f,
        rexp_f_lower,
        rexp_f_upper,
        rexp_r,
        rexp_r_lower,
        rexp_r_upper,
        bar_bi,
        bar_bi_lower,
        bar_bi_upper,
        equ_r,
        equ_dDG,
    )


# Plot per hipen system the non-equilibrium and instantaneous work distribution
def plot_dist_for_fep_and_neq(r: Results, system_name: str, axs: plt.Axes):
    """plot results for a single system
    Args:
        r (Results): results object
        system_name (str): system name
        axs (plt.Axes): axes to plot on

    Returns:
        plt.Axes: axes with plot
    """

    colors = sns.color_palette("flare", n_colors=201)
    # draw histogramms and results for FEP and NEQ
    ##################################################
    # start with FEP
    w_F = r.dE_mm_to_qml
    w_R = r.dE_qml_to_mm
    equ_mbar = r.equ_mbar
    # calculate offset
    tmp = np.append(w_F, w_R * -1)
    min_F = min(tmp)
    textstr_fep = _plot_dist(min_F, w_F, w_R, equ_mbar, colors, axs[0])
    axs[0].set_title(f"FEP - {system_name}", fontsize=15)

    ############################################################
    # now NEQ
    w_F = r.W_mm_to_qml
    w_R = r.W_qml_to_mm
    textstr_neq = _plot_dist(min_F, w_F, w_R, equ_mbar, colors, axs[1])

    axs[1].set_title(f"NEQ - {system_name}", fontsize=15)

    # these are matplotlib.patch.Patch properties
    props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)

    # place a text box in upper left in axes coords
    axs[0].text(
        0.05,
        0.95,
        textstr_fep,
        transform=axs[0].transAxes,
        fontsize=15,
        verticalalignment="top",
        bbox=props,
        horizontalalignment="left",
    )
    axs[1].text(
        0.05,
        0.95,
        textstr_neq,
        transform=axs[1].transAxes,
        fontsize=15,
        verticalalignment="top",
        bbox=props,
    )

    # set axis labels and legend for both axis
    axs[0].get_yaxis().set_ticks([])
    axs[0].set_ylabel("Density", fontsize=14)
    axs[1].get_yaxis().set_ticks([])
    axs[1].set_ylabel("Density", fontsize=14)
    axs[0].legend(loc="upper right", fontsize=14, fancybox=True, framealpha=0.5)


def plot_dist_for_neq(rs, system_name, axs):
    # get all dEs
    dEs = []
    for r in rs:
            dEs.extend(r.dE_mm_to_qml)
            dEs.extend(r.dE_qml_to_mm * -1)
    # offset value
    min_F = min(dEs)
    # calculate all values
    ############################################################
    # now NEQ
    switching_times = ['5 ps', '10 ps', '20 ps', '50 ps']
    equ_mbar = rs[0].equ_mbar
    for idx, r in enumerate(rs):
            #tmp = np.append(r.W_mm_to_qml, r.W_qml_to_mm * -1)
            #min_F = abs(min(tmp))
            w_F, w_R = r.W_mm_to_qml, r.W_qml_to_mm
            textstr_neq = _plot_dist(min_F, w_F, w_R, equ_mbar, colors, axs[idx])

            axs[idx].set_title(f'NEQ - {system_name} - {switching_times[idx]}', fontsize=15)

            # these are matplotlib.patch.Patch properties
            props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)

            # place a text box in upper left in axes coords
            axs[idx].text(0.05, 0.95, textstr_neq, transform=axs[idx].transAxes, fontsize=15,
                    verticalalignment='top', bbox=props)

            # set axis labels and legend for both axis
            axs[idx].get_yaxis().set_ticks([])
            axs[idx].set_ylabel('Density',fontsize=14 )
            axs[0].legend(loc='upper right',fontsize=14, fancybox=True, framealpha=0.5)

    

def plot_distributions(
    all_systems_results, all_used_systems, batch_size=6, neq_switching_length=5000, only_neq:bool=False
):
    """Plot distributions for each system.

    Args:
        all_systems_results (dict): Dictionary holding system results.
        all_used_systems (list): List of all systems used.
        batch_size (int, optional): The number of systems to plot in a batch. Defaults to 6.
        neq_switching_length (int, optional): Non-equilibrium switching length. Defaults to 5000.
    """
    for batch_idx, list_of_systems in enumerate(
        grouper(batch_size, all_used_systems, ("", "", ""))
    ):
        print(list_of_systems)
        if only_neq:
            fig, axs = plt.subplots(len(list_of_systems), 4, figsize=(25.0, 13.0), dpi=600)
        else:
            fig, axs = plt.subplots(len(list_of_systems), 2, figsize=(13.0, 17.0), dpi=600)

        for idx, (system_name, smiles, hipen_id) in enumerate(list_of_systems):
            # skip if no system name is provided
            if not system_name:
                continue
            if only_neq:
                rs = [ all_systems_results[neq_switching_length][hipen_id] for neq_switching_length in SWITCHING_LENGTHS]
                plot_dist_for_neq(rs, hipen_id, axs[idx])
            else:
                r = all_systems_results[neq_switching_length][hipen_id]
                plot_dist_for_fep_and_neq(r, hipen_id, axs[idx])

        plt.tight_layout()
        if only_neq:
            plt.savefig(f'../figures/{batch_idx}_batch_neq_dist_for_each_switching_length.png')
        else:
            plt.savefig(f"../figures/{batch_idx}_batch_dist_{neq_switching_length=}.png")
        plt.show()
        
    

In [None]:
plot_distributions(all_systems_results, all_used_systems)

In [None]:
plot_distributions(all_systems_results, all_used_systems, only_neq=True)