In [None]:

from tqdm import tqdm
import numpy as np
import pandas as pd
import KDEpy
from sklearn.neighbors import KernelDensity
from scipy.interpolate import interp1d

import dill as pickle

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import seaborn as sns
sns.set(rc={"figure.dpi": 150})

import importlib
import fresco
importlib.reload(fresco)


pairs = ['Donor-Aromatic',
         'Aromatic-Acceptor',
         'Aromatic-Aromatic',
         'Donor-Donor',
         'Donor-Acceptor',
         'Acceptor-Acceptor']


In [None]:
nx = 500
n_bins = 20

cmap = plt.get_cmap("tab10")
sns.set_style('white')
frag_2body_dict = pickle.load(
    open(pickle_dir + '/frag_pair_distance_dict_mpro.pickle', 'rb'))

kde_dir = '/rds-d7/project/rds-ZNFRY9wKoeE/EnamineREAL/pickles/'
kdes = {'mpro': 'kde_dict_spl_mpro.pickle',
        'mac1': 'kde_dict_spl_mac1.pickle',
        'dpp11': 'kde_dict_spl_dpp11.pickle'}
kde_dicts = pickle.load(open(kde_dir+kdes['mpro'], 'rb'))

fig, axs = plt.subplots(nrows=6, figsize=(6, 24), dpi=200)
for i, combo in enumerate(pairs):
    x = np.linspace(0, np.amax(frag_2body_dict[combo]), nx)
    kde_pair = kde_dicts[combo]
    axs[i].set_title('{} distance distribution'.format(
        combo))
    axs[i].hist(frag_2body_dict[combo], bins=n_bins,
                alpha=0.5, density=True, color='#1E85FC')
    axs[i].plot(x, np.exp(kde_pair(x)), color=cmap(1))
    axs[i].set_xlim(left=0)
    axs[i].set_ylabel('Probability Density')
    axs[i].set_xlabel(r'$d (\AA)$')
    fig.tight_layout()
    fig.show()


In [None]:
def plot_dists_and_kdes(kde_dicts, frag_2body_dict, rand_kde_dicts, rand_2body_dicts, name, c=0, weights=None):
    cmap = plt.get_cmap("tab10")

    n_bins = 30
    n_rand = 10
    nx = 500

    fig, axs = plt.subplots(nrows=6, ncols=2, sharex=True, figsize=(
        16, 24), dpi=200)
    fig.suptitle(name)

    new_kdes = {}
    for i, combo in tqdm(enumerate(pairs), total=6):
        axs[i, 0].set_title('{} histogram, N={}'.format(
            combo, len(frag_2body_dict[combo])))
        axs[i, 0].hist(frag_2body_dict[combo], bins=n_bins,
                       alpha=0.5, density=True, color=cmap(c))

        axs[i, 1].set_title('{} KDE'.format(combo))

        kde_pair = kde_dicts[combo]

        x = np.linspace(0, np.amax(frag_2body_dict[combo]), nx)

        # pair_dist = np.exp(kde_pair.score_samples(x.reshape(-1, 1)))
        pair_dist = kde_pair.score_samples(x.reshape(-1, 1))
        pair_dist = pair_dist.flatten()

        axs[i, 1].plot(x, pair_dist, color=cmap(c))
        # axs[i,1].set_yscale('log')

        if weights is None:
            kdepy_model = KDEpy.FFTKDE(kernel='gaussian', bw='ISJ').fit(
                frag_2body_dict[combo])
        else:
            # print(frag_2body_dict[combo].shape)
            # print(weights[combo].shape)
            kdepy_model = KDEpy.FFTKDE(kernel='gaussian', bw='ISJ').fit(
                frag_2body_dict[combo], weights=weights[combo])
        # new_kdes[combo] = kdepy_model
        kdepy_x, kdepy_y = kdepy_model.evaluate()
        kdepy_bw = kdepy_model.bw
        axs[i, 1].plot(kdepy_x, np.log(kdepy_y), color=cmap(c+1))

        if weights is None:
            kde_new = KernelDensity(
                kernel='gaussian', bandwidth=kdepy_bw, rtol=1e-4).fit(
                    frag_2body_dict[combo].reshape(-1, 1))
        else:
            kde_new = KernelDensity(
                kernel='gaussian', bandwidth=kdepy_bw, rtol=1e-4).fit(
                    frag_2body_dict[combo].reshape(-1, 1), sample_weight=weights[combo])

        # new_kdes[combo] = kde_new

        # new_dist = np.exp(kde_new.score_samples(x.reshape(-1, 1)))
        new_dist = kde_new.score_samples(x.reshape(-1, 1))

        new_dist = new_dist.flatten()
        axs[i, 1].plot(x, new_dist, color=cmap(c+2))

        # use interpolated spline to speed up future scoring
        spl = interp1d(kdepy_x, np.log(kdepy_y), fill_value='extrapolate')
        axs[i, 1].plot(kdepy_x, spl(kdepy_x), color=cmap(c+3))
        axs[i, 1].set_ylim(top=0, bottom=-16)

        new_kdes[combo] = spl
        for n in range(n_rand):
            axs[i, 0].hist(rand_2body_dicts[n][combo], bins=30,
                           alpha=0.1, density=True, color='grey')
            kde_rand = rand_kde_dicts[n][combo]

            # rand_dist = np.exp(kde_rand.score_samples(x.reshape(-1, 1)))
            rand_dist = kde_rand.score_samples(x.reshape(-1, 1))
            rand_dist = rand_dist.flatten()
            rand_dist = rand_dist  # /np.sum(rand_dist)

            # axs[i, 1].plot(x, rand_dist, alpha=0.2, color='grey')

    legend_elements = [Rectangle((0, 0), 1, 1, color=cmap(c), label='Measured'),
                       Rectangle((0, 0), 1, 1,  color='grey', label='Random')]
    axs[0, 0].legend(handles=legend_elements, loc='upper right')
    legend_elements = [Rectangle((0, 0), 1, 1, color=cmap(c), label='sklearn'),
                       Rectangle((0, 0), 1, 1, color=cmap(c+1), label='kdepy'),
                       Rectangle((0, 0), 1, 1, color=cmap(
                           c+2), label='sklearn new bw'),
                       Rectangle((0, 0), 1, 1, color=cmap(
                           c+3), label='interpolation'),
                       Rectangle((0, 0), 1, 1,  color='grey', label='Random')]
    axs[0, 1].legend(handles=legend_elements, loc='upper right')
    fig.add_subplot(111, frameon=False)
    plt.tick_params(labelcolor='none', which='both', top=False,
                    bottom=False, left=False, right=False)
    plt.xlabel('Distance (angstrom)', labelpad=30)
    plt.ylabel('Probability Density', labelpad=30)

    fig.tight_layout()
    fig.show()
    return new_kdes


In [None]:
data_dir = '/home/wjm41/ml_physics/frag-pcore-screen/data/Mpro'

pickle_dir = '/home/wjm41/ml_physics/frag-pcore-screen/notebooks/pickles'

frags = pickle.load(open(data_dir + '/frags_mpro.pickle', 'rb'))

frag_pair_distance_dict = pickle.load(
    open(pickle_dir + '/frag_pair_distance_dict_mpro.pickle', 'rb'))
rand_pair_dicts = pickle.load(
    open(pickle_dir + '/rand_pair_dicts_mpro.pickle', 'rb'))

kde_dict_opt = pickle.load(
    open(pickle_dir + '/kde_dict_opt_mpro.pickle', 'rb'))
rand_kde_dicts = pickle.load(
    open(pickle_dir + '/rand_kde_dicts_mpro.pickle', 'rb'))
new_kdes = plot_dists_and_kdes(
    kde_dict_opt, frag_pair_distance_dict, rand_kde_dicts, rand_pair_dicts, name='Mpro', c=0)
