## Preliminary Filtering and analysis

In [None]:
import subprocess
import os
print('Current conda environment:', os.environ['CONDA_DEFAULT_ENV'])
os.environ['TOKENIZERS_PARALLELISM'] = "false"

cwd = os.getcwd()
print('Working directory:', cwd)

import warnings
warnings.filterwarnings('ignore')

import random
random.seed(42)



In [None]:
import numpy as np
import pandas as pd

from rdkit import Chem
import useful_rdkit_utils as uru

import pickle

import seaborn as sns

In [None]:
#Seaborn settings for visualizations

rc = {
    "axes.facecolor": "#f7f9fc",
    "figure.facecolor": "#f7f9fc",
    "axes.edgecolor": "#000000",
    "grid.color": "#EBEBE7",
    "font.family": "serif",
    "axes.labelcolor": "#000000",
    "xtick.color": "#000000",
    "ytick.color": "#000000",
    "grid.alpha": 0.4
}

default_palette = 'tab10'

sns.set(rc=rc)
pd.set_option('display.max_columns', 35)
pd.options.display.float_format = '{:,.2f}'.format

## Generating Data

In [None]:
fragments = []

f = open("data/fragments.smi", "r")

for i in range(1, 51):
    mol = f.readline()
    fragments.append(mol[:-1])

### Generate a distribution from each fragment for each model

In [None]:
models = ['reinvent', 'crem', 'coati', 'safe']

#### Warning: The following 3 cells are dangerous:

In [None]:
# %%capture

# for fragment in fragments:
#     for model in models:

#         DF_FILEPATH = f'data/{model}_dataframe.csv'

#         arg1 = '--model'
#         arg2 = '--input_frag'
#         arg3 = '--sample'

#         args = ['python3', 'generate_analogs.py',
#                 arg1, model,
#                 arg2, fragment,
#                 arg3, '200']

#         # Change directory to generate analogs with python script
#         %cd ..

#         subprocess.run(args,
#                     stdout=subprocess.DEVNULL,
#                     stderr=subprocess.STDOUT)
                
#         # Change directory back to that of the current notebook
#         %cd experiments

#         df = pd.read_csv(DF_FILEPATH, index_col=0)

#         df['Model'] = model

#         if model == 'reinvent':
#             reinvent_distributions.append(df)
#         elif model == 'crem':
#             crem_distributions.append(df)
#         elif model == 'coati':
#             coati_distributions.append(df)
#         elif model == 'safe':
#             safe_distributions.append(df)

In [None]:
# data = {'reinvent' : reinvent_distributions,
#         'crem' : crem_distributions,
#         'coati' : coati_distributions,
#         'safe' : safe_distributions}

In [None]:
# with open('lists.pkl', 'wb') as file:
#     pickle.dump(data, file)

In [None]:
with open('data/lists.pkl', 'rb') as file:
    data = pickle.load(file)

### Concatenate data for each model

In [None]:
model_dfs = []

for model in models:

    distributions = data[model]
    model_df = pd.DataFrame()

    for df in distributions:

        model_df = pd.concat((model_df, df))

    model_dfs.append(model_df)

### Filtering invalid molecules

In [None]:
for i, df in enumerate(model_dfs):

    smiles_list = df['SMILES'].to_list()

    valid_smiles = []
    invalid_smiles = []

    for smiles in smiles_list:
        
        try:
            molecule = Chem.MolFromSmiles(smiles, sanitize=True)
            if molecule is not None:
                valid_smiles.append(True)
            else:
                valid_smiles.append(False)
        except Exception as e:
            invalid_smiles.append(smiles)
        
    model_dfs[i] = df[valid_smiles]

In [None]:
len(model_dfs[3])

## Filtering
Initialise dictionary

In [None]:
d = {'reinvent' : np.zeros(9),
     'crem' : np.zeros(9),
     'coati' : np.zeros(9),
     'safe' : np.zeros(9)}

scores = pd.DataFrame(data = d, index = ['valid', 'duplicates', 'good rings', 'Dundee', 'scaffold novelty', 'skeleton novelty', "Lilly_chem", "filters_NIBR", "rule_of_five"])

### Filter #1: Invalid SMILES strings

In [None]:
def num_valid(df):

    size = len(df)
    count = 0

    for smi in df.SMILES:

        mol = Chem.MolFromSmiles(smi)

        if mol is not None:
            count += 1

    return count / size

In [None]:
for model, df in zip(models, model_dfs):

    scores[model]['valid'] = num_valid(df)

In [None]:
scores

### Filter #2: Duplicates

In [None]:
safe_df = model_dfs[3]

In [None]:
safe_df

In [None]:
for i, (model, df) in enumerate(zip(models, model_dfs)):

    mols = [Chem.MolFromSmiles(smi) for smi in df.SMILES]

    df['ROMol'] = mols

    df['inchi'] = df.ROMol.apply(Chem.MolToInchiKey)

    duplicates = df.drop_duplicates(subset="inchi")

    scores[model]['duplicates'] = len(duplicates) / len(df)

    df.drop_duplicates(subset="inchi", inplace=True)
    
    model_dfs[i] = df 

In [None]:
scores

### Filter #3: Odd ring systems

In [None]:
def number_odd_rings(df):

    ring_system_lookup = uru.RingSystemLookup.default()
    df['ring_systems'] = df.SMILES.apply(ring_system_lookup.process_smiles)
    df[['min_ring','min_freq']] = df.ring_systems.apply(uru.get_min_ring_frequency).to_list()
    good_rings = df.query('min_freq > 100').copy()

    return len(good_rings) / len(df)

In [None]:
for model, df in zip(models, model_dfs):

    scores[model]['good rings'] = number_odd_rings(df)

0.892	0.575	0.692	0.983333

In [None]:
scores

### Filter #4: Chemical Stability

In [None]:
reos = uru.REOS()
reos.set_active_rule_sets(["Dundee"])

In [None]:
def num_stable(df):

    temp_df = pd.DataFrame()

    temp_df[['rule_set','reos']] = [list(reos.process_mol(x)) for x in df.ROMol]

    return uru.value_counts_df(temp_df, "reos")['count'].values[0] / len(df)

In [None]:
for model, df in zip(models, model_dfs):

    scores[model]['Dundee'] = num_stable(df)

In [None]:
scores

### Scaffold and Skeleton Novelty

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.DataStructs import BulkTanimotoSimilarity

In [None]:
def extract_scaffold(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    scaffold = MurckoScaffold.GetScaffoldForMol(mol)
    return Chem.MolToSmiles(scaffold)


def extract_scaffold_skeleton(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    scaffold = MurckoScaffold.MakeScaffoldGeneric(MurckoScaffold.GetScaffoldForMol(mol))
    return Chem.MolToSmiles(scaffold)



In [None]:
def get_scaffold_novelty(df, initial_frags):
    intial_scaffold = set([extract_scaffold(frag) for frag in initial_frags])
    initial_skeleton = set([extract_scaffold_skeleton(frag) for frag in initial_frags])

    model_scaffolds = set(df.SMILES.apply(extract_scaffold))
    model_skeletons = set(df.SMILES.apply(extract_scaffold_skeleton))

    novel_scaffolds = model_scaffolds - intial_scaffold
    novel_skeletons = model_skeletons - initial_skeleton

    return len(novel_scaffolds) / len(model_scaffolds), len(novel_skeletons) / len(model_skeletons)


In [None]:
for model, df in zip(models, model_dfs):

    scaffold_novelty, skeleton_novelty = get_scaffold_novelty(df, fragments)

    scores[model]['scaffold novelty'] = scaffold_novelty
    scores[model]['skeleton novelty'] = skeleton_novelty

In [None]:
print(scores)

### Define functions for Medchem Filters

In [None]:

import matplotlib.pyplot as plt
import matplotlib.colors
import seaborn as sns

In [None]:
import medchem as mc

In [None]:
from rdkit.Chem import PandasTools
import datamol as dm

In [None]:
model_dfs[0]["mol"] = model_dfs[0]["SMILES"].apply(dm.to_mol)

In [None]:
def process_filters(data):

    data["rule_of_five"] = data["SMILES"].apply(mc.rules.basic_rules.rule_of_five)

    # Apply some default medchem filters
    data["alerts_Dundee"] = mc.functional.alert_filter(
        mols=data["mol"].tolist(),
        alerts=["Dundee"],
        n_jobs=-1,
        progress=True,
        return_idx=False,
    )

    data["alerts_SureChEMBL"] = mc.functional.alert_filter(
        mols=data["mol"].tolist(),
        alerts=["SureChEMBL"],
        n_jobs=-1,
        progress=True,
        return_idx=False,
    )

    data["filters_NIBR"] = mc.functional.nibr_filter(
        mols=data["mol"].tolist(),
        n_jobs=-1,
        progress=True,
        return_idx=False,
    )

    data["filter_molecular_graph"] = mc.functional.molecular_graph_filter(
        mols=data["mol"].tolist(),
        max_severity=5,
        n_jobs=-1,
        progress=True,
        return_idx=False,
    )

    data["filter_lilly_demerit"] = mc.functional.lilly_demerit_filter(
        mols=data["mol"].tolist(),
        n_jobs=-1,
        progress=True,
        return_idx=False,
    )

In [None]:

import matplotlib.pyplot as plt
import matplotlib.colors
import seaborn as sns

In [None]:
def plot_filters(data, model):
    filter_columns = [
        "rule_of_five",
        "alerts_Dundee",
        "alerts_SureChEMBL",
        "filters_NIBR",
        "filter_molecular_graph",
        "filter_lilly_demerit",
    ]

    # Some sorting for a nice plot
    data["n_filters_pass"] = data[filter_columns].sum(axis=1)
    data = data.sort_values("n_filters_pass", ascending=True)

    # Plot

    f, ax = plt.subplots(figsize=(14, 4), constrained_layout=True)

    cmap = matplotlib.colors.ListedColormap(["#EF6262", "#1D5B79"], None)

    a = sns.heatmap(
        data[filter_columns].T,
        annot=False,
        ax=ax,
        xticklabels=False,  # type: ignore
        yticklabels=True,  # type: ignore
        cbar=True,
        cmap=cmap,
    )

    ax.collections[0].colorbar.set_ticks([0.25, 0.75])
    ax.collections[0].colorbar.set_ticklabels(["Don't Pass", "Pass"], fontsize=14)

    ax.set_xlabel(f"Analogs from {model} sorted (n={len(data)})", fontsize=14)
    ax.set_ylabel("Medchem Filters", fontsize=18)

    # Add percentage of passing mols in the y labels
    new_ylabels = []
    for t in ax.yaxis.get_ticklabels():
        perc = data[t.get_text()].sum() / len(data) * 100
        new_ylabels.append(f"{t.get_text()} ({perc:.0f}%)")
    _ = ax.yaxis.set_ticklabels(new_ylabels, fontsize=12)


In [None]:
for model, df in zip(models[:3], model_dfs[:3]):

        df["mol"] = df["SMILES"].apply(dm.to_mol)
    
        process_filters(df)
    
        plot_filters(df, model)

In [None]:
model_dfs[3]["mol"] = model_dfs[3]["SMILES"].apply(dm.to_mol)
    

In [None]:
safe_1000 = model_dfs[3].sample(1900, random_state=42)

In [None]:
process_filters(safe_1000)

In [None]:
plot_filters(safe_1000, model[3])

In [None]:
for model, df in zip(models[:3], model_dfs[:3]):

    scores[model]["Lilly_chem"] = df["filter_lilly_demerit"].sum()/len(df)
    scores[model]["rule_of_five"] = df["rule_of_five"].sum()/len(df)

In [None]:
scores['safe']["Lilly_chem"] = safe_1000["filter_lilly_demerit"].sum()/len(safe_1000)
scores['safe']["rule_of_five"] = safe_1000["rule_of_five"].sum()/len(safe_1000)



### Now we can make Barplots using the scores

In [None]:
#make bar plots of the scores color coded by model for each filter

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

for i, filter in enumerate(["Lilly_chem", "rule_of_five"]):
        
            sns.barplot(x=models, y=[scores[model][filter] for model in models], ax=ax[i], palette=default_palette)
        
            ax[i].set_title(filter)
            ax[i].set_ylabel('Fraction of Molecules Passing Filter')
            
            ax[i].set_ylim(0, 1.05)
            
            for p in ax[i].patches:
                ax[i].annotate(f'{p.get_height():.2f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                            ha='center', va='center', fontsize=12, color='black', xytext=(0, 5),
                            textcoords='offset points')

plt.tight_layout()
plt.show()

In [None]:
#Now we will do 'duplicates', 'good rings', 'chemical stability'

fig, ax = plt.subplots(1, 3, figsize=(15, 5))

for i, filter in enumerate(["duplicates", "good rings", "Dundee"]):
            
                sns.barplot(x=models, y=[scores[model][filter] for model in models], ax=ax[i], palette=default_palette)
            
                ax[i].set_title(filter)
                ax[i].set_ylabel('Fraction of Molecules Passing Filter')
                
                ax[i].set_ylim(0, 1.05)
                
                for p in ax[i].patches:
                    ax[i].annotate(f'{p.get_height():.2f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                                ha='center', va='center', fontsize=12, color='black', xytext=(0, 5),
                                textcoords='offset points')

plt.tight_layout()

In [None]:
#Finally we will do 'scaffold novelty', 'skeleton novelty'

fig, ax = plt.subplots(1, 2, figsize=(11, 4))
ax = ax.flatten()

scaffold_novelty = [scores[model]['scaffold novelty'] for model in models]
skeleton_novelty = [scores[model]['skeleton novelty'] for model in models]

sns.barplot(x=models, y=scaffold_novelty, ax=ax[0], palette=default_palette)
ax[0].set_title('Scaffold Novelty')
ax[0].set_ylabel('Fraction of Novel Scaffolds')

sns.barplot(x=models, y=skeleton_novelty, ax=ax[1], palette=default_palette)
ax[1].set_title('Skeleton Novelty')
ax[1].set_ylabel('Fraction of Novel Skeletons')

plt.tight_layout()
plt.show()

In [None]:
scores_display = scores.loc[['good rings', 'Dundee', 'Lilly_chem', 'rule_of_five']]

In [None]:
scores_display = pd.melt(scores_display.reset_index(), id_vars='index', var_name='Model', value_name='Score')



In [None]:
g = sns.catplot(data=scores_display, x='index', y='Score', hue='Model', kind='bar', palette=default_palette, height=5, aspect=2)

g.set_axis_labels('Filter', 'Fraction of Molecules Passing Filter')

# Add percentages to the bars

for ax in g.axes.flatten():
    for p in ax.patches:
        ax.annotate(f'{p.get_height():.2f}', (p.get_x() + p.get_width() / 2., p.get_height()),
                    ha='center', va='center', fontsize=12, color='black', xytext=(0, 5),
                    textcoords='offset points')
        
        