In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import bootstrap
from tqdm import tqdm

tqdm.pandas()

In [19]:
# Example of a test dataframe:
pd.read_pickle('../data/test_results/retrieval/rebuttal_MIST_test_formula_2024-08-13_15-07-19.pkl')

Unnamed: 0,identifier,sorted_candidate_smiles,test_hit_rate@1,test_hit_rate@5,test_hit_rate@20,test_mces@1
0,MassSpecGymID0000201,[O=C(NC1CC2(CC(O)C2)C1)OCc1ccccc1.O=C(NC1CC2(C...,0.0,0.0,1.0,18.0
1,MassSpecGymID0000202,[O=C(NC1CC2(CC(O)C2)C1)OCc1ccccc1.O=C(NC1CC2(C...,0.0,0.0,0.0,18.0
2,MassSpecGymID0000203,[O=C(NC1CC2(CC(O)C2)C1)OCc1ccccc1.O=C(NC1CC2(C...,0.0,0.0,1.0,18.0
3,MassSpecGymID0000204,[O=C(NC1CC2(CC(O)C2)C1)OCc1ccccc1.O=C(NC1CC2(C...,0.0,0.0,1.0,18.0
4,MassSpecGymID0000205,[O=C(NC1CC2(CC(O)C2)C1)OCc1ccccc1.O=C(NC1CC2(C...,0.0,0.0,1.0,18.0
...,...,...,...,...,...,...
17551,MassSpecGymID0414164,[CCCN1C(=O)C(=O)/C(=C(/O)c2cc(Cl)c(OC)cc2OC)C1...,0.0,0.0,0.0,15.5
17552,MassSpecGymID0414165,[COc1c2c(c(CNC(=O)COc3ccc4c(c3)OCO4)c3c1C(=O)N...,0.0,0.0,0.0,18.5
17553,MassSpecGymID0414166,[CCCC(=O)O[C@H]1CC[C@@]2(C)C(=CC[C@@H]3C2CC[C@...,0.0,0.0,0.0,15.0
17554,MassSpecGymID0414167,[CCCC(=O)O[C@H]1CC[C@@]2(C)C(=CC[C@@H]3C2CC[C@...,0.0,0.0,0.0,15.0


In [2]:
def evaluate(dir_results, task, seed=0):
    np.random.seed(seed)

    # Prepare
    if task == 'retrieval':
        metric_cols = ['test_hit_rate@1', 'test_hit_rate@5', 'test_hit_rate@20', 'test_mces@1']
    elif task == 'de_novo':
        metric_cols = [
            'test_top_1_accuracy', 'test_top_1_mces_dist', 'test_top_1_max_tanimoto_sim',
            'test_top_10_accuracy', 'test_top_10_mces_dist', 'test_top_10_max_tanimoto_sim'
        ]

    # Load all data into a single data frame
    df = []
    for path in dir_results.glob('*.pkl'):
        df_method = pd.read_pickle(path)
        df_method['method'] = path.stem
        df_method = df_method.rename(columns={'test_mces_at_1': 'test_mces@1'})  # compatibility
        df.append(df_method)
    df = pd.concat(df)

    # Preprocess data frame
    for col in [c for c in df.columns if ('hit_rate' in c or 'accuracy' in c)]:
        df[col] *= 100

    # Calculate means for all metrics into a single table
    df_mean = df.groupby('method')[metric_cols].mean().round(2)

    # Calculate confidence intervals for all metrics into a single table
    def get_ci(col_vals, confidence_level=0.999, n_resamples=20_000):
        res = bootstrap((col_vals,), np.mean, confidence_level=confidence_level, n_resamples=n_resamples)
        ci = res.confidence_interval
        return f'{ci.low:.2f}-{ci.high:.2f}'
    def get_ci_for_each_col(df_method):
        return df_method.apply(get_ci, axis=0)
    df_ci = df.groupby('method')[metric_cols].progress_apply(lambda df_method: get_ci_for_each_col(df_method))

    # Merge tables with means and confidence intervals
    for col in metric_cols:
        df_mean[col] = df_mean[col].astype(str) + ' (' + df_ci[col] + ')'
    return df_mean

In [3]:
dir_results = Path('../data/test_results/retrieval')
task = 'retrieval'

df = evaluate(dir_results, task)
df

100%|██████████| 9/9 [05:55<00:00, 39.47s/it]


Unnamed: 0_level_0,test_hit_rate@1,test_hit_rate@5,test_hit_rate@20,test_mces@1
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
rebuttal_MIST_test_formula_2024-08-13_15-07-19,9.57 (8.88-10.30),22.11 (21.13-23.24),41.12 (39.91-42.29),12.75 (12.58-12.92)
rebuttal_deepsets_test_formula_2024-08-15_16-45-06,4.42 (3.91-4.93),14.46 (13.60-15.39),30.76 (29.64-31.90),15.04 (14.89-15.19)
rebuttal_deepsets_test_mass_2024-08-14_22-51-05,1.47 (1.20-1.79),6.21 (5.63-6.84),19.23 (18.27-20.22),25.11 (24.84-25.38)
rebuttal_enhanced_MIST_test_mass_2024-08-13_01-18-44,14.64 (13.78-15.53),34.87 (33.70-36.06),59.15 (57.95-60.33),15.37 (15.13-15.62)
rebuttal_fingerprint_ffn_test_formula_2024-08-15_15-45-02,5.09 (4.57-5.62),14.69 (13.83-15.57),31.97 (30.80-33.13),14.94 (14.79-15.10)
rebuttal_fingerprint_ffn_test_mass_2024-08-15_15-39-32,2.54 (2.16-2.97),7.59 (6.93-8.27),20.0 (19.06-21.05),24.66 (24.37-24.95)
rebuttal_random_test_formula_2024-08-13_16-14-07,3.06 (2.67-3.51),11.35 (10.59-12.14),27.74 (26.62-28.94),13.87 (13.70-14.03)
rebuttal_random_test_formula_2024-08-13_17-08-09,3.06 (2.64-3.50),11.35 (10.58-12.13),27.74 (26.66-28.80),13.87 (13.70-14.03)
rebuttal_random_test_mass_2024-08-13_17-08-09,0.37 (0.24-0.54),2.01 (1.68-2.38),8.22 (7.57-8.93),30.81 (30.43-31.24)


In [14]:
df_paper = df.reset_index()
df_paper = df_paper[df_paper['method'].str.contains('formula')]
df_paper = df_paper.sort_values('test_hit_rate@1', ascending=True, key=lambda x: x.str.split(' ').str[0].astype(float))
print(df_paper.to_markdown(index=False))

| method                                                    | test_hit_rate@1   | test_hit_rate@5     | test_hit_rate@20    | test_mces@1         |
|:----------------------------------------------------------|:------------------|:--------------------|:--------------------|:--------------------|
| rebuttal_random_test_formula_2024-08-13_16-14-07          | 3.06 (2.67-3.51)  | 11.35 (10.59-12.14) | 27.74 (26.62-28.94) | 13.87 (13.70-14.03) |
| rebuttal_random_test_formula_2024-08-13_17-08-09          | 3.06 (2.64-3.50)  | 11.35 (10.58-12.13) | 27.74 (26.66-28.80) | 13.87 (13.70-14.03) |
| rebuttal_deepsets_test_formula_2024-08-15_16-45-06        | 4.42 (3.91-4.93)  | 14.46 (13.60-15.39) | 30.76 (29.64-31.90) | 15.04 (14.89-15.19) |
| rebuttal_fingerprint_ffn_test_formula_2024-08-15_15-45-02 | 5.09 (4.57-5.62)  | 14.69 (13.83-15.57) | 31.97 (30.80-33.13) | 14.94 (14.79-15.10) |
| rebuttal_MIST_test_formula_2024-08-13_15-07-19            | 9.57 (8.88-10.30) | 22.11 (21.13-23.24) | 41.12 (3

|                                                  | Hit rate @ 1 ↑    | Hit rate @ 5 ↑    | Hit rate @ 20 ↑   | MCES @ 1 ↓        |
|:-------------------------------------------------------|:--------------------:|:--------------------:|:--------------------:|:--------------------:|
| Random          | 0.37 (0.24-0.54)    | 2.01 (1.68-2.38)    | 8.22 (7.57-8.93)    | 30.81 (30.43-31.24) |
| DeepSets        | 1.47 (1.20-1.79)    | 6.21 (5.63-6.84)    | 19.23 (18.27-20.22) | 25.11 (24.84-25.38) |
| FingerprintFFN | 2.54 (2.16-2.97)    | 7.59 (6.93-8.27)    | 20.0 (19.06-21.05)  | 24.66 (24.37-24.95) |
| MIST   | **14.64** (13.78-15.53) | **34.87** (33.70-36.06) | **59.15** (57.95-60.33) | **15.37** (15.13-15.62) |
| *Bonus chemical formulae challenge*                                                    |    |      |     |          |
| Random          | 3.06 (2.64-3.50)  | 11.35 (10.58-12.13) | 27.74 (26.66-28.80) | 13.87 (13.70-14.03) |
| DeepSets        | 4.42 (3.91-4.93)  | 14.46 (13.60-15.39) | 30.76 (29.64-31.90) | 15.04 (14.89-15.19) |
| FingerprintFFN | 5.09 (4.57-5.62)  | 14.69 (13.83-15.57) | 31.97 (30.80-33.13) | 14.94 (14.79-15.10) |
| MIST            | **9.57** (8.88-10.30) | **22.11** (21.13-23.24) | **41.12** (39.91-42.29) | **12.75** (12.58-12.92) |

In [15]:
dir_results = Path('../data/test_results/de_novo')
task = 'de_novo'

df = evaluate(dir_results, task)
df

  a_hat = 1/6 * sum(nums) / sum(dens)**(3/2)
100%|██████████| 4/4 [04:13<00:00, 63.32s/it]


Unnamed: 0_level_0,test_top_1_accuracy,test_top_1_mces_dist,test_top_1_max_tanimoto_sim,test_top_10_accuracy,test_top_10_mces_dist,test_top_10_max_tanimoto_sim
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
random_baseline_formula,0.0 (nan-nan),21.11 (20.97-21.26),0.08 (0.08-0.08),0.0 (nan-nan),18.25 (18.14-18.35),0.11 (0.11-0.11)
random_baseline_no_formula,0.0 (nan-nan),28.59 (28.33-28.84),0.07 (0.07-0.07),0.0 (nan-nan),25.72 (25.48-25.96),0.1 (0.10-0.10)
rebuttal_selfies_transformer_test_2024-08-15_16-05-36,0.0 (nan-nan),33.28 (32.98-33.58),0.1 (0.10-0.10),0.0 (nan-nan),21.84 (21.67-22.00),0.15 (0.15-0.15)
rebuttal_smiles_transformer_test_2024-08-15_17-11-37,0.0 (nan-nan),53.8 (52.95-54.65),0.07 (0.07-0.08),0.0 (nan-nan),21.97 (21.78-22.16),0.17 (0.17-0.17)


In [17]:
df_paper = df.reset_index()
# df_paper = df_paper[~df_paper['method'].str.contains('formula')]
df_paper = df_paper.sort_values('test_top_1_mces_dist', ascending=True, key=lambda x: x.str.split(' ').str[0].astype(float))
print(df_paper.to_markdown(index=False))

| method                                                | test_top_1_accuracy   | test_top_1_mces_dist   | test_top_1_max_tanimoto_sim   | test_top_10_accuracy   | test_top_10_mces_dist   | test_top_10_max_tanimoto_sim   |
|:------------------------------------------------------|:----------------------|:-----------------------|:------------------------------|:-----------------------|:------------------------|:-------------------------------|
| random_baseline_formula                               | 0.0 (nan-nan)         | 21.11 (20.97-21.26)    | 0.08 (0.08-0.08)              | 0.0 (nan-nan)          | 18.25 (18.14-18.35)     | 0.11 (0.11-0.11)               |
| random_baseline_no_formula                            | 0.0 (nan-nan)         | 28.59 (28.33-28.84)    | 0.07 (0.07-0.07)              | 0.0 (nan-nan)          | 25.72 (25.48-25.96)     | 0.1 (0.10-0.10)                |
| rebuttal_selfies_transformer_test_2024-08-15_16-05-36 | 0.0 (nan-nan)         | 33.28 (32.98-33.58)    | 0

|                                                 | Top-1 Accuracy ↑   | Top-1 MCES ↓   | Top-1 Tanimoto ↑   | Top-10 Accuracy ↑   | Top-10 MCES ↓   | Top-10 Tanimoto ↑   |
|:------------------------------------------------------|:----------------------:|:-----------------------:|:------------------------------:|:-----------------------:|:------------------------:|:-------------------------------:|
| Random chemical generation                            | 0.0         | **28.59** (28.33-28.84)    | 0.07 (0.07-0.07)              | 0.0          | 25.72 (25.48-25.96)     | 0.1 (0.10-0.10)                |
| SMILES Transformer  | 0.0         | 53.8 (52.95-54.65)     | 0.07 (0.07-0.08)              | 0.0          | 21.97 (21.78-22.16)     | **0.17** (0.17-0.17)               |
| SELFIES Transformer | 0.0         | 33.28 (32.98-33.58)    | **0.1** (0.10-0.10)               | 0.0          | **21.84** (21.67-22.00)     | 0.15 (0.15-0.15)               |
| *Bonus chemical formulae challenge*
| Random chemical generation                               | 0.0         | **21.11** (20.97-21.26)    | **0.08** (0.08-0.08)              | 0.0          | **18.25** (18.14-18.35)     | **0.11** (0.11-0.11)               |