In [None]:
from rdkit import Chem
from rdkit.Chem.Draw import SimilarityMaps
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_useSVG=True
IPythonConsole.drawOptions.addAtomIndices=True
from rdkit.Chem.rdDepictor import Compute2DCoords
from rdkit.Chem.Draw.IPythonConsole import drawMol3D
import py3Dmol
from tqdm import tqdm
import numpy as np
import pandas as pd
import time
import pickle
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.patches import Rectangle, Circle, ConnectionPatch
import datashader as ds
from datashader.mpl_ext import dsshow
from serenityff.charge.tree.tree import tree
from serenityff.charge.tree.atom_features import AtomFeatures
from openff.toolkit.topology import Molecule, Topology
from openff.toolkit.typing.engines.smirnoff import ForceField
from openff.toolkit.utils.toolkits import RDKitToolkitWrapper, AmberToolsToolkitWrapper

In [None]:
plt.rcParams.update({'font.size': 16}) 

In [None]:
tree_folder="/localhome/mlehner/dash_data/tree/"
folder_train = "/localhome/mlehner/dash_data/test147_train/"
folder_explain = "/localhome/mlehner/dash_data/test154_explain/"
folder_4charges = "/localhome/mlehner/dash_data/test168_4charges/"
folder_attThresh="/localhome/mlehner/dash_data/test171_attThresh/"
folder_depth="/localhome/mlehner/dash_data/test172_depth/"
folder_aa = "/localhome/mlehner/dash_data/test144_aa/outs/"

sdf_test_file_path = f"{folder_4charges}/test.sdf"
sdf_aa_file_path = f"{folder_aa}/../all_aa.sdf"

In [None]:
mol_sup_comb = Chem.SDMolSupplier(f"{folder_train}combined_multi.sdf", removeHs=False)

In [None]:
def rmse(x, y):
    return np.sqrt(np.mean((x - y) ** 2))
def r2_correlation(x, y):
    return np.corrcoef(x, y)[0, 1] ** 2
def ratio_over_0_05(x, y):
    return np.sum(np.abs(x - y) > 0.05) / len(x)
def percentile(x, y, p=90):
    return np.nanpercentile(np.abs(x - y), p)

## Attention Threshold

In [None]:
df_dict_attThresh = {}
for df_idx in range(1,100):
    try:
        df = pd.read_csv(f"{folder_attThresh}/df_{df_idx}.csv")
        df_dict_attThresh[df_idx] = df
    except:
        pass

In [None]:
attention_thresholds_to_test = np.linspace(0.1, 6, 100)

In [None]:
df_att_vs_rmse_temp = {}
df_att_vs_r2_temp = {}
df_att_vs_ratio_temp = {}
df_att_vs_90p_temp = {}
for df_idx in df_dict_attThresh.keys():
    df = df_dict_attThresh[df_idx]
    df_att_vs_rmse_temp[df_idx] = rmse(df["mbis_charge"], df["tree_charge"])
    df_att_vs_r2_temp[df_idx] = r2_correlation(df["mbis_charge"], df["tree_charge"])
    df_att_vs_ratio_temp[df_idx] = ratio_over_0_05(df["mbis_charge"], df["tree_charge"])
    df_att_vs_90p_temp[df_idx] = percentile(df["mbis_charge"], df["tree_charge"], p=90)
df_att_vs_rmse = pd.DataFrame.from_dict(df_att_vs_rmse_temp, orient="index", columns=["rmse"])
df_att_vs_r2 = pd.DataFrame.from_dict(df_att_vs_r2_temp, orient="index", columns=["r2"])
df_att_vs_ratio = pd.DataFrame.from_dict(df_att_vs_ratio_temp, orient="index", columns=["ratio"])
df_att_vs_90p = pd.DataFrame.from_dict(df_att_vs_90p_temp, orient="index", columns=["90p"])
df_att_vs_rmse.index = attention_thresholds_to_test[df_att_vs_rmse.index]
df_att_vs_r2.index = attention_thresholds_to_test[df_att_vs_r2.index]
df_att_vs_ratio.index = attention_thresholds_to_test[df_att_vs_ratio.index]
df_att_vs_90p.index = attention_thresholds_to_test[df_att_vs_90p.index]

In [None]:
df_att_vs_rmse.plot(figsize=(3,3))

In [None]:
rmse_list_attthresh = [x[0] for x in df_att_vs_rmse.values.tolist()]
attention_thresholds_to_test = df_att_vs_rmse.index.tolist()

In [None]:
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : 16}
plt.rc('font', **font)
xy_range = [[-1.6,2.2],[-1.6,2.2]]
vmax = 1000
vmin = 0.6
#points picked for correlation plots
picked_p1 = 1
picked_p2 = 20
picked_p3 = 60

df_p1 = df_dict_attThresh[picked_p1]
df_p2 = df_dict_attThresh[picked_p2]
df_p3 = df_dict_attThresh[picked_p3]

fig = plt.figure(figsize=(12,12))#, constrained_layout=True)
axs = fig.subplot_mosaic([["Left", "TopRight"],["Left", "BottomRight"],["Left", "BottomRight2"]], gridspec_kw={"width_ratios":[1, 1]})
axs["Left"].plot(attention_thresholds_to_test, rmse_list_attthresh, label="RMSE", color="C2")
axs["Left"].set_xlabel("Attention threshold")
axs["Left"].set_ylabel("RMSE")
axs["Left"].legend(loc="upper left")

# 1. correlation plot at picked point 1
h1 = axs["TopRight"].hist2d(df_p1["mbis_charge"], df_p1["tree_charge"], bins=200, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs["TopRight"].set_xlabel("MBIS charge [e]")
axs["TopRight"].set_ylabel("Tree charge [e]")
axs["TopRight"].plot(xy_range[0], xy_range[1], color="grey", linestyle="--")
axs["TopRight"].set_aspect("equal")
axs["TopRight"].set_title(f"Attention threshold = {attention_thresholds_to_test[picked_p1]:.2f}")

# 2. correlation plot at picked point 2
h2 = axs["BottomRight"].hist2d(df_p2["mbis_charge"], df_p2["tree_charge"], bins=200, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs["BottomRight"].set_xlabel("MBIS charge [e]")
axs["BottomRight"].set_ylabel("Tree charge [e]")
axs["BottomRight"].plot(xy_range[0], xy_range[1], color="grey", linestyle="--")
axs["BottomRight"].set_aspect("equal")
axs["BottomRight"].set_title(f"Attention threshold = {attention_thresholds_to_test[picked_p2]:.2f}")

# 3. correlation plot at picked point 3
h3 = axs["BottomRight2"].hist2d(df_p3["mbis_charge"], df_p3["tree_charge"], bins=200, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs["BottomRight2"].set_xlabel("MBIS charge [e]")
axs["BottomRight2"].set_ylabel("Tree charge [e]")
axs["BottomRight2"].plot(xy_range[0], xy_range[1], color="grey", linestyle="--")
axs["BottomRight2"].set_aspect("equal")
axs["BottomRight2"].set_title(f"Attention threshold = {attention_thresholds_to_test[picked_p3]:.2f}")

# color bars
cbar = fig.colorbar(h1[3], ax=axs["TopRight"], label="Counts (log scale) [a.u.]")
cbar = fig.colorbar(h2[3], ax=axs["BottomRight"], label="Counts (log scale) [a.u.]")
cbar = fig.colorbar(h3[3], ax=axs["BottomRight2"], label="Counts (log scale) [a.u.]")
#fig.suptitle("Comparison of different attention thresholds\n Dataset: 20 amino acids", fontsize=20, fontweight="bold")

#mark points in main plot
axs["Left"].scatter(attention_thresholds_to_test[picked_p1], rmse_list_attthresh[picked_p1], color="C1", marker="o", s=100)
axs["Left"].scatter(attention_thresholds_to_test[picked_p2], rmse_list_attthresh[picked_p2], color="C1", marker="o", s=100)
axs["Left"].scatter(attention_thresholds_to_test[picked_p3], rmse_list_attthresh[picked_p3], color="C1", marker="o", s=100)

#draw fine grey lines from the mark to the corresponding plot
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p1], rmse_list_attthresh[picked_p1]), xyB=(xy_range[1][0], xy_range[1][0]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["TopRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p1], rmse_list_attthresh[picked_p1]), xyB=(xy_range[1][0], xy_range[1][1]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["TopRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p2], rmse_list_attthresh[picked_p2]), xyB=(xy_range[1][0], xy_range[1][0]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["BottomRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p2], rmse_list_attthresh[picked_p2]), xyB=(xy_range[1][0], xy_range[1][1]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["BottomRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p3], rmse_list_attthresh[picked_p3]), xyB=(xy_range[1][0], xy_range[1][0]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["BottomRight2"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p3], rmse_list_attthresh[picked_p3]), xyB=(xy_range[1][0], xy_range[1][1]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["BottomRight2"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)

#zoom in on values in main plot x(0.9, 1.1) y(0.04, 0.05)
#axins = axs["Left"].inset_axes([0.64, 0.5, 0.35, 0.18])
#axins.plot(attention_thresholds_to_test, rmse_list_attthresh, color="C2")
#axins.set_xlim(0.95, 1.2)
#axins.set_ylim(0.047, 0.048)
#axs["Left"].indicate_inset_zoom(axins)


# final adjustments
fig.tight_layout(pad=2)
plt.savefig("test_123_attthreshRMSE+Corr.pdf")
plt.show()

In [None]:
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : 16}
plt.rc('font', **font)
xy_range = [[-1.6,2.2],[-1.6,2.2]]
vmax = 1000
vmin = 0.6
#points picked for correlation plots
picked_p1 = 9
picked_p2 = 65

df_p1 = df_dict_attThresh[picked_p1]
df_p2 = df_dict_attThresh[picked_p2]

fig = plt.figure(figsize=(9,9))#, constrained_layout=True)
axs = fig.subplot_mosaic([["Left", "TopRight"],["Left", "BottomRight"]], width_ratios=[1,1])#, gridspec_kw={"width_ratios":[1, 1.6]})
axs["Left"].plot(attention_thresholds_to_test, rmse_list_attthresh, label="RMSE", color="C2")
axs["Left"].set_xlabel("Attention threshold")
axs["Left"].set_ylabel("RMSE [e]")
axs["Left"].legend(loc="upper left")

# 1. correlation plot at picked point 1
h1 = axs["TopRight"].hist2d(df_p1["mbis_charge"], df_p1["tree_charge"], bins=200, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs["TopRight"].set_xlabel("MBIS charge [e]")
axs["TopRight"].set_ylabel("Tree charge [e]")
axs["TopRight"].plot(xy_range[0], xy_range[1], color="grey", linestyle="--")
axs["TopRight"].set_aspect("equal")
axs["TopRight"].set_title(f"Attention threshold = {attention_thresholds_to_test[picked_p1]:.2f}")

# 2. correlation plot at picked point 2
h2 = axs["BottomRight"].hist2d(df_p2["mbis_charge"], df_p2["tree_charge"], bins=200, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs["BottomRight"].set_xlabel("MBIS charge [e]")
axs["BottomRight"].set_ylabel("Tree charge [e]")
axs["BottomRight"].plot(xy_range[0], xy_range[1], color="grey", linestyle="--")
axs["BottomRight"].set_aspect("equal")
axs["BottomRight"].set_title(f"Attention threshold = {attention_thresholds_to_test[picked_p2]:.2f}")

# color bars
cbar1 = fig.colorbar(h1[3], ax=axs["TopRight"], label="Counts (log scale) [a.u.]", shrink=0.7)
cbar2 = fig.colorbar(h2[3], ax=axs["BottomRight"], label="Counts (log scale) [a.u.]", shrink=0.7)

#mark points in main plot
axs["Left"].scatter(attention_thresholds_to_test[picked_p1], rmse_list_attthresh[picked_p1], color="C1", marker="o", s=100)
axs["Left"].scatter(attention_thresholds_to_test[picked_p2], rmse_list_attthresh[picked_p2], color="C1", marker="o", s=100)

#draw fine grey lines from the mark to the corresponding plot
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p1], rmse_list_attthresh[picked_p1]), xyB=(xy_range[1][0], xy_range[1][0]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["TopRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p1], rmse_list_attthresh[picked_p1]), xyB=(xy_range[1][0], xy_range[1][1]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["TopRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p2], rmse_list_attthresh[picked_p2]), xyB=(xy_range[1][0], xy_range[1][0]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["BottomRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(attention_thresholds_to_test[picked_p2], rmse_list_attthresh[picked_p2]), xyB=(xy_range[1][0], xy_range[1][1]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["BottomRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)

# final adjustments
#fig.tight_layout(pad=0.1)
plt.subplots_adjust(wspace=0.3, hspace=0.1)
plt.savefig("test_123_attthreshRMSE+Corr2.pdf")
plt.show()

In [None]:
# min attention
df_att_vs_rmse[df_att_vs_rmse["rmse"] == df_att_vs_rmse["rmse"].min()]

In [None]:
df_dict_attThresh[78].head(2)

In [None]:
df_test = df_dict_attThresh[78]
tree_norm_list = []
for mol_idx in df_test["mol_index"].unique():
    df_test_temp = df_test[df_test["mol_index"] == mol_idx]
    charges_raw = df_test_temp["tree_raw"]
    charge_sum = charges_raw.sum()
    num_atoms = len(charges_raw)
    charges_norm = charges_raw - (charge_sum/num_atoms)
    tree_norm_list.extend(charges_norm.values.tolist())
df_test["tree_norm"] = tree_norm_list

In [None]:
df_test = df_dict_attThresh[78]
rmse_raw = rmse(df_test["tree_raw"], df_test["mbis_charge"])
r2_raw = r2_correlation(df_test["tree_raw"], df_test["mbis_charge"])
percentile_raw = percentile(df_test["tree_raw"], df_test["mbis_charge"])
ratio_over_0_05_raw = ratio_over_0_05(df_test["tree_raw"], df_test["mbis_charge"])
rmse_norm = rmse(df_test["tree_norm"], df_test["mbis_charge"])
r2_norm = r2_correlation(df_test["tree_norm"], df_test["mbis_charge"])
percentile_norm = percentile(df_test["tree_norm"], df_test["mbis_charge"])
ratio_over_0_05_norm = ratio_over_0_05(df_test["tree_norm"], df_test["mbis_charge"])
rmse_std = rmse(df_test["tree_charge"], df_test["mbis_charge"])
r2_std = r2_correlation(df_test["tree_charge"], df_test["mbis_charge"])
percentile_std = percentile(df_test["tree_charge"], df_test["mbis_charge"])
ratio_over_0_05_std = ratio_over_0_05(df_test["tree_charge"], df_test["mbis_charge"])

df_norms = pd.DataFrame([[rmse_raw, r2_raw, percentile_raw, ratio_over_0_05_raw], [rmse_norm, r2_norm, percentile_norm, ratio_over_0_05_norm], [rmse_std, r2_std, percentile_std, ratio_over_0_05_std]], columns=["RMSE", "R2", "Percentile", "Ratio over 0.05"], index=["raw", "normalized", "std. weighted"])

In [None]:
df_norms

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
df_norms["RMSE"].plot.bar(ax=ax, color="lime", position=1, width=0.4)
ax.set_ylabel("RMSE [e]")
ax.set_ylim([0.028,0.0293])
twin_ax = ax.twinx()
df_norms["R2"].plot.bar(ax=twin_ax, color="darkgreen", position=0, width=0.4)
twin_ax.set_ylabel("R2")
twin_ax.set_ylim([0.99,0.994])
ax.set_xlim([-0.5,2.5])
fig.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, 0.85))
plt.savefig("test_123_norms.pdf", bbox_inches="tight")
plt.show()

In [None]:
#plot rmse per element
fig, ax = plt.subplots(figsize=(12,6), ncols=2, nrows=1, width_ratios=[1.7,1])
df_test.groupby("element").apply(lambda x: rmse(x["tree_charge"], x["mbis_charge"])).sort_values(ascending=False).plot.bar(ax=ax[1], color="lime", position=1, width=0.4)
df_test.groupby("element").apply(lambda x: rmse(x["tree_raw"], x["mbis_charge"])).sort_values(ascending=False).plot.bar(ax=ax[1], color="darkgreen", position=0, width=0.4)
ax[1].set_ylabel("RMSE [e]")
ax[1].set_xlabel("Element")
ax[1].set_xlim([-0.5, 10.5])
ax[1].legend(["norm", "raw"])
# 2d histogram tree_charge vs mbis_charge
ax[0].hist2d(df_test["mbis_charge"], df_test["tree_charge"], bins=100, cmap="Greens", norm=LogNorm(0.6, 1000))
ax[0].set_ylabel("Tree charge [e]")
ax[0].set_xlabel("MBIS charge [e]")
ax[0].set_xlim([-2, 2.5])
ax[0].set_ylim([-2, 2.5])
ax[0].plot([-2, 2.5], [-2, 2.5], color="black", linestyle="--")
ax[0].set_aspect("equal")
cbar = fig.colorbar(plt.cm.ScalarMappable(norm=LogNorm(0.6, 1000), cmap="Greens"), ax=ax[0], shrink=0.75)
cbar.set_label("Counts [a.u.]")
fig.set_tight_layout(True)
plt.savefig("test_123_rmsePerElement.pdf", bbox_inches="tight")
plt.show()

In [None]:
num_unique_mols = len(df_test["mol_index"].unique())
print(f"Number of unique molecules: {num_unique_mols}")


In [None]:
del df_dict_attThresh

## Depth test

In [None]:
attention_thresholds_to_test = list(range(1, 18))
df_dict_depth = {}
for df_idx in range(1,18):
    try:
        df = pd.read_csv(f"{folder_depth}/df_{df_idx}.csv")
        df_dict_depth[df_idx] = df
    except:
        pass

In [None]:
df_att_vs_rmse_temp = {}
df_att_vs_r2_temp = {}
df_att_vs_ratio_temp = {}
df_att_vs_90p_temp = {}
for df_idx in df_dict_depth.keys():
    df = df_dict_depth[df_idx]
    df_att_vs_rmse_temp[df_idx] = rmse(df["mbis_charge"], df["tree_charge"])
    df_att_vs_r2_temp[df_idx] = r2_correlation(df["mbis_charge"], df["tree_charge"])
    df_att_vs_ratio_temp[df_idx] = ratio_over_0_05(df["mbis_charge"], df["tree_charge"])
    df_att_vs_90p_temp[df_idx] = percentile(df["mbis_charge"], df["tree_charge"], p=90)
df_depth_vs_rmse = pd.DataFrame.from_dict(df_att_vs_rmse_temp, orient="index", columns=["rmse"])
df_depth_vs_r2 = pd.DataFrame.from_dict(df_att_vs_r2_temp, orient="index", columns=["r2"])
df_depth_vs_ratio = pd.DataFrame.from_dict(df_att_vs_ratio_temp, orient="index", columns=["ratio"])
df_depth_vs_90p = pd.DataFrame.from_dict(df_att_vs_90p_temp, orient="index", columns=["90p"])

In [None]:
rmse_list_depth = [x[0] for x in df_depth_vs_rmse.values.tolist()]
depth_to_test = df_depth_vs_rmse.index.tolist()

In [None]:
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : 16}
plt.rc('font', **font)
xy_range = [[-1.6,2.2],[-1.6,2.2]]
vmax = 1000
vmin = 0.1
#points picked for correlation plots
picked_p1 = 1
picked_p2 = 11

df_p1 = df_dict_depth[picked_p1]
df_p2 = df_dict_depth[picked_p2]

fig = plt.figure(figsize=(9,9))#, constrained_layout=True)
axs = fig.subplot_mosaic([["Left", "TopRight"],["Left", "BottomRight"]], width_ratios=[1,1])#, gridspec_kw={"width_ratios":[1, 1.6]})
axs["Left"].plot(depth_to_test, rmse_list_depth, label="RMSE", color="C2")
axs["Left"].set_xlabel("Tree depth")
axs["Left"].set_ylabel("RMSE [e]")
#axs["Left"].legend(loc="upper left")

# 1. correlation plot at picked point 1
h1 = axs["TopRight"].hist2d(df_p1["mbis_charge"], df_p1["tree_charge"], bins=200, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs["TopRight"].set_xlabel("MBIS charge [e]")
axs["TopRight"].set_ylabel("Tree charge [e]")
axs["TopRight"].plot(xy_range[0], xy_range[1], color="grey", linestyle="--")
axs["TopRight"].set_aspect("equal")
axs["TopRight"].set_title(f"Tree depth = {depth_to_test[picked_p1]:.2f}")

# 2. correlation plot at picked point 2
h2 = axs["BottomRight"].hist2d(df_p2["mbis_charge"], df_p2["tree_charge"], bins=200, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs["BottomRight"].set_xlabel("MBIS charge [e]")
axs["BottomRight"].set_ylabel("Tree charge [e]")
axs["BottomRight"].plot(xy_range[0], xy_range[1], color="grey", linestyle="--")
axs["BottomRight"].set_aspect("equal")
axs["BottomRight"].set_title(f"Tree depth = {depth_to_test[picked_p2]:.2f}")

# color bars
cbar1 = fig.colorbar(h1[3], ax=axs["TopRight"], label="Counts (log scale) [a.u.]", shrink=0.7)
cbar2 = fig.colorbar(h2[3], ax=axs["BottomRight"], label="Counts (log scale) [a.u.]", shrink=0.7)

#zoom in on values in main plot x(0.9, 1.1) y(0.04, 0.05)
axins = axs["Left"].inset_axes([0.5, 0.7, 0.45, 0.28])
axins.plot(depth_to_test, rmse_list_depth, color="C2")
axins.set_xlim(10, 17)
axins.set_ylim(0.028, 0.029)
axs["Left"].indicate_inset_zoom(axins)

#mark points in main plot
axs["Left"].scatter(depth_to_test[picked_p1], rmse_list_depth[picked_p1], color="C1", marker="o", s=100)
axs["Left"].scatter(depth_to_test[picked_p2], rmse_list_depth[picked_p2], color="C1", marker="o", s=100)

#draw fine grey lines from the mark to the corresponding plot
con = ConnectionPatch(xyA=(depth_to_test[picked_p1], rmse_list_depth[picked_p1]), xyB=(xy_range[1][0], xy_range[1][0]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["TopRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(depth_to_test[picked_p1], rmse_list_depth[picked_p1]), xyB=(xy_range[1][0], xy_range[1][1]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["TopRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(depth_to_test[picked_p2], rmse_list_depth[picked_p2]), xyB=(xy_range[1][0], xy_range[1][0]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["BottomRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)
con = ConnectionPatch(xyA=(depth_to_test[picked_p2], rmse_list_depth[picked_p2]), xyB=(xy_range[1][0], xy_range[1][1]), coordsA="data", coordsB="data", axesA=axs["Left"], axesB=axs["BottomRight"], color="grey", linestyle="--", linewidth=1)
axs["Left"].add_artist(con)

# final adjustments
#fig.tight_layout(pad=0.1)
plt.subplots_adjust(wspace=0.3, hspace=0.1)
plt.savefig("test_123_depthRMSE+Corr2.pdf")
plt.show()

In [None]:
time_per_depth = {}
for i in range(1, 18):
    try:
        filei = open(f"{folder_depth}/slurm-16595052_{i}.out", "r")
        lines = filei.readlines()
        time_per_depth[i] = float(lines[2].split("=")[-1])
    except:
        print(f"File {i} not found")

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
ax.plot(depth_to_test, rmse_list_depth, label="RMSE", color="limegreen")
# twin ax for time
ax2 = ax.twinx()
ax2.plot(depth_to_test, list(time_per_depth.values()), label="Time", color="darkgreen")

ax.set_xlabel("Tree depth")
ax.set_ylabel("RMSE [e]")
ax2.set_ylabel("Time [s]")

fig.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, 0.85))
plt.savefig("test_123_depthVsTime.pdf", bbox_inches='tight')

In [None]:
del df_dict_depth

## 4 charges

In [None]:
df_4charges = pd.read_csv(f"{folder_4charges}/df_charges.csv")

In [None]:
print(f"RMSE Tree to MBIS: \t{np.sqrt(np.mean((df_4charges['tree']-df_4charges['mbis'])**2)):0.3f}")
print(f"RMSE am1Bcc to MBIS: \t{np.sqrt(np.mean((df_4charges['am1Bcc']-df_4charges['mbis'])**2)):0.3f}")
print(f"RMSE gasteiger to MBIS: {np.sqrt(np.mean((df_4charges['gasteiger']-df_4charges['mbis'])**2)):0.3f}")
print(f"RMSE mmff to MBIS: \t{np.sqrt(np.mean((df_4charges['mmff']-df_4charges['mbis'])**2)):0.3f}")

In [None]:
df_4charges["am1Bcc"] = df_4charges["am1Bcc"].apply(lambda x: np.NAN if x == 0. else x)
df_4charges.dropna(inplace=True)

In [None]:
rmse_list = []
r2_list = []
for col in ["am1Bcc", "tree", "mmff", "gasteiger"]:
    rmse_list.append(np.sqrt(df_4charges["mbis"].sub(df_4charges[col]).pow(2).mean()))
    r2_list.append(np.corrcoef(df_4charges["mbis"], df_4charges[col])[0,1]**2)
print(rmse_list)
print(r2_list)

In [None]:
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : 14}
plt.rc('font', **font)
xy_range = [[-1.5,2.3],[-1.5,2.3]]
fig, axs = plt.subplots(2, 2, figsize=(10,10), sharex=True, sharey=True)
vmin=0.15
vmax=10000
# plot 2d histograms
h1 = axs[0,0].hist2d(df_4charges["mbis"], df_4charges["am1Bcc"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs[0,0].set_xlabel("")
axs[0,0].xaxis.set_label_position('top')
axs[0,0].xaxis.tick_top()
axs[0,0].set_ylabel("charge [e]")
h2 = axs[0,1].hist2d(df_4charges["mbis"], df_4charges["tree"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs[0,1].set_xlabel("")
axs[0,1].xaxis.set_label_position('top')
axs[0,1].xaxis.tick_top()
axs[0,1].set_ylabel("")
axs[0,1].yaxis.set_label_position('right')
axs[0,1].yaxis.tick_right()
h3 = axs[1,0].hist2d(df_4charges["mbis"], df_4charges["mmff"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs[1,0].set_xlabel("MBIS charge [e]")
axs[1,0].set_ylabel("charge [e]")
h4 = axs[1,1].hist2d(df_4charges["mbis"], df_4charges["gasteiger"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax))
axs[1,1].set_xlabel("MBIS charge [e]")
axs[1,1].set_ylabel("")
axs[1,1].yaxis.set_label_position('right')
axs[1,1].yaxis.tick_right()

# draw diagonal lines and add axis labels
for ax in axs.flat:
    ax.plot(xy_range[0], xy_range[1], color="grey", linestyle="--")
plt.subplots_adjust(wspace=0, hspace=0)

# add colorbar
cbar_ax = fig.add_axes([0.95, 0.15, 0.05, 0.7])
cbar = fig.colorbar(h1[3], cax=cbar_ax, label="Counts (log scale) [a.u.]")

# write name as text in plot
#for name, ax in zip(["a) AM1-BCC","b) DASH","c) MMFF","d) Gasteiger"], axs.flat):
#    ax.text(-1.15, 1.8, name, fontsize=16, fontweight="bold")
##
for name, ax, rmse_i, r2_i in zip(["A) AM1-BCC","B) DASH","C) MMFF","D) Gasteiger"], axs.flat, rmse_list, r2_list):
    ax.text(-1.15, 1.8, f"{name}", fontsize=16, fontweight="bold")
    ax.text(-0.9, 1.4, f"RMSE: {rmse_i:.2f}\u2009e\nR2: {r2_i:.2f}", fontsize=12)
#fig.suptitle("Comparison of different charge models to the MBIS reference", fontsize=20, fontweight="bold")
plt.savefig("test_123_4charges2.pdf", bbox_inches='tight')
plt.show()

In [None]:
del df_4charges

## Import Tree

In [None]:
test_tree_pruned = tree()
test_tree_pruned.from_folder_pickle(tree_folder)

## Atom Features and Attention

In [None]:
af_list_in_branches = []
for branch in test_tree_pruned.root.children:
    af_list_in_branches.append([AtomFeatures.lookup_int(x[0]) for x in branch.atoms])

In [None]:
def shorten_AF(af):
    elem, nBonds, charge, isConj, nHs = af.split(" ")
    is_ConjNew = "T" if isConj == "True" else "F"
    charge_new = " "+charge if eval(charge) >= 0 else charge
    return f"{elem} {nBonds} {charge_new} {is_ConjNew} {nHs}"

In [None]:
max_attention_lvl_0_list = []
for i, child in enumerate(test_tree_pruned.root.children):
    max_attention_lvl_0_list.append([i, child.attention, child.count])
df_max_attention_lvl_0_list = pd.DataFrame(max_attention_lvl_0_list, columns=["branch_idx", "attention", "count"])
df_max_attention_lvl_0_list.dropna(inplace=True)
df_max_attention_lvl_0_list.sort_values("attention", ascending=False, inplace=True)
df_max_attention_lvl_0_list["af"] = df_max_attention_lvl_0_list.branch_idx.apply(lambda x: AtomFeatures.lookup_int(test_tree_pruned.root.children[x].atoms[0][0]))
df_max_attention_lvl_0_list["af_short"] = df_max_attention_lvl_0_list.af.apply(shorten_AF)
df_max_attention_lvl_0_list.sort_values("attention", ascending=False, inplace=True)
df_max_attention_lvl_0_list.reset_index(inplace=True, drop=True)
df_max_attention_lvl_0_list["af_idx"] = df_max_attention_lvl_0_list.af.apply(lambda x: str(AtomFeatures.lookup_str(x)))

In [None]:
df_max_attention_lvl_0_list.head()

In [None]:
plt.rcParams.update({'font.size': 11})
fig, ax = plt.subplots(figsize=(15,5))
# use af, af_short or af_idx
df_max_attention_lvl_0_list.plot.line(x="af_idx", y="attention", ax=ax, color="lime", linewidth=3)
df_max_attention_lvl_0_list.plot.bar(x="af_idx", y="count", ax=ax, secondary_y=True, log=True, color="darkgreen")
ax.set_xlabel("atom feature", fontsize=16)
ax.set_ylabel("attention per atom feature in layer 0", fontsize=16)
ax.right_ax.set_ylabel("count per atom feature in layer 0", fontsize=16)
fig.savefig("test_123_AttentionPerFeature.pdf", bbox_inches="tight")
fig.show()
plt.rcParams.update({'font.size': 16})

In [None]:
h_idx = (np.where(np.array(af_list_in_branches) == "H 1 0 False 0"))[0][0]
f_idx = (np.where(np.array(af_list_in_branches) == "F 1 0 False 0"))[0][0]
cl_idx = (np.where(np.array(af_list_in_branches) == "Cl 1 0 False 0"))[0][0]
br_idx = (np.where(np.array(af_list_in_branches) == "Br 1 0 False 0"))[0][0]
i_idx = (np.where(np.array(af_list_in_branches) == "I 1 0 False 0"))[0][0]
#halogen_indices = [48,34,13,89]
halogen_indices = [f_idx, cl_idx, br_idx, i_idx]
print(halogen_indices)

In [None]:
collected_attention_per_node = []
def collect_attention(node):
    collected_attention_per_node.append([node.level, node.attention])
    for child in node.children:
        collect_attention(child)

In [None]:
collected_attention_per_node_for_branches_halogens = {}
for branch_idx in halogen_indices:
    collected_attention_per_node = []
    collect_attention(test_tree_pruned.root.children[branch_idx])
    df_tmp = pd.DataFrame(collected_attention_per_node, columns=["level", "attention"])
    df_tmp_mean = df_tmp.groupby("level").mean()
    df_tmp_mean["std"] = df_tmp.groupby("level").std()
    df_tmp_mean.loc[-1] = [0.0, 0.0]
    df_tmp_mean.index = df_tmp_mean.index + 1
    df_tmp_mean.sort_index(inplace=True)
    df_tmp_mean.fillna(0, inplace=True)
    collected_attention_per_node_for_branches_halogens[branch_idx] = df_tmp_mean

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
for branch_idx in halogen_indices:
    collected_attention_per_node_for_branches_halogens[branch_idx]["attention"].cumsum().plot.line(ax=ax, label=AtomFeatures.lookup_int(test_tree_pruned.root.children[branch_idx].atoms[0][0]), linewidth=3)
ax.set_xlabel("tree depth")
ax.set_ylabel("cumulative attention")
ax.set_ylim(0.4,1.7)
ax.set_xlim(0,12)
ax.legend()
fig.savefig("test_123_cumulativeAttentionPerLevelForHalogenBranches_noFill.pdf", bbox_inches="tight")
fig.show()

## charge distribution with depth

In [None]:
value_level_af_list = []

def get_value_level_af(tree_node):
    try:
        value = tree_node.result
        level = tree_node.level
        af = tree_node.atoms
        value_level_af_list.append([value, level, af])
    except:
        pass
    for child in tree_node.children:
        get_value_level_af(child)

get_value_level_af(test_tree_pruned.root)

In [None]:
# split list by level
value_level_af_list_by_level = [[]]
for entry in value_level_af_list:
    if len(value_level_af_list_by_level) < entry[1]:
        value_level_af_list_by_level.append([])
    value_level_af_list_by_level[entry[1]-1].append(entry)

In [None]:
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(10,10))
xy_range = [-1.5, 2]
ax[0,0].hist([x[0] for x in value_level_af_list_by_level[0]], bins=2000, label="level 0", range=xy_range, log=True, color="darkgreen")
ax[0,0].set_xlim(xy_range)
ax[0,0].set_ylim(0.7,2.6)
ax[0,0].set_yticks([1e0,2e0])
ax[0,0].set_yticklabels(["1","2"])
ax[0,0].set_xlabel("charge [e]")
ax[0,0].set_ylabel("count")
ax[0,0].text(.1,.9,"A", fontsize=20, ha='left', va='top', transform=ax[0,0].transAxes)
ax[0,1].hist([x[0] for x in value_level_af_list_by_level[1]], bins=2000, label="level 1", range=xy_range, log=True, color="darkgreen")
ax[0,1].set_xlim(xy_range)
ax[0,1].set_ylim(0.7, 6)
ax[0,1].set_yticks([1e0, 2e0, 3e0, 4e0, 5e0, 6e0])
ax[0,1].set_yticklabels(["1","2","3","4","5","6"])
ax[0,1].set_xlabel("charge [e]")
ax[0,1].set_ylabel("count")
ax[0,1].text(.1,.9,"B", fontsize=20, ha='left', va='top', transform=ax[0,1].transAxes)
ax[1,0].hist([x[0] for x in value_level_af_list_by_level[5]], bins=2000, label="level 5", range=xy_range, log=True, color="darkgreen")
ax[1,0].set_xlim(xy_range)
ax[1,0].set_xlabel("charge [e]")
ax[1,0].set_ylabel("count")
ax[1,0].text(.1,.9,"C", fontsize=20, ha='left', va='top', transform=ax[1,0].transAxes)
ax[1,1].hist([x[0] for x in value_level_af_list_by_level[10]], bins=2000, label="level 10", range=xy_range, log=True, color="darkgreen")
ax[1,1].set_xlim(xy_range)
ax[1,1].set_xlabel("charge [e]")
ax[1,1].set_ylabel("count")
ax[1,1].text(.1,.9,"D", fontsize=20, ha='left', va='top', transform=ax[1,1].transAxes)
fig.tight_layout()
fig.savefig("test_123_ChgDistrDepth.pdf", bbox_inches="tight")

## CNF Diff

In [None]:
df_cnf = pd.read_csv(f"{folder_explain}df_conf_diff.csv")

In [None]:
rmse_gnnVSmbis = np.sqrt(np.mean((df_cnf["gnn"] - df_cnf["mbis_charges"])**2))
r2_gnnVSmbis = np.corrcoef(df_cnf["gnn"], df_cnf["mbis_charges"])[0,1]**2
mae_gnnVSmbis = np.mean(np.abs(df_cnf["gnn"] - df_cnf["mbis_charges"]))
print(f"RMSE GNN vs MBIS: {rmse_gnnVSmbis:.5f}")
print(f"R2 GNN vs MBIS: {r2_gnnVSmbis:.5f}")
print(f"MAE GNN vs MBIS: {mae_gnnVSmbis:.5f}")

In [None]:
df_cnf["d_cnf"] = np.abs(df_cnf["delta_cnf"])
df_cnf["d_gnn"] = np.abs(df_cnf["mbis_charges"] - df_cnf["gnn"])

In [None]:
df_cnf.head(2)

In [None]:
#GNN to MBIS and MBIS to CNF as line plot of histogram
fig, ax = plt.subplots(figsize=(5,5))
df_cnf["d_gnn"].plot.hist(bins=100, ax=ax, color="C2", logy=True, range=[0, 0.4], histtype="step")#, density=True)
df_cnf["d_cnf"].plot.hist(bins=100, ax=ax, color="C0", logy=True, range=[0, 0.4], histtype="step")#, density=True)#alpha=0.6
ax.set_xlabel("Absolute difference [e]")
ax.set_ylabel("Count")
ax.set_xlim(0, 0.4)
ax.legend(["GNN to MBIS reference", "MBIS to CNF median"])
fig.tight_layout()
fig.savefig("test_123_abs_diff_cnfs_line.pdf", bbox_inches="tight")
fig.show()

In [None]:
df_cnf.sort_values("d_cnf", ascending=False, inplace=True)

In [None]:
max_outlier = df_cnf.iloc[9899]
mol_idx_max_outlier = int(max_outlier["mol_idx"])
atom_idx_max_outlier = int(max_outlier["atom_idx"])

In [None]:
mol_with_cnf_idx_max_outlier = int(max_outlier["mol_with_cnf_idx"])

In [None]:
cnf_mol_idx1, cnf_mol_idx2, cnf_mol_idx3 = [int(x) for x in df_cnf[df_cnf["mol_with_cnf_idx"] == mol_with_cnf_idx_max_outlier].mol_idx.unique()]

In [None]:
df_line_cnf1 = df_cnf[(df_cnf["mol_idx"] == cnf_mol_idx1) & (df_cnf["atom_idx"] == atom_idx_max_outlier)]
df_line_cnf2 = df_cnf[(df_cnf["mol_idx"] == cnf_mol_idx2) & (df_cnf["atom_idx"] == atom_idx_max_outlier)]
df_line_cnf3 = df_cnf[(df_cnf["mol_idx"] == cnf_mol_idx3) & (df_cnf["atom_idx"] == atom_idx_max_outlier)]

In [None]:
outlier_mol_conf1 = mol_sup_comb[cnf_mol_idx1]
outlier_mol_conf2 = mol_sup_comb[cnf_mol_idx2]
outlier_mol_conf3 = mol_sup_comb[cnf_mol_idx3]

In [None]:
df_line_cnf1

In [None]:
df_line_cnf2

In [None]:
df_line_cnf3

In [None]:
def draw_mol_with_highlights(mol, hit_ats, style=None):
    """Draw molecule in 3D with highlighted atoms. 
    Parameters
    ----------
    mol : RDKit molecule
    hit_ats : tuple of tuples
        atoms to highlight, from RDKit's GetSubstructMatches
    style : dict, optional
        drawing style, see https://3dmol.csb.pitt.edu/doc/$3Dmol.GLViewer.html for some examples
    Returns
    -------
    py3Dmol viewer
    """
    v = py3Dmol.view()
    if style is None: 
        style = {'stick':{'colorscheme':'grayCarbon', "linewidth": 0.1}}
    v.addModel(Chem.MolToMolBlock(mol), "mol") 
    v.setStyle({'model':0},style)
    #hit_ats = [x for tup in hit_ats for x in tup]
    for atom in hit_ats:
        p = mol.GetConformer().GetAtomPosition(atom)
        v.addSphere({"center":{"x":p.x,"y":p.y,"z":p.z},"radius":0.9,"color":'green', "alpha": 0.8})
    v.zoomTo()
    return v

In [None]:
draw_mol_with_highlights(outlier_mol_conf1, [atom_idx_max_outlier])

In [None]:
draw_mol_with_highlights(outlier_mol_conf2, [atom_idx_max_outlier])

In [None]:
draw_mol_with_highlights(outlier_mol_conf3, [atom_idx_max_outlier])

## Amino Acids

In [None]:
mol_sup_aa = Chem.SDMolSupplier(sdf_aa_file_path, removeHs=False)
print(len(mol_sup_aa))

In [None]:
ff = ForceField("openff_unconstrained-2.0.0.offxml")

In [None]:
# calculate am1Bcc charges and time it
start_time = time.time()
matching_time_off = 0
am1Bcc_charges = []
for mol in tqdm(mol_sup_aa):
    molecule = Molecule.from_rdkit(mol, allow_undefined_stereo=True)
    charges_tmp = ff.get_partial_charges(molecule)
    charges = [round(float(item), 5) for item in charges_tmp.value_in_unit(charges_tmp.unit).tolist()]
    am1Bcc_charges.append(charges)
matching_time_off = time.time() - start_time

In [None]:
# calculate gasteiger charges and time it
start_time = time.time()
matching_time_gasteiger = 0
gasteiger_charges = []
for mol in tqdm(mol_sup_aa):
    AllChem.ComputeGasteigerCharges(mol)
    gasteiger_charges.append([round(float(x.GetProp("_GasteigerCharge")), 5) for x in mol.GetAtoms()])
matching_time_gasteiger = time.time() - start_time

In [None]:
# calculate rdkit MMFF charges and time it
start_time = time.time()
matching_time_mmff = 0
mmff_charges = []
for mol in tqdm(mol_sup_aa):
    try:
        mm = AllChem.MMFFGetMoleculeProperties(mol)
        mmff_charges.append([round(float(mm.GetMMFFPartialCharge(x)),5) for x in range(mol.GetNumAtoms())])
    except:
        mmff_charges.append([0 for x in range(mol.GetNumAtoms())])
matching_time_mmff = time.time() - start_time

In [None]:
# calculate tree charges and time it
start_time = time.time()
matching_time_tree = 0
tree_charges = []
for mol in tqdm(mol_sup_aa):
    try:
        charges = test_tree_pruned.match_molecule_atoms(mol, attention_threshold=5.5)[0]
        tree_charges.append(charges)
    except:
        tree_charges.append([0 for x in range(mol.GetNumAtoms())])
matching_time_tree = time.time() - start_time

In [None]:
print(f"matching_time_off: {matching_time_off}")
print(f"matching_time_gasteiger: {matching_time_gasteiger}")
print(f"matching_time_mmff: {matching_time_mmff}")
print(f"matching_time_tree: {matching_time_tree}")

In [None]:
# get MBIS reference charges and atom elements from sdf
mbis_charges = []
xtb_charges = []
formal_charges = []
elements = []
mol_idx = []
atom_idx = []
i = 0
for mol in tqdm(mol_sup_aa):
    mbis_charge = [float(x) for x in mol.GetProp("MBIS_CHARGES").split("|")]
    mbis_charges.append(mbis_charge)
    xtb_charge = [float(x) for x in mol.GetProp("XTB_MulikenCharge").split("|")]
    xtb_charges.append(xtb_charge)
    formal_charges.append([x.GetFormalCharge() for x in mol.GetAtoms()])
    elements.append([x.GetSymbol() for x in mol.GetAtoms()])
    mol_idx.append([i for x in range(mol.GetNumAtoms())])
    atom_idx.append([x for x in range(mol.GetNumAtoms())])
    i += 1

In [None]:
resp_charges = pickle.load(open(f"{folder_aa}../resp_charges_dict.pickle", "rb"))

In [None]:
df_aa = pd.DataFrame({
    "mol_idx": [item for sublist in mol_idx for item in sublist],
    "atom_idx": [item for sublist in atom_idx for item in sublist],
    "element": [item for sublist in elements for item in sublist],
    "mbis_charge": [item for sublist in mbis_charges for item in sublist],
    "xtb_charge": [item for sublist in xtb_charges for item in sublist],
    "am1Bcc_charge": [item for sublist in am1Bcc_charges for item in sublist],
    "gasteiger_charge": [item for sublist in gasteiger_charges for item in sublist],
    "mmff_charge": [item for sublist in mmff_charges for item in sublist],
    "tree_charge": [item for sublist in tree_charges for item in sublist],
    "formal_charge": [item for sublist in formal_charges for item in sublist],
})
df_aa["resp_charge"] = df_aa.apply(lambda x: resp_charges[x["mol_idx"]][0][x["atom_idx"]], axis=1)
df_aa.dropna(inplace=True)

In [None]:
df_aa.head(2)

In [None]:
print(np.abs(df_aa["mbis_charge"] - df_aa["resp_charge"]).mean())
print(np.abs(df_aa["mbis_charge"] - df_aa["tree_charge"]).mean())

In [None]:
print(np.sqrt(np.mean((df_aa["mbis_charge"] - df_aa["tree_charge"])**2)))

In [None]:
rmse_list_aa = []
r2_list_aa = []
mae_list = []
col_names = ["am1Bcc_charge", "tree_charge", "mmff_charge", "gasteiger_charge", "xtb_charge", "resp_charge"]
# Titles: "A) AM1-BCC","B) DASH","C) MMFF","D) Gasteiger", "E) XTB", "F) RESP"
for col in col_names:
    rmse_list.append(np.sqrt(np.mean((df_aa["mbis_charge"] - df_aa[col])**2)))      
    r2_list.append(np.corrcoef(df_aa["mbis_charge"], df_aa[col])[0,1]**2)
    mae_list.append(np.abs(df_aa["mbis_charge"] - df_aa[col]).mean())

In [None]:
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : 12}
plt.rc('font', **font)
xy_range = [[-1.2,1.2],[-1.2,1.2]]
fig, axs = plt.subplots(3, 2, figsize=(10,15), sharex=True, sharey=True)
h1 = axs[0,0].hist2d(df_aa["mbis_charge"], df_aa["am1Bcc_charge"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=0.1, vmax=10))
axs[0,0].set_xlabel("")#MBIS charge [e]]")
axs[0,0].xaxis.set_label_position('top')
axs[0,0].xaxis.tick_top()
axs[0,0].set_ylabel("charge [e]")#AM1-BCC charge [e]")
h2 = axs[0,1].hist2d(df_aa["mbis_charge"], df_aa["tree_charge"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=0.1, vmax=10))
axs[0,1].set_xlabel("")#MBIS charge [e]")
axs[0,1].xaxis.set_label_position('top')
axs[0,1].xaxis.tick_top()
axs[0,1].set_ylabel("")#Dash charge [e]")
axs[0,1].yaxis.set_label_position('right')
axs[0,1].yaxis.tick_right()
h3 = axs[1,0].hist2d(df_aa["mbis_charge"], df_aa["mmff_charge"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=0.1, vmax=10))
axs[1,0].set_xlabel("")#MBIS charge [e]")
axs[1,0].set_ylabel("charge [e]")#MMFF charge [e]")
h4 = axs[1,1].hist2d(df_aa["mbis_charge"], df_aa["gasteiger_charge"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=0.1, vmax=10))
axs[1,1].set_xlabel("")#MBIS charge [e]")
axs[1,1].set_ylabel("")#Gasteiger charge [e]")
axs[1,1].yaxis.set_label_position('right')
axs[1,1].yaxis.tick_right()
h5 = axs[2,0].hist2d(df_aa["mbis_charge"], df_aa["xtb_charge"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=0.1, vmax=10))
axs[2,0].set_xlabel("MBIS charge [e]")
axs[2,0].set_ylabel("charge [e]")#XTB charge [e]")
h6 = axs[2,1].hist2d(df_aa["mbis_charge"], df_aa["resp_charge"], bins=100, range=xy_range, cmap="Greens", norm=LogNorm(vmin=0.1, vmax=10))
axs[2,1].set_xlabel("MBIS charge [e]")
axs[2,1].set_ylabel("")#RESP charge [e]")
axs[2,1].yaxis.set_label_position('right')
axs[2,1].yaxis.tick_right()

# draw diagonal lines and add axis labels
for ax in axs.flat:
    ax.plot(xy_range[0], xy_range[1], color="grey", linestyle="--")
plt.subplots_adjust(wspace=0, hspace=0)
# add colorbar
cbar_ax = fig.add_axes([0.95, 0.15, 0.05, 0.7])
cbar = fig.colorbar(h1[3], cax=cbar_ax, label="Counts (log scale) [a.u.]")
# write name as text in plot
for name, ax, rmse_i, mae_i in zip(["A) AM1-BCC","B) DASH","C) MMFF","D) Gasteiger", "E) XTB", "F) RESP"], axs.flat, rmse_list, mae_list):
    ax.text(-1.15, 1, f"{name}", fontsize=16, fontweight="bold")
    ax.text(-0.9, 0.7, f"RMSE: {rmse_i:.2f}\u2009e\nMAE: {mae_i:.2f}\u2009e", fontsize=12)
#fig.suptitle("Comparison of different charge models to the MBIS reference\n Dataset: 20 amino acids", fontsize=20, fontweight="bold")
plt.savefig("test_123_aa_charges.pdf", bbox_inches="tight")
plt.show()

In [None]:
matching_time_gnn = 1.021888
matching_time_mbis = 8490
matching_time_resp = 3*3600 + 2*60 + 56

In [None]:
df_time = pd.DataFrame({"method": ["AM1-BCC", "Gasteiger", "MMFF", "DASH", "RESP", "GNN", "MBIS"], "time [s]": [matching_time_off, matching_time_gasteiger, matching_time_mmff, matching_time_tree, matching_time_resp, matching_time_gnn, matching_time_mbis]})
df_time.sort_values(by="time [s]", inplace=True)

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))
df_time.plot.bar(x="method", y="time [s]", logy=True, figsize=(5, 5), legend=False, color="C2", ax=ax)
ax.set_ylabel("time [s]")
fig.savefig("test_123_aa_time.pdf", bbox_inches="tight")


## Symmetry

In [None]:
def symmetrizeTerminalAtoms(mol,atomPattern='O,N;D1'):
    qry = Chem.MolFromSmarts(f'[{atomPattern};$([{atomPattern}]-[*]=[{atomPattern}]),$([{atomPattern}]=[*]-[{atomPattern}])]~[*]')
    qb = Chem.BondFromSmarts('-,=')
    res = Chem.RWMol(mol)
    matches = mol.GetSubstructMatches(qry)
    for idx1,idx2 in matches:
        bnd = res.GetBondBetweenAtoms(idx1,idx2)
        res.ReplaceBond(bnd.GetIdx(),qb)
        # adjust charges:
        at = res.GetAtomWithIdx(idx1)
        at.SetFormalCharge(0)
    return res

In [None]:
def compare_charges_on_symmetric_atoms(mol, mol_idx):
    df_entry_list = []
    sm = symmetrizeTerminalAtoms(mol)
    symCanonRanks = list(Chem.CanonicalRankAtoms(sm,breakTies=False))
    tree_charges = test_tree_pruned.match_molecule_atoms(mol)[0]
    symCanonRanksSet = set(symCanonRanks)
    for rank in symCanonRanksSet:
        atom_idx_with_rank = np.where(np.array(symCanonRanks) == rank)[0]
        if len(atom_idx_with_rank) > 1:
            std_sym_charges = np.std([tree_charges[atom_idx] for atom_idx in atom_idx_with_rank])
            max_diff_sym_charges = np.max([tree_charges[atom_idx] for atom_idx in atom_idx_with_rank]) - np.min([tree_charges[atom_idx] for atom_idx in atom_idx_with_rank])
            element = mol.GetAtomWithIdx(int(atom_idx_with_rank[0])).GetSymbol()
            af = AtomFeatures.atom_features_from_molecule(mol, int(atom_idx_with_rank[0]))
            df_entry_list.append({"mol_idx":mol_idx, "atom_idx":atom_idx_with_rank, "std_sym_charges":std_sym_charges, "element":element, "af":af, "max_diff_sym_charges":max_diff_sym_charges})
    return df_entry_list  

In [None]:
#df_entries = []
#for mol_idx, mol in tqdm(enumerate(mol_sup_comb), total=len(mol_sup_comb)):
#    df_entries.extend(compare_charges_on_symmetric_atoms(mol, mol_idx))
#df_sym = pd.DataFrame(df_entries)
#df_sym.to_csv(f"{folder_4charges}test_113_symmetric_atom_charges.csv", index=False)
df_sym = pd.read_csv(f"{folder_4charges}test_113_symmetric_atom_charges.csv")

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
df_sym["max_diff_sym_charges"].hist(bins=100, log=True, ax=ax, color="C2")
ax.set_xlabel("Absolute difference of symmetric atoms [e]")
ax.set_ylabel("Count")
ax.set_xlim(0,0.3)
fig.savefig("test_123_symmetric_atom_charges_max_diff.pdf", bbox_inches="tight")

In [None]:
df_sym.head(2)

In [None]:
df_sym.iloc[0].atom_idx.strip("[]").split(" ")

In [None]:
df_sym.sort_values(by="max_diff_sym_charges", ascending=False, inplace=True)

In [None]:
mol_list = []
highlight_list = []
legend_list = []
for line in df_sym.head(51).itertuples():
    try:
        mol = mol_sup_comb[int(line.mol_idx)]
        mbis_charges = [float(x) for x in mol.GetProp("MBIScharge").split("|")]
        AllChem.Compute2DCoords(mol)
        atom_idx_pars = [int(x) for x in line.atom_idx.strip("[]").split(" ")]
        highlighted_charges = [mbis_charges[atom_idx] for atom_idx in atom_idx_pars]
        legend = f"af: {line.af}, std: {line.std_sym_charges:.3f}\n MBIS: "
        for i, chg in enumerate(highlighted_charges):
            legend += f" {chg:.3f}"
        legend_list.append(legend)
        mol_list.append(mol)
        highlight_list.append(atom_idx_pars)
    except:
        pass

In [None]:
len(mol_list)

In [None]:
print(legend_list[0])
print(highlight_list[0])
mol_list[0]

In [None]:
Chem.Draw.MolsToGridImage(mol_list)

In [None]:
plotGrid = Chem.Draw.MolsToGridImage(mol_list, highlightAtomLists=highlight_list, legends=legend_list, molsPerRow=3, subImgSize=(300,300), useSVG=True)
with open("test_123_symmetric_atom_charges_max_MolGrid.svg", "w") as f:
    f.write(plotGrid.data)
plotGrid
#plt.savefig("test_123_symmetric_atom_charges_max_MolGrid.pdf", bbox_inches="tight")
#plt.show()

In [None]:
from IPython.display import display, Javascript
display(Javascript('IPython.notebook.save_checkpoint();'))