# Evaluate metrics for different methods

In [None]:
import os
import re

import pandas as pd
import numpy as np
from rdkit import Chem
import pickle
import matplotlib.pyplot as plt
import seaborn as sns

## Load data

Set params

In [None]:
# reference dataset
ref = 'test'

# dirs of sampling outputs of different methods
generated_dict = {
    'MolDiff': '../outputs/sample_MolDiff_20230101_0000',
    'MolDiff_simple': '../outputs/sample_MolDiff_simple_20230101_0000',
}

# metric file of dataset
df_path_dict = {
    'test': '../data/geom_drug/metrics/test.csv',
}

Preapare

In [None]:
for key, value in generated_dict.items():
    df_path_dict[key] = os.path.join(value, 'mols.csv')
df_path_dict

In [None]:
method_list = list(df_path_dict.keys())
compare_list = list(generated_dict.keys())

idx_ref = list(df_path_dict.keys()).index(ref)
df_all = pd.DataFrame(index=method_list)

print('Ref is', ref, 'idx', idx_ref)
print('methods:', method_list)
print('compare:', compare_list)

Load

In [None]:
# load df
df_dict = {key:pd.read_csv(path, index_col=0) for key, path in df_path_dict.items()}

# load local3d
local3d_dict = {}
for key, path in df_path_dict.items():
    path = path.replace('.csv', '_local3d.pkl')
    if os.path.exists(path):
        with open(path, 'rb') as f:
            local3d_dict[key] = pickle.load(f)
    else:
        print(f'No local3d for {key}')

In [None]:
print('dataframe shape:')
for key, df in df_dict.items():
    print(key, df.shape)

Define functions

In [None]:
from scipy.spatial.distance import jensenshannon
def get_jsd(p, q):  # actually use js div instead of kld
    return jensenshannon(p, q)

In [None]:
def compare_with_ref(value_list, width=None, num_bins=50, discrete=False):

    # set distribution ranges
    all_list = np.concatenate(value_list)
    all_list = all_list[~np.isnan(all_list)]
    all_list_sort = np.sort(all_list)
    max_value = all_list_sort[-5]
    min_value = all_list_sort[5]
    if not discrete:
        if width is not None:
            bins = np.arange(min_value, max_value+width, width)
        else:
            bins = np.linspace(min_value, max_value, num_bins)
    else:
        bins = np.arange(min_value, max_value+1.5) - 0.5

    # calculate distributions
    hist_list = []
    for metric_method in value_list:
        hist, _ = np.histogram(metric_method, bins=bins, density=True)
        hist = hist + 1e-10
        hist = hist / hist.sum()
        hist_list.append(hist)
        
    # calculate jsd
    jsd_list = []
    for i, hist in enumerate(hist_list):
        jsd = get_jsd(hist_list[idx_ref], hist)
        jsd_list.append(jsd)
        
    return (jsd_list, bins, hist_list)

In [None]:
from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage
from rdkit import Chem
def show(x):
    print(Chem.MolToSmiles(x))
    IPythonConsole.drawMol3D(x)
    return x

def show_mols(mols):
    mols2d = [Chem.MolFromSmiles(Chem.MolToSmiles(x)) for x in mols]
    return Chem.Draw.MolsToGridImage(mols2d, molsPerRow=8, subImgSize=(250,200))

## Generation ability

validity, connectivity

In [None]:
metrics_list = ['validity', 'connectivity']
df_metrics = pd.DataFrame(index=compare_list, columns=metrics_list)
for method in compare_list:
    path = df_path_dict[method].replace('.csv', '_validity.pkl')
    if not os.path.exists(path):
        print(f'No validity file for {method}')
        continue
    with open(path, 'rb') as f:
        values = pickle.load(f)
    df_metrics.loc[method] = values

In [None]:
df_metrics

In [None]:
df_all = pd.concat([df_all, df_metrics], axis=1)

novelty, uniqueness, diversity,

In [None]:
metrics_list = ['novelty', 'uniqueness', 'diversity', 'sim_with_val']
df_metrics = pd.DataFrame(index=compare_list, columns=metrics_list)
for method in compare_list:
    path = df_path_dict[method].replace('.csv', '_similarity.pkl')
    with open(path, 'rb') as f:
        values = pickle.load(f)
    df_metrics.loc[method] = values

In [None]:
df_metrics

In [None]:
df_all = pd.concat([df_all, df_metrics], axis=1)

## Drug-likeness properties

qed sa logp lipinski

In [None]:
metrics_list

In [None]:
metrics_list = ['qed', 'sa', 'logp', 'lipinski']

df_value = pd.DataFrame(index=compare_list, columns=metrics_list)
for method in compare_list:
    for metric in metrics_list:
        df_value.loc[method, metric] = df_dict[method][metric].mean()
print('value mean')
df_value

In [None]:
df_value.columns = ['mean_{}'.format(x) for x in df_value.columns]
df_all = pd.concat([df_all, df_value], axis=1)

## Bonds

Distributions of bonds

In [None]:
metric = ['cnt_bond1', 'cnt_bond2', 'cnt_bond3', 'cnt_bond4']
df_metrics = pd.DataFrame(index=method_list, columns=['dist_bond'])

hist_list = []
for method in method_list:
    count = df_dict[method][metric].values.sum(axis=0) + 1e-10
    hist_ = count / count.sum()
    hist_list.append(hist_)
bins = np.arange(len(metric)+1) - 0.5
    
jsd_list = []
for hist_ in hist_list:
    jsd = get_jsd(hist_, hist_list[idx_ref])
    jsd_list.append(jsd)
df_metrics['dist_bond'] = jsd_list
print(['{}:{:.4f}'.format(key, value) for key, value in zip(method_list, jsd_list)])

# plot
plt.figure(figsize=(15, 5))

n_hist = len(hist_list)
bar_width = 1 / (n_hist+1)
for i, hist_ in enumerate(hist_list):
    # plt.plot(bins_center, hist_, label=method_list[i], marker='o')
    plt.bar(bins[:-1]+bar_width*(i+1), hist_, label=method_list[i], width=bar_width)
plt.legend()
# plt.xtick_labels(metric)
plt.show()


In [None]:
df_metrics

In [None]:
df_all = pd.concat([df_all, df_metrics], axis=1)

Count of bonds/atoms, rings

In [None]:
metrics_list = ['n_rings', 'n_bonds_per_atom']

df_metrics = pd.DataFrame(index=method_list, columns=metrics_list)
for metric in metrics_list:
    # set width and discrete
    width = 0.5 
    discrete = True 
    
    # get jsd
    if 'per' not in metric:
        values_list = [df[metric].values for df in df_dict.values()]
    else:
        width = 0.01
        discrete = False
        m1 = 'n_' + metric.split('_')[1]
        m2 = 'n_' + metric.split('_')[-1] + 's'
        values_list = [df[m1].values / df[m2].values for df in df_dict.values()]
    jsd_list, bins, hist_list = compare_with_ref(values_list, width=width,
                                                 discrete=discrete)
    df_metrics[metric] = jsd_list
    print(['{}:{:.4f}'.format(key, value) for key, value in zip(method_list, jsd_list)])
    print('num bins', len(bins), 'width', bins[1]-bins[0])

    # plot
    plt.figure(figsize=(10, 5))
    for i, hist_ in enumerate(hist_list):
        bins_center = (bins[:-1] + bins[1:]) / 2
        if not discrete:
            plt.plot(bins_center, hist_, label=method_list[i])
        else:
            n_hist = len(hist_list)
            bar_width = 1 / (n_hist+1)
            # plt.plot(bins_center, hist_, label=method_list[i], marker='o')
            plt.bar(bins[:-1]+bar_width*(i+1), hist_, label=method_list[i], width=bar_width)
    plt.legend()
    plt.title(metric)
    plt.show()


In [None]:
df_metrics

In [None]:
df_all = pd.concat([df_all, df_metrics], axis=1)

## 3D structure

global rmsd

In [None]:
metrics_list = ['rmsd_min']

df_values = pd.DataFrame(index=method_list, columns=metrics_list)
for metric in metrics_list:
    for method in method_list:
        df_values.loc[method, metric] = df_dict[method][metric].mean()
print('value mean')
df_values

In [None]:
df_values.columns = ['mean_{}'.format(x) for x in df_values.columns]
df_all = pd.concat([df_all, df_values], axis=1)

local 3D: bond lengths

In [None]:
# bond lengths
metric_base = 'lengths'
metrics_list = list(local3d_dict[ref][metric_base].keys())
print(metric_base, ':', metrics_list, '\n')

df_metrics = pd.DataFrame(index=method_list, columns=metrics_list)
for metric in metrics_list:
    # set width and discrete
    width = 0.02
    discrete = False
    
    # get jsd
    values_list = [local3d_dict[key][metric_base][metric] for key in method_list]
    jsd_list, bins, hist_list = compare_with_ref(values_list, width=width,
                                                 discrete=discrete)
    df_metrics[metric] = jsd_list
    print('num bins', len(bins), 'width', bins[1]-bins[0])
    print('num values', [len(values) for values in values_list])
    print(['{}:{:.4f}'.format(key, value) for key, value in zip(method_list, jsd_list)])

    # plot
    plt.figure(figsize=(10, 5))
    for i, hist_ in enumerate(hist_list):
        bins_center = (bins[:-1] + bins[1:]) / 2
        if not discrete:
            plt.plot(bins_center, hist_, label=method_list[i])
        else:
            # plt.plot(bins_center, hist_, label=method_list[i], marker='o')
            plt.bar(bins_center, hist_, label=method_list[i], )
    plt.legend()
    plt.title(metric)
    plt.show()


In [None]:
df_metrics

In [None]:
df_all = pd.concat([df_all, df_metrics], axis=1)

In [None]:
df_mean = df_metrics.mean(1).to_frame(name='length_jsd_mean')
df_mean

In [None]:
df_all = pd.concat([df_all, df_mean], axis=1)

BTW: JS. of frequent bond types

In [None]:
metric_base = 'lengths'

# get jsd
values_list = [np.array([len(local3d_dict[key][metric_base][metric]) + 1e-10 for metric in metrics_list]) for key in method_list]
hist_list = [val/np.sum(val) for val in values_list]
jsd_list = [get_jsd(hist_list[idx_ref], hist) for hist in hist_list]

print(['{}:{:.4f}'.format(key, value) for key, value in zip(method_list, jsd_list)])

# plot
bins = np.arange(len(metrics_list)+1)
plt.figure(figsize=(10, 5))
for i, hist_ in enumerate(hist_list):
    n_hist = len(hist_list)
    bar_width = 1 / (n_hist+1)
    plt.bar(bins[:-1]+bar_width*(i+1), hist_, label=method_list[i], width=bar_width)
plt.legend()
plt.title(metric_base)
plt.show()
df_metric = pd.DataFrame(jsd_list, index=method_list, columns=['js_bond_type'])
df_metric

In [None]:
df_all = pd.concat([df_all, df_metric], axis=1)

local 3D: bond angles

In [None]:
metric_base = 'angles'
metrics_list = list(local3d_dict[ref][metric_base].keys())
print(metric_base, ':', metrics_list, '\n')

df_metrics = pd.DataFrame(index=method_list, columns=metrics_list)
for metric in metrics_list:
    # set width and discrete
    width = 5
    discrete = False
    
    # get jsd
    values_list = [local3d_dict[key][metric_base][metric] for key in method_list]
    jsd_list, bins, hist_list = compare_with_ref(values_list, width=width,
                                                 discrete=discrete)
    df_metrics[metric] = jsd_list
    print('num bins', len(bins), 'width', bins[1]-bins[0])
    print('num values', [len(values) for values in values_list])
    print(['{}:{:.4f}'.format(key, value) for key, value in zip(method_list, jsd_list)])

    # plot
    plt.figure(figsize=(10, 5))
    for i, hist_ in enumerate(hist_list):
        bins_center = (bins[:-1] + bins[1:]) / 2
        if not discrete:
            plt.plot(bins_center, hist_, label=method_list[i])
        else:
            # plt.plot(bins_center, hist_, label=method_list[i], marker='o')
            plt.bar(bins_center, hist_, label=method_list[i], )
    plt.legend()
    plt.title(metric)
    plt.show()


In [None]:
df_metrics

In [None]:
df_all = pd.concat([df_all, df_metrics], axis=1)

In [None]:
df_mean = df_metrics.mean(1).to_frame(name='angle_jsd_mean')
df_mean

In [None]:
df_all = pd.concat([df_all, df_mean], axis=1)

BTW: JS. of frequent bond pairs

In [None]:
metric_base = 'angles'

# get jsd
values_list = [np.array([len(local3d_dict[key][metric_base][metric]) + 1e-10 for metric in metrics_list]) for key in method_list]
hist_list = [val/np.sum(val) for val in values_list]
jsd_list = [get_jsd(hist_list[idx_ref], hist) for hist in hist_list]

print(['{}:{:.4f}'.format(key, value) for key, value in zip(method_list, jsd_list)])

# plot
bins = np.arange(len(metrics_list)+1)
plt.figure(figsize=(10, 5))
for i, hist_ in enumerate(hist_list):
    n_hist = len(hist_list)
    bar_width = 1 / (n_hist+1)
    plt.bar(bins[:-1]+bar_width*(i+1), hist_, label=method_list[i], width=bar_width)
plt.legend()
plt.title(metric_base)
plt.show()
df_metric = pd.DataFrame(jsd_list, index=method_list, columns=['JS_bond_pair'])
df_metric

In [None]:
df_all = pd.concat([df_all, df_metric], axis=1)

local 3D: dihedral angles

In [None]:
# dihedral angels
metric_base = 'dihedral'
metrics_list = list(local3d_dict[ref][metric_base].keys())
print(metric_base, ':', metrics_list, '\n')

df_metrics = pd.DataFrame(index=method_list, columns=metrics_list)
for metric in metrics_list:
    # set width and discrete
    width = 5
    discrete = False
    
    # get jsd
    values_list = [local3d_dict[key][metric_base][metric] for key in method_list]
    jsd_list, bins, hist_list = compare_with_ref(values_list, width=width,
                                                 discrete=discrete)
    df_metrics[metric] = jsd_list
    print('num bins', len(bins), 'width', bins[1]-bins[0])
    print('num values', [len(values) for values in values_list])
    print(['{}:{:.4f}'.format(key, value) for key, value in zip(method_list, jsd_list)])

    # plot
    plt.figure(figsize=(10, 5))
    for i, hist_ in enumerate(hist_list):
        bins_center = (bins[:-1] + bins[1:]) / 2
        if not discrete:
            plt.plot(bins_center, hist_, label=method_list[i])
        else:
            # plt.plot(bins_center, hist_, label=method_list[i], marker='o')
            plt.bar(bins_center, hist_, label=method_list[i], )
    plt.legend()
    plt.title(metric)
    plt.show()


In [None]:
df_metrics

In [None]:
df_all = pd.concat([df_all, df_metrics], axis=1)

In [None]:
df_mean = df_metrics.mean(1).to_frame(name='dihedral_jsd_mean')
df_mean

In [None]:
df_all = pd.concat([df_all, df_mean], axis=1)

BTW: JS. of frequent bond triplets

In [None]:
metric_base = 'dihedral'

# get jsd
values_list = [np.array([len(local3d_dict[key][metric_base][metric]) + 1e-10 for metric in metrics_list]) for key in method_list]
hist_list = [val/np.sum(val) for val in values_list]
jsd_list = [get_jsd(hist_list[idx_ref], hist) for hist in hist_list]

print(['{}:{:.4f}'.format(key, value) for key, value in zip(method_list, jsd_list)])

# plot
bins = np.arange(len(metrics_list)+1)
plt.figure(figsize=(10, 5))
for i, hist_ in enumerate(hist_list):
    n_hist = len(hist_list)
    bar_width = 1 / (n_hist+1)
    plt.bar(bins[:-1]+bar_width*(i+1), hist_, label=method_list[i], width=bar_width)
plt.legend()
plt.title(metric_base)
plt.show()
df_metric = pd.DataFrame(jsd_list, index=method_list, columns=['JS_bond_triplet'])
df_metric

In [None]:
df_all = pd.concat([df_all, df_metric], axis=1)

## Rings

counts of n-sized rings

In [None]:
metrics_list = [f'cnt_ring{i}' for i in range(3, 10)]
df_metrics = pd.DataFrame(index=method_list, columns=metrics_list)
for metric in metrics_list:
    # set width and discrete
    width = 0.01
    discrete = True
    
    # get jsd
    jsd_list, bins, hist_list = compare_with_ref([df[metric].values for df in df_dict.values()], width=width,
                                                 discrete=discrete)
    df_metrics[metric] = jsd_list
    print(['{}:{:.4f}'.format(key, value) for key, value in zip(method_list, jsd_list)])
    print('num bins', len(bins), 'width', bins[1]-bins[0])

    # plot
    plt.figure(figsize=(10, 5))
    for i, hist_ in enumerate(hist_list):
        bins_center = (bins[:-1] + bins[1:]) / 2
        if not discrete:
            plt.plot(bins_center, hist_, label=method_list[i])
        else:
            n_hist = len(hist_list)
            bar_width = 1 / (n_hist+1)
            # plt.plot(bins_center, hist_, label=method_list[i], marker='o')
            plt.bar(bins[:-1]+bar_width*(i+1), hist_, label=method_list[i], width=bar_width)
    plt.legend()
    plt.title(metric)
    plt.show()


In [None]:
df_metrics

In [None]:
df_all = pd.concat([df_all, df_metrics], axis=1)

In [None]:
df_mean =  df_metrics.mean(1).to_frame('cnt_ringn_mean')
df_mean

In [None]:
df_all = pd.concat([df_all, df_mean], axis=1)

Freq ring

In [None]:
freq_list = []
for method in method_list:
    path = df_path_dict[method].replace('.csv', '_freq_ring_type.pkl')
    with open(path, 'rb') as f:
        freq_ring_type = pickle.load(f)
    freq_list.append(freq_ring_type)

In [None]:
len(freq_list)

In [None]:
# interact set with ref for the top 10 ring types
ref_rings = freq_list[idx_ref]['freq_rings']
df_inter = pd.DataFrame(index=method_list, columns=['intersect_ring_types'])
for i, rings in enumerate(freq_list):
    inter = len(np.intersect1d(rings['freq_rings'], ref_rings))
    df_inter.loc[method_list[i], 'intersect_ring_types'] = inter

In [None]:
df_inter

In [None]:
df_all = pd.concat([df_all, df_inter], axis=1)

In [None]:
ring_list_flat = [ring for freq_ring in freq_list for ring in freq_ring['freq_rings']]
count_list_flat = [count for freq_ring in freq_list for count in freq_ring['counts']]
legends = [str(freq) if i % 10 != 0 else method_list[i//10] + '_' + str(freq) for i, freq in enumerate(count_list_flat)]
Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(ring) for ring in ring_list_flat],
                          molsPerRow=10, subImgSize=(250,200), legends=legends, maxMols=200)


## Save results

In [None]:
df_all

In [None]:
df_all.to_csv(f'../outputs/metics_all_methods.csv')