In [None]:
import pandas as pd
import numpy as np
import io
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import rdMolDraw2D
import IPython.display
from collections import defaultdict
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from serenityff.charge.tree.dash_tree import DASHTree
from rdkit.Chem.Draw.IPythonConsole import drawMol3D
from rdkit.Chem.MolStandardize import rdMolStandardize
import pickle
from rdkit.Chem import rdDetermineBonds
import glob
import py3Dmol
plt.rcParams.update({'font.size': 16})
try:
    import IPython.display
except ImportError:
    pass
from PIL import Image
from rdkit.Chem.Draw import rdMolDraw2D
from collections import defaultdict

In [None]:
def dash_corr_plot(df, x, y, fig, ax, xy_range, xlabel=None, ylabel=None, vmin=0.1, vmax=10000, text=None, err_range=(-0.5,0.5), stats=False):
    df_plot = df[[x,y]].dropna()
    h1 = ax.hist2d(df_plot[x], df_plot[y], bins=100, cmap="Greens", norm=LogNorm(vmin=vmin, vmax=vmax), range=(xy_range, xy_range))
    if xlabel is None:
        xlabel = x
    if ylabel is None:
        ylabel = y
    ax.set_xlabel(f"{xlabel}")
    ax.set_ylabel(f"{ylabel}")
    ax.plot(xy_range, xy_range, color='grey', linestyle=':')
    if text is not None:
        ax.text(0.05, 0.95, text, transform=ax.transAxes,
        fontsize=16, fontweight='bold', va='top')
    if stats:
        rmse = np.sqrt(np.mean((df_plot[x]-df_plot[y])**2))
        r2 = df_plot[[x,y]].corr().iloc[0,1]**2
        tau = df_plot[[x,y]].corr(method="kendall").iloc[0,1]
        ax.text(0.05, 0.85, f"RMSE: {rmse:.3f}\nR2: {r2:.3f}\nTau: {tau:.3f}", transform=ax.transAxes,
        fontsize=12, va='top')
    ax_ins = ax.inset_axes([0.69 , 0.06, 0.3, 0.3], frameon=True)
    err = df_plot[x]-df_plot[y]
    ax_ins.hist(err, bins=100, color="C2", range=err_range)
    ax_ins.set_xlabel(f"error", fontsize=11)
    ax_ins.tick_params(axis='both', which='major', labelsize=11)
    ax_ins.xaxis.set_label_position("top")
    ax_ins.yaxis.set_ticklabels([])
    ax_ins.set_yticks([])
    fig.subplots_adjust(right=0.9)
    cbar_ax = fig.add_axes([0.91, 0.15, 0.05, 0.7])
    cbar = fig.colorbar(h1[3], cax=cbar_ax)
    cbar.ax.set_ylabel('Counts', rotation=270, labelpad=15, fontsize=14)

In [None]:
sdf_file = f"./sdf_qmugs500_mbis_collect.sdf"
df_atom = pd.read_csv("./test_184_atomData_grouped.csv")
df_mol = pd.read_csv("./test_184_molData_withMBIS_ref.csv")

In [None]:
df_atom["mbis"] = df_atom["MBIScharge"]

In [None]:
prop_keys = ['mbis', 'mulliken', 'resp1', 'resp2', 'dual', 'mbis_dipole_strength', 'dipole_bond_1']
charge_range = (-2.5, 2.5)
prop_range = {"mbis": charge_range, "mulliken": charge_range, "resp1": charge_range, "resp2": charge_range, "dual": (-1,1), "mbis_dipole_strength": (0, 1), "dipole_bond_1": (0, 1)}
prop_labels_x = {"mbis": "MBIS charge [e]", "mulliken": "Mulliken charge [e]", "resp1": "RESP1 charge [e]", "resp2": "RESP2 charge [e]", "dual": "Dual Descriptor", "mbis_dipole_strength": "MBIS dipole strength [eA]", "dipole_bond_1": "Dipole bond 1 [eA]"}
prop_labels_y = {"mbis": "DASH prediction [e]", "mulliken": "DASH prediction [e]", "resp1": "DASH prediction [e]", "resp2": "DASH prediction [e]", "dual": "DASH prediction", "mbis_dipole_strength": "DASH prediction [eA]", "dipole_bond_1": "DASH prediction [eA]"}

In [None]:
for prop in prop_keys:
    x_prop = prop
    y_prop = f"{prop}_pred"
    fig, ax = plt.subplots(1,1, figsize=(6,6))
    dash_corr_plot(df_atom, x_prop, y_prop, fig, ax, prop_range[prop], stats=True, 
                   xlabel=prop_labels_x[prop], ylabel=prop_labels_y[prop], vmin=0.2, vmax=1e5)
    fig.savefig(f"./test_185/test_185_{prop}_corr.pdf", bbox_inches="tight")
    fig.savefig(f"./test_185/test_185_{prop}_corr.svg", bbox_inches="tight")
    fig.show()
    

In [None]:
df_mol

In [None]:
df_mol["mol_dipole_with_atomic_eA"] = df_mol["mol_dipole_with_atomic"]*0.393430307
df_mol["mol_dipole_no_atomic_eA"] = df_mol["mol_dipole_no_atomic"]*0.393430307

In [None]:
df_mol_cnf_grouped = df_mol.groupby("DASH_IDX").median()

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,6))
dash_corr_plot(df_mol_cnf_grouped, "mol_dipole_with_atomic_eA", "mol_dipole_from_mbis_ref", fig, ax, (0, 5), stats=True, 
               xlabel="Molecular dipole [eA]", ylabel="DASH prediction [eA]", vmin=0.2, vmax=100, err_range=(-1,1))
fig.savefig(f"./test_185/test_185_mol_dipole_corr.pdf", bbox_inches="tight")
fig.savefig(f"./test_185/test_185_mol_dipole_corr.svg", bbox_inches="tight")

In [None]:
df_am1bcc = pd.read_hdf("./test_145_am1bcc_prediction_df.h5", key="df")

In [None]:
df_c6_atom = pd.read_csv("./test_143_c6_prediction_df.csv")
df_c6_mol = pd.read_csv("./test_143_c6_mol_prediction_df.csv")

In [None]:
selected_dash_idx_set = set(df_mol["DASH_IDX"].unique())

In [None]:
dftd4_sdf_path = "/localhome/mlehner/test170_dftd4/mols_comb_dftd4.sdf"
mol_sup_c6 = Chem.SDMolSupplier(dftd4_sdf_path, removeHs=False)

In [None]:
selected_atom_mask = []
selected_mol_mask = []
counter_qmugs = 0
counter_rest = 0
for mol_idx, mol in tqdm(enumerate(mol_sup_c6), total=len(mol_sup_c6)):
    tmp_dash_idx = None
    if mol.HasProp("CHEMBL_ID"):
        counter_qmugs += 1
        tmp_dash_idx = f"QMUGS500_{counter_qmugs}"
    else:
        counter_rest += 1
        tmp_dash_idx = f"REST_{counter_rest}"
        
    if tmp_dash_idx in selected_dash_idx_set:
        selected_mol_mask.append(True)
        selected_atom_mask.extend([True]*mol.GetNumAtoms())
    else:
        selected_mol_mask.append(False)
        selected_atom_mask.extend([False]*mol.GetNumAtoms())


In [None]:
print(len(selected_atom_mask))
print(len(df_c6_atom.dropna()))
print("----")
print(len(selected_mol_mask))
print(len(df_c6_mol))

In [None]:
df_c6_mol = df_c6_mol[selected_mol_mask]

In [None]:
df_c6_atom = df_c6_atom.dropna()[selected_atom_mask[:len(df_c6_atom.dropna())]]

In [None]:
df_c6_atom

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,6))
dash_corr_plot(df_c6_atom, "c6", "c6_pred", fig, ax, (0, 200), stats=True,
                xlabel=r"C6 [au Bohr$^6$]", ylabel=r"DASH prediction [au Bohr$^6$]", vmin=0.2, vmax=1000, err_range=(-5,5))
fig.savefig(f"./test_185/test_185_c6_corr.pdf", bbox_inches="tight")
fig.savefig(f"./test_185/test_185_c6_corr.svg", bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,6))
dash_corr_plot(df_c6_atom, "polar", "polar_pred", fig, ax, (0, 40), stats=True,
                xlabel=r"Polarizability [au]", ylabel=r"DASH prediction [au]", vmin=0.2, vmax=1000, err_range=(-1,1))
fig.savefig(f"./test_185/test_185_polar_corr.pdf", bbox_inches="tight")
fig.savefig(f"./test_185/test_185_polar_corr.svg", bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,6))
dash_corr_plot(df_c6_mol, "polarization", "polarization_pred", fig, ax, (0, 320), stats=True,
                xlabel=r"Polarization [au]", ylabel=r"DASH prediction [au]", vmin=0.2, vmax=100, err_range=(-10,10))
fig.savefig(f"./test_185/test_185_polarization_corr_mol.pdf", bbox_inches="tight")
fig.savefig(f"./test_185/test_185_polarization_corr_mol.svg", bbox_inches="tight")

In [None]:
mol_sup_am1bcc = Chem.SDMolSupplier("/localhome/mlehner/dash_data/sdf_qmugs500_mbis_collect.sdf", removeHs=False)

In [None]:
mol_sup_am1bcc[0].GetProp("DASH_IDX")

In [None]:
selected_am1bcc_mask = []
for mol_idx, mol in tqdm(enumerate(mol_sup_am1bcc), total=len(mol_sup_am1bcc)):
    tmp_dash_idx = mol.GetProp("DASH_IDX")
    nAtoms = mol.GetNumAtoms()        
    if tmp_dash_idx in selected_dash_idx_set:
        selected_am1bcc_mask.extend([True]*nAtoms)
    else:
        selected_am1bcc_mask.extend([False]*nAtoms)

In [None]:
print(len(selected_am1bcc_mask))
print(len(df_am1bcc))

In [None]:
df_am1bcc = df_am1bcc[selected_am1bcc_mask]

In [None]:
df_am1bcc.head()

In [None]:
# drop all lines where am1bcc is exactly 0
df_am1bcc = df_am1bcc[df_am1bcc["am1bcc"] != 0]

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,6))
dash_corr_plot(df_am1bcc, "am1bcc", "am1bcc_pred", fig, ax, charge_range, stats=True,
                xlabel="AM1-BCC charge [e]", ylabel="DASH prediction [e]", vmin=0.2, vmax=1000, err_range=(-0.5,0.5))

In [None]:
# combine am1bcc, mulliken and resp2
fig, ax = plt.subplots(1,3, figsize=(18,6))
dash_corr_plot(df_am1bcc, "am1bcc", "am1bcc_pred", fig, ax[0], charge_range, stats=True,
                xlabel="AM1-BCC charge [e]", ylabel="DASH prediction [e]", vmin=0.2, vmax=1e4, err_range=(-0.5,0.5), text="A")
dash_corr_plot(df_atom, "mulliken", "mulliken_pred", fig, ax[1], charge_range, stats=True,
                xlabel="Mulliken charge [e]", ylabel="DASH prediction [e]", vmin=0.2, vmax=1e4, err_range=(-0.5,0.5), text="B")
dash_corr_plot(df_atom, "resp2", "resp2_pred", fig, ax[2], charge_range, stats=True,
                xlabel="RESP2 charge [e]", ylabel="DASH prediction [e]", vmin=0.2, vmax=1e4, err_range=(-0.5,0.5), text="C")
# fix cbar size
cbar_axes = fig.axes[3:]
for cbar_ax in cbar_axes:
    cbar_ax.set_position([0.91, 0.15, 0.02, 0.7])


fig.savefig(f"./test_185/test_185_3charges_corr.pdf", bbox_inches="tight")
fig.savefig(f"./test_185/test_185_3charges_corr.svg", bbox_inches="tight")


In [None]:
# combine atomic c6 and polarizability
fig, ax = plt.subplots(1,2, figsize=(12,6))
dash_corr_plot(df_c6_atom, "c6", "c6_pred", fig, ax[0], (0, 200), stats=True,
                xlabel=r"C6 [au Bohr$^6$]", ylabel=r"DASH prediction [au Bohr$^6$]", vmin=0.2, vmax=1e4, err_range=(-5,5), text="A")
dash_corr_plot(df_c6_atom, "polar", "polar_pred", fig, ax[1], (0, 40), stats=True,
                xlabel=r"Polarizability [au]", ylabel=r"DASH prediction [au]", vmin=0.2, vmax=1e4, err_range=(-1,1), text="B")
# fix cbar size
cbar_axes = fig.axes[2:]
for cbar_ax in cbar_axes:
    cbar_ax.set_position([0.91, 0.15, 0.025, 0.7])

fig.savefig(f"./test_185/test_185_C6+polar_corr.pdf", bbox_inches="tight")
fig.savefig(f"./test_185/test_185_C6+polar_corr.svg", bbox_inches="tight")

In [None]:
#df_mol_cnf_grouped_plot = df_mol_cnf_grouped[(df_mol_cnf_grouped["mol_dipole_from_mbis_ref"] < 5) & (df_mol_cnf_grouped["mol_dipole_with_atomic_eA"] < 5)]
df_mol_cnf_grouped_plot = df_mol_cnf_grouped

In [None]:
df_mol_cnf_grouped_plot

In [None]:
# combine molecular dipole and polarization
fig, ax = plt.subplots(1,2, figsize=(12,6))
dash_corr_plot(df_c6_mol, "polarization", "polarization_pred", fig, ax[0], (0, 320), stats=True,
                xlabel=r"Polarizability [au]", ylabel=r"DASH prediction [au]", vmin=0.2, vmax=100, err_range=(-10,10), text="A")
dash_corr_plot(df_mol_cnf_grouped_plot, "mol_dipole_with_atomic_eA", "mol_dipole_from_mbis_ref", fig, ax[1], (0, 5), stats=True, 
               xlabel="Molecular dipole [eA]", ylabel="DASH prediction [eA]", vmin=0.2, vmax=100, err_range=(-1,1), text="B")
# dash_corr_plot(df_mol_cnf_grouped_plot, "mol_dipole_no_atomic_eA", "mol_dipole_from_mbis_ref", fig, ax[1], (0, 5), stats=True, 
#                xlabel="Molecular dipole [eA]", ylabel="DASH prediction [eA]", vmin=0.2, vmax=100, err_range=(-1,1), text="B")
# fix cbar size
cbar_axes = fig.axes[2:]
for cbar_ax in cbar_axes:
    cbar_ax.set_position([0.91, 0.15, 0.025, 0.7])

fig.savefig(f"./test_185/test_185_mol_dipole+polarization_corr.pdf", bbox_inches="tight")
fig.savefig(f"./test_185/test_185_mol_dipole+polarization_corr.svg", bbox_inches="tight")

In [None]:
def dual_confusion_4(x, x_pred, threshold=0.1):
    ret_val = np.zeros((4,4))
    ret_val[0,0] = len(x[(x < -threshold) & (x_pred < -threshold)])
    ret_val[0,1] = len(x[(x < -threshold) & (x_pred > -threshold) & (x_pred < 0)])
    ret_val[0,2] = len(x[(x < -threshold) & (x_pred > 0) & (x_pred < threshold)])
    ret_val[0,3] = len(x[(x < -threshold) & (x_pred > threshold)])
    ret_val[1,0] = len(x[(x > -threshold) & (x < 0) & (x_pred < -threshold)])
    ret_val[1,1] = len(x[(x > -threshold) & (x < 0) & (x_pred > -threshold) & (x_pred < 0)])
    ret_val[1,2] = len(x[(x > -threshold) & (x < 0) & (x_pred > 0) & (x_pred < threshold)])
    ret_val[1,3] = len(x[(x > -threshold) & (x < 0) & (x_pred > threshold)])
    ret_val[2,0] = len(x[(x > 0) & (x < threshold) & (x_pred < -threshold)])
    ret_val[2,1] = len(x[(x > 0) & (x < threshold) & (x_pred > -threshold) & (x_pred < 0)])
    ret_val[2,2] = len(x[(x > 0) & (x < threshold) & (x_pred > 0) & (x_pred < threshold)])
    ret_val[2,3] = len(x[(x > 0) & (x < threshold) & (x_pred > threshold)])
    ret_val[3,0] = len(x[(x > threshold) & (x_pred < -threshold)])
    ret_val[3,1] = len(x[(x > threshold) & (x_pred > -threshold) & (x_pred < 0)])
    ret_val[3,2] = len(x[(x > threshold) & (x_pred > 0) & (x_pred < threshold)])
    ret_val[3,3] = len(x[(x > threshold) & (x_pred > threshold)])
    return ret_val / np.sum(ret_val)

In [None]:
df_atom.head()

In [None]:
confusion_thresh = 0.05
confusion_4 = dual_confusion_4(df_atom["dual"], df_atom["dual_pred"], threshold=confusion_thresh)

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
# scale confusion matrix logarithmically
im = ax.imshow(confusion_4.T, cmap="Greens", norm=LogNorm(vmin=0.001, vmax=1), origin="lower")
ax.set_xticks(np.arange(4))
ax.set_yticks(np.arange(4))
ax.set_xticklabels([f"<{confusion_thresh}",f"{confusion_thresh}-0",f"0-{confusion_thresh}",f">{confusion_thresh}"], fontsize=16)
ax.set_yticklabels([f"<{confusion_thresh}",f"{confusion_thresh}-0",f"0-{confusion_thresh}",f">{confusion_thresh}"], fontsize=16)
ax.set_xlabel("TPSSh Dual Descriptor", fontsize=16)
ax.set_ylabel("DASH Dual Descriptor", fontsize=16)
for i in range(4):
    for j in range(4):
        text = ax.text(j, i, f"{confusion_4.T[i, j]:.2f}", ha="center", va="center", color="black", fontsize=16)
plt.savefig(f"./test_185/test_185_confusion_4.pdf", bbox_inches="tight", dpi=400)
plt.savefig(f"./test_185/test_185_confusion_4.svg", bbox_inches="tight")
plt.savefig(f"./test_185/test_185_confusion_4.png", bbox_inches="tight", dpi=400)

In [None]:
df_atom_plot_dipole_bonds = df_atom[(df_atom["dipole_bond_1"] < 1) & (df_atom["dipole_bond_1_pred"] < 1) & (df_atom["dipole_bond_1"] > -1 ) & (df_atom["dipole_bond_1_pred"] > -1)]

In [None]:
# combine atomic dipole and dipole bond 1
fig, ax = plt.subplots(1,2, figsize=(14,6))
dash_corr_plot(df_atom, "mbis_dipole_strength", "mbis_dipole_strength_pred", fig, ax[0], (0, 1), stats=True,
                xlabel=r"MBIS dipole strength [eA]", ylabel=r"DASH prediction [eA]", vmin=0.2, vmax=1e4, err_range=(-0.01,0.01), text="A")
dash_corr_plot(df_atom_plot_dipole_bonds, "dipole_bond_1", "dipole_bond_1_pred", fig, ax[1], (-1, 1), stats=True,
                xlabel=r"Dipole bond 1 [eA]", ylabel=r"DASH prediction [eA]", vmin=0.2, vmax=1e4, err_range=(-0.01,0.01), text="B")
# fix cbar size
cbar_axes = fig.axes[2:]
for cbar_ax in cbar_axes:
    cbar_ax.set_position([0.91, 0.15, 0.025, 0.7])
ax[0].set_aspect('equal')
ax[1].set_aspect('equal')
plt.subplots_adjust(wspace=0.3)
fig.savefig(f"./test_185/test_185_atomic_dipole_corr_mag_dir.pdf", bbox_inches="tight")
fig.savefig(f"./test_185/test_185_atomic_dipole_corr_mag_dir.svg", bbox_inches="tight")
plt.show()

In [None]:
tree = DASHTree("../../serenityff/charge/data/dashProps/", preload=False)

In [None]:
rmol = mol_sup_am1bcc[42755]
atom_idx = 7

In [None]:
node_path, match_indices = tree.match_new_atom(mol=rmol, atom=atom_idx, return_atom_indices=True)
mbis = [tree.data_storage[node_path[0]].iloc[i]["result"] for i in node_path[1:]]
mbis_std = [tree.data_storage[node_path[0]].iloc[i]["std"] for i in node_path[1:]]
mulliken = [tree.data_storage[node_path[0]].iloc[i]["mulliken"] for i in node_path[1:]]
resp2 = [tree.data_storage[node_path[0]].iloc[i]["resp2"] for i in node_path[1:]]
am1bcc = [tree.data_storage[node_path[0]].iloc[i]["AM1BCC"] for i in node_path[1:]]
am1bcc_std = [tree.data_storage[node_path[0]].iloc[i]["AM1BCC_std"] for i in node_path[1:]]

In [None]:
print(mulliken)

In [None]:
# fill nans with last non-nan value
mbis = pd.Series(mbis).fillna(method="ffill").to_numpy()
mulliken = pd.Series(mulliken).fillna(method="ffill").to_numpy()
resp2 = pd.Series(resp2).fillna(method="ffill").to_numpy()
am1bcc = pd.Series(am1bcc).fillna(method="ffill").to_numpy()

In [None]:
node_path, match_indices = tree.match_new_atom(atom_idx,rmol,return_atom_indices=True)

In [None]:
def draw_mol_with_highlights_in_order(
    mol,
    highlight_atoms=[],
    highlight_bonds=[],
    text_per_atom=[],
    plot_title: str = None,
    plot_size=(600, 400),
    useSVG=False,
):
    #color = (0, 0.6, 0.1)
    color = (0.9, 0.7, 0.25)
    alphas = [1 - i / (len(highlight_atoms) + 4) for i in range(len(highlight_atoms) + 1)]
    athighlights = defaultdict(list)
    bthighlights = defaultdict(list)
    arads = {}
    brads = {}
    for i, atom in enumerate(highlight_atoms):
        athighlights[atom].append((color[0], color[1], color[2], alphas[i]))
        arads[atom] = 0.75
        if len(text_per_atom) < len(highlight_atoms):
            text_per_atom = [str(i) for i in highlight_atoms]
        mol.GetAtomWithIdx(atom).SetProp("atomNote", f"{text_per_atom[i]}")
    for i, bond in enumerate(highlight_bonds):
        bthighlights[bond].append((color[0], color[1], color[2], alphas[i + 1]))
        brads[bond] = 100
    if useSVG:
        d2d = rdMolDraw2D.MolDraw2DSVG(plot_size[0], plot_size[1])
    else:
        d2d = rdMolDraw2D.MolDraw2DCairo(plot_size[0], plot_size[1])
    dopts = d2d.drawOptions()
    dopts.scaleHighlightBondWidth = False
    # remove Hs
    mol_pic = Chem.RemoveHs(mol)
    AllChem.Compute2DCoords(mol_pic)
    if plot_title is not None:
        dopts.legendFontSize = 30
        d2d.DrawMoleculeWithHighlights(mol_pic, plot_title, dict(athighlights), dict(bthighlights), arads, brads)
    else:
        d2d.DrawMoleculeWithHighlights(mol_pic, "", dict(athighlights), dict(bthighlights), arads, brads)
    d2d.FinishDrawing()
    if useSVG:
        if not IPython:
            raise ImportError("IPython is not available, cannot use SVG")
        p = d2d.GetDrawingText().replace("svg:", "")
        img = IPython.display.SVG(data=p)
    else:
        bio = io.BytesIO(d2d.GetDrawingText())
        img = Image.open(bio)
    return img

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(16, 6), gridspec_kw={'width_ratios': [1, 1.5]})
# slightly staggered to avoid overlap
x_axis_mulliken = np.arange(len(mbis)) - 0.15
x_axis_mbis = np.arange(len(mbis)) - 0.05
x_axis_resp2 = np.arange(len(mbis)) + 0.05
x_axis_am1bcc = np.arange(len(mbis)) + 0.15
ax[0].errorbar(x_axis_mbis, mbis, yerr=mbis_std, fmt="o", color="#1f77b4")
ax[0].errorbar(x_axis_resp2, resp2, yerr=am1bcc_std, fmt="o", color="#ff7f0e")
ax[0].errorbar(x_axis_mulliken, mulliken, yerr=am1bcc_std, fmt="o", color="#2ca02c")
ax[0].errorbar(x_axis_am1bcc, am1bcc, yerr=am1bcc_std, fmt="o", color="#9467bd")
ax[0].set_xticks(range(len(mbis)))
#ax[0].set_ylim([-2.5,2.5])
ax[0].set_xlabel("DASH tree node")
ax[0].set_ylabel("Charge [e]")
ax[0].legend(["Mulliken", "RESP2", "MBIS", "AM1-BCC"])

ax[1].axis('off')
#im = Chem.Draw.MolToImage(rmol, size=(800, 800), fitImage=True, highlightAtoms=[atom_idx], highlightColor=(0.8, 0, 0.8))
#ax[1].imshow(im, resample=False, interpolation='bilinear')
ax[1].imshow(draw_mol_with_highlights_in_order(rmol, highlight_atoms=match_indices, text_per_atom=range(len(match_indices)), plot_size=(900,600), useSVG=False), resample=False, interpolation='bilinear')

fig.savefig("./test_185/test_185_dashTree_explain4Chg.pdf", bbox_inches="tight")
fig.savefig("./test_185/test_185_dashTree_explain4Chg.svg", bbox_inches="tight")