In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [1]:
import pickle
import os
from smart_open import open
import smart_open
import numpy as np
import random
import torch
import itertools
from tqdm import tqdm
from coatiLDM.common.utils import utc_epoch_now, batch_iterable, tanimoto_distance_torch
import rdkit
from rdkit import Chem
from rdkit.Chem.Crippen import MolLogP
from rdkit.Chem import Draw
from coatiLDM.common.utils import mol_to_morgan
from coatiLDM.common.s3 import load_figure_file


In [3]:
base_path = './figure_data/gen_comparison_data/'

## Get Walters

In [5]:
walters_data = load_figure_file('walters_hcaii.pkl', base_path,filetype='pkl')
walters_stacked = torch.from_numpy(np.stack([r['morgan'] for r in walters_data])).to(dtype=torch.float)

# Eval Methods

In [5]:
import torch

from collections import defaultdict


from coatiLDM.common.utils import batch_iterable

def _eval_on_recs(_eval_recs, _test_stacked, cutoffs=None, prefix='', stack_field='morgan', batch_size=100_000):
    if cutoffs is None:
        cutoffs = (np.arange(0, 9)+1)/10
    
    # test_stacked = torch.from_numpy(test_stacked)

    cutoff_fields = {ct: f'{prefix}neighbors_>={ct}' for ct in cutoffs}
    test_coverage = defaultdict(set)
    
    for _batch in tqdm(batch_iterable(_eval_recs, batch_size), total=-(len(_eval_recs) // -batch_size)):
        mask = torch.tensor([True if (stack_field in r) and (r[stack_field] is not None) else False for r in _batch])
        mask_indices = np.where(mask)[0]
        stacked_batch = torch.stack([torch.from_numpy(r[stack_field]) for r, msk in zip(_batch, mask)]).to(torch.float)
        print(stacked_batch.sum(axis=1))
        dists = tanimoto_distance_torch(stacked_batch, _test_stacked)
        
        sims = 1 - dists
    
        assert sims.shape[0] == len(stacked_batch)
        assert sims.shape[1] == len(_test_stacked)
        
        
        for cutoff in cutoffs:
            _matches = (sims>=cutoff)
            _neighbors = _matches.sum(axis=1)
            
            _match_idx_map = {m: np.where(_matches[m, :])[0] for m in range(len(stacked_batch))}
            
            
            for n_idx in torch.arange(len(_test_stacked))[
                ((sims>=cutoff).sum(axis=0)>0).cpu()
            ].detach().tolist():
                assert isinstance(n_idx, int)
                test_coverage[cutoff].add(n_idx)
            
            neighbors = torch.zeros(len(mask) ,dtype=_neighbors.dtype)
            neighbors[mask]=_neighbors
            
            # Extract the original indices that were True in the mask


            # Map these original indices to their corresponding values
            match_idx_map = {original_index: _match_idx_map[i] for i, original_index in enumerate(mask_indices)}

            for _idx, r in enumerate(_batch):
                r[cutoff_fields[cutoff]] = neighbors[_idx].cpu().tolist()
                if _idx in match_idx_map:
                    r[cutoff_fields[cutoff]+'_indices'] = match_idx_map[_idx]
    

    test_coverage_percents = {
        k: len(test_coverage[k]) / len(_test_stacked) if k in test_coverage else 0
        for k in cutoffs
    }
    
        
            
    return cutoff_fields, test_coverage_percents

### Copy underlying data to one construct

In [20]:
all_model_data = load_figure_file('all_model_data.pkl', base_path,filetype='pkl')


In [8]:
test_set = all_model_data['test']

In [9]:
results_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

## Find coverage

This will take a bit...

In [10]:
import pandas as pd
from collections import defaultdict

result_recs = []
desired_models = list(all_model_data)
desired_models = [model_name for model_name in all_model_data if '__' not in model_name]
print(desired_models)


all_test_data = {
    'test': test_set,
    'walters': walters_data,
}
stack_field = 'morgan'
neighbor_vals = [1]

existing_models = defaultdict(set)
for t in results_dict:
    for c in results_dict[t]:
        for m in results_dict[t][c]:
            existing_models[m].add(t)

for test_name in all_test_data:
    test_prefix = test_name + '_'
    _test_data = all_test_data[test_name]
    _stacked_test = torch.from_numpy(np.stack([r[stack_field] for r in _test_data])).to(dtype=torch.float)
    for model_name in desired_models:
        if model_name in existing_models and test_name in existing_models[model_name]:
            print(f'Skipping {test_name} {model_name}')
#             continue
       
        model_data = all_model_data[model_name]
        
        if '__' in model_name:
            _model_name, rep = model_name.split('__')
        else:
            _model_name, rep = model_name, 0
        
        


        for r in model_data:
            if 'morgan' not in r:
                r['morgan'] = mol_to_morgan(r['smiles'])

        cutoff_fields, test_coverage = _eval_on_recs(model_data, _stacked_test, prefix=test_prefix)
        
        for cutoff in cutoff_fields:
            cutoff_field = cutoff_fields[cutoff]

            values = np.array([r[cutoff_field] for r in model_data])
            for n_neighbor in neighbor_vals:
                
                
                result_info = {
                    'test_name': test_name,
                    'model_name': _model_name,
                    'repetition': rep,
                    'tanimoto_cutoff': cutoff,
                    'gt_n_neighbors': n_neighbor,
                    'gt_n_neighbors_proportion': len(values[values>=n_neighbor]) / len(values),
                    'model_data_size': len(model_data),
                    'test_data_size': len(_test_data),
                    'test_coverage': test_coverage[cutoff],
                }
                
                results_dict[test_name][cutoff][_model_name][rep] = result_info


['CG-20-50k', 'QED-and-CG-20', 'CFG-50k', 'Genetic A', 'Genetic B', 'Genetic D', 'test', 'train', 'Genetic C']


  0%|          | 0/1 [00:00<?, ?it/s]

tensor([55., 76., 55.,  ..., 72., 21., 46.])


  A = torch.tensor(A, dtype=torch.float32).to(B.device)
100%|██████████| 1/1 [01:27<00:00, 87.55s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([43., 42., 56.,  ..., 64., 50., 46.])


100%|██████████| 1/1 [01:44<00:00, 104.72s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([92., 68., 78.,  ..., 62., 65., 47.])


100%|██████████| 1/1 [01:38<00:00, 98.54s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([ 75.,  81.,  83.,  ..., 102.,  81., 105.])


100%|██████████| 1/1 [01:41<00:00, 101.68s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([78., 57., 85.,  ..., 81., 75., 79.])


100%|██████████| 1/1 [01:36<00:00, 96.82s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([47., 48., 46.,  ..., 46., 49., 46.])


100%|██████████| 1/1 [01:33<00:00, 93.31s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([ 61., 102.,  60.,  ...,  87.,  59.,  73.])


100%|██████████| 1/1 [00:50<00:00, 50.22s/it]
  0%|          | 0/5 [00:00<?, ?it/s]

tensor([57., 86., 70.,  ..., 68., 94., 96.])


 20%|██        | 1/5 [03:17<13:08, 197.21s/it]

tensor([86., 55., 64.,  ..., 68., 66., 51.])


 40%|████      | 2/5 [06:42<10:06, 202.06s/it]

tensor([66., 70., 74.,  ..., 48., 83., 75.])


 60%|██████    | 3/5 [10:05<06:44, 202.26s/it]

tensor([52., 77., 69.,  ..., 57., 65., 57.])


 80%|████████  | 4/5 [13:33<03:24, 204.60s/it]

tensor([72., 64., 72.,  ..., 67., 67., 75.])


100%|██████████| 5/5 [16:13<00:00, 194.69s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([53., 74., 57.,  ..., 85., 89., 73.])


100%|██████████| 1/1 [01:38<00:00, 98.96s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([55., 76., 55.,  ..., 72., 21., 46.])


100%|██████████| 1/1 [00:27<00:00, 27.49s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([43., 42., 56.,  ..., 64., 50., 46.])


100%|██████████| 1/1 [00:31<00:00, 31.72s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([92., 68., 78.,  ..., 62., 65., 47.])


100%|██████████| 1/1 [00:29<00:00, 29.94s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([ 75.,  81.,  83.,  ..., 102.,  81., 105.])


100%|██████████| 1/1 [00:31<00:00, 31.30s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([78., 57., 85.,  ..., 81., 75., 79.])


100%|██████████| 1/1 [00:31<00:00, 31.16s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([47., 48., 46.,  ..., 46., 49., 46.])


100%|██████████| 1/1 [00:31<00:00, 31.33s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([ 61., 102.,  60.,  ...,  87.,  59.,  73.])


100%|██████████| 1/1 [00:15<00:00, 15.80s/it]
  0%|          | 0/5 [00:00<?, ?it/s]

tensor([57., 86., 70.,  ..., 68., 94., 96.])


 20%|██        | 1/5 [01:03<04:12, 63.09s/it]

tensor([86., 55., 64.,  ..., 68., 66., 51.])


 40%|████      | 2/5 [02:06<03:09, 63.21s/it]

tensor([66., 70., 74.,  ..., 48., 83., 75.])


 60%|██████    | 3/5 [03:09<02:06, 63.38s/it]

tensor([52., 77., 69.,  ..., 57., 65., 57.])


 80%|████████  | 4/5 [04:12<01:03, 63.07s/it]

tensor([72., 64., 72.,  ..., 67., 67., 75.])


100%|██████████| 5/5 [05:00<00:00, 60.18s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

tensor([53., 74., 57.,  ..., 85., 89., 73.])


100%|██████████| 1/1 [00:31<00:00, 31.27s/it]


In [14]:
import pandas as pd

result_recs = [
    results_dict[t][c][m][r]
    for t in results_dict
    for c in results_dict[t]
    for m in results_dict[t][c]
    for r in results_dict[t][c][m]
]

_res_df = pd.DataFrame(result_recs)
_res_df

Unnamed: 0,test_name,model_name,repetition,tanimoto_cutoff,gt_n_neighbors,gt_n_neighbors_proportion,model_data_size,test_data_size,test_coverage
0,test,CG-20-50k,0,0.1,1,0.999908,43536,25007,1.000000
1,test,QED-and-CG-20,0,0.1,1,0.999980,50400,25007,1.000000
2,test,CFG-50k,0,0.1,1,0.999937,47758,25007,1.000000
3,test,Genetic A,0,0.1,1,1.000000,50000,25007,1.000000
4,test,Genetic B,0,0.1,1,1.000000,50000,25007,1.000000
...,...,...,...,...,...,...,...,...,...
157,walters,Genetic B,0,0.9,1,0.000000,50000,544,0.000000
158,walters,Genetic D,0,0.9,1,0.000000,50000,544,0.000000
159,walters,test,0,0.9,1,0.000000,25007,544,0.000000
160,walters,train,0,0.9,1,0.000000,476810,544,0.000000


## Results Table

In [15]:
import pandas as pd

# Set new Pandas display settings to avoid truncation
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 100)
pd.set_option('display.max_rows', 100)

In [16]:
result_data = pd.DataFrame(result_recs)



do_10k_instead_of_50k = False

if do_10k_instead_of_50k:
    print(set(result_data['model_name']))
    model_renames = {m: m.split('_10k')[0] for m in set(result_data['model_name']) if m.endswith('_10k')}
    print(model_renames)
    result_data = result_data[
        result_data['model_name'].isin(set(model_renames))
    ]
    result_data['model_name'] = result_data['model_name'].map(model_renames)

    print(set(result_data['model_name']))




tc_col = 'Tanimoto Cutoff'
model_col = 'Model'
partition_col = 'Partition'
neigh_col = 'Neighbors'
cov_col = 'test_coverage'

result_data = result_data.rename(columns={
    'tanimoto_cutoff': tc_col,
    'model_name': model_col,
    'test_name': partition_col,
    'gt_n_neighbors': neigh_col,
})

# cutoffs = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
cutoffs = [0.5, 0.6, 0.7]
# Filter data based on valid tuples
filtered_data = result_data[
    (
    ((result_data[partition_col] == 'walters') & (result_data[tc_col].isin(cutoffs)))
    | ((result_data[partition_col] == 'test') & (result_data[tc_col].isin(cutoffs)))
    )
    & (result_data[neigh_col] == 1)
    & (~result_data[model_col].isin(['Genetic B','test', 'CG-PG', 'CG-20-QED', 'test_10k',
                                     'CG-PG_10k', 'Genetic B_10k',
                                     'dflow_high_dec_high_10k','dflow_higher_dec_low_10k','dflow_low_dec_10k','dflow_low_dec_low_10k','single_opt_high_10k','single_opt_high']))
#     & (result_data[model_col].isin(['CFG 250m', 'CG 250m',  'Genetic 1', 'Genetic 4', 'Genetic 5', 'Genetic 7']))
].sort_values(model_col)

filtered_data[partition_col] = filtered_data[partition_col].map({
    'test': 'Binding Hits',
    'walters': 'Public Hits'
})
model_renames = {
        'train': 'hCAII Train',
        'CFG-50k': 'CFG',
        'CG-20-50k': 'CG',
        'QED-and-CG-20': 'CG w/ Joint QED',
        'Genetic A': 'Genetic',
        'Genetic C': 'Genetic w/ QED Cutoff',
        'Genetic D': 'Genetic w/ Joint QED',
    }
filtered_data[model_col] = filtered_data[model_col].map(
    { k: (model_renames[k] if k in model_renames else k) for k in set(filtered_data[model_col])
        
    }
)

filtered_data[tc_col] = filtered_data[tc_col].apply(lambda x: str(x))
filtered_data[cov_col] = filtered_data[cov_col].apply(lambda x: round(100*x, 1))


percision = 1

def mean_std(x):
    mean = x.mean()
    std = x.std()
    if mean == 0 and std == 0:
        return '0'
    elif np.isnan(std):
        return f'{mean:.1f}'
    return f'{mean:.1f} $\\pm$ {std:.1f}'

table = filtered_data.pivot_table(
    index=[partition_col, tc_col],
    columns=[model_col],
    values="test_coverage",
    aggfunc=mean_std  # Use the custom aggregation function
)

def apply_latex_formatting(df):
    for row in df.index:
        
        
        best_val = max([float(df.at[row, col].split(' $\\pm$ ')[0]) for col in df.columns if col!='hCAII Train'])

        for col in df.columns:
            if col in [partition_col, tc_col]:
                continue
            value = df.at[row, col]
            _val = float(value.split(" $\\pm$ ")[0])

            if "\\pm" in value:  # Check if the cell contains the mean ± std format
                mean, std = value.split(" $\\pm$ ")
                mean = float(mean)
                # Update the formatting logic as per your requirements
                if mean == 0:
                    df.at[row, col] = '0'
                else:
                    df.at[row, col] = f'{mean:.1f} $\\pm$ {std}'  # Include std in the formatting
                    if _val == best_val:
                        df.at[row, col] = '\\textbf{{' + df.at[row, col] + '}}'

                    
            else:

                if float(value)==0:
                    df.at[row, col] = '0'
                elif float(value) == best_val:
                    df.at[row, col] = '\\textbf{{' + df.at[row, col] + '}}'
                else:
                    df.at[row, col] = f'{value}'
    return df

ordered_cols = list(
    c
    for c in list(model_renames.values()) + [c for c in sorted(list(table.columns)) if c not in model_renames.values()]
    if c in table
)
print(ordered_cols)
table = table[ordered_cols]


# Apply the updated formatting function to the table DataFrame
formatted_table = apply_latex_formatting(table.copy())

formatted_table = formatted_table.reset_index().sort_values(by=partition_col)
formatted_table[partition_col] = formatted_table[partition_col].where(
    formatted_table[partition_col] != formatted_table[partition_col].shift(),
    ''
)
formatted_table.index.name=None
formatted_table

['hCAII Train', 'CFG', 'CG', 'CG w/ Joint QED', 'Genetic', 'Genetic w/ QED Cutoff', 'Genetic w/ Joint QED']


Model,Partition,Tanimoto Cutoff,hCAII Train,CFG,CG,CG w/ Joint QED,Genetic,Genetic w/ QED Cutoff,Genetic w/ Joint QED
0,Binding Hits,0.5,99.6,\textbf{{42.7}},5.6,10.9,0.9,1.4,1.5
1,,0.6,75.7,\textbf{{15.3}},0.9,2.1,0.1,0.2,0.2
2,,0.7,28.6,\textbf{{4.3}},0.1,0.3,0.0,0.0,0.0
3,Public Hits,0.5,0.6,1.1,6.1,\textbf{{15.1}},7.2,7.2,5.0
4,,0.6,0.0,0,2.4,\textbf{{5.7}},1.8,2.6,1.3
5,,0.7,0.0,0,0.6,\textbf{{2.0}},0.0,0.2,0.0
