In [None]:
import pandas as pd
import numpy as np
from collections import Counter
import pickle
from tqdm import tqdm
tqdm.pandas()
import statistics
from pubchempy import Compound
from rdkit import Chem, DataStructs
from rdkit.Chem import SaltRemover, QED, rdMolDescriptors
from molvs import Standardizer
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

# Datasets Task 1: Prepare PubChem Datasets

Author: Kaan Donbekci (donbekci@stanford.edu)

## Contents
* [1.1 Assemble Dataset](#1.1-Assemble-Dataset)
* [1.2 Sanitize Molecules](#1.2-Sanitize-Molecules)
* [1.3 Remove non-druglike molecules](#1.3-Remove-non-druglike-molecules)
* [1.4 Resolve errors](#1.4-Resolve-errors)
    * [1.4.1 Calculate pairwise similarities using fingerprints](#1.4.1-Calculate-pairwise-similarities-using-fingerprints)
    * [1.4.2 Find and remove duplicates](#1.4.2-Find-and-remove-duplicates) # TODO
    * [1.4.3 Find and remove activity cliffs](#1.4.3-Find-and-remove-activity-cliffs) # TODO
* [Exports](#Exports)

In [None]:
plots = False

In [None]:
AVAILABLE_DATASETS = ['ST14', 'KLKB1', 'TMPRSS11D', 'TMPRSS6']

In [None]:
ds_name = 'TMPRSS6'

In [None]:
assert ds_name in AVAILABLE_DATASETS

In [None]:
def load_pickle(filename):
    with open(f'../dumps/{filename}.pkl', 'rb') as f:
        return pickle.load(f)

In [None]:
IC50_to_Ki_df = pd.read_csv('../data/combined_IC50.csv')

In [None]:
IC50_to_Ki = {}
for i, row in IC50_to_Ki_df.iterrows():
    if row.IC50 == '1':
        IC50_to_Ki[int(row.assay_id)] = float(row.Ki)

## 1.1 Assemble Dataset

First step is to read the dataset as exported from 
[pubchempy](https://pubchem.ncbi.nlm.nih.gov/gene/6768#section=Tested-Compounds&fullscreen=true).


In [None]:
df = pd.read_csv(f'../data/{ds_name}.csv')

Let's drop the rows where the Activity value is not measured in Ki. 

In [None]:
Counter(df.acname)

In [None]:
def fix_IC50(row):
    if row.aid in IC50_to_Ki:
        row.acname = 'Ki'
        row.acvalue = row.acvalue * IC50_to_Ki[row.aid]
    return row

In [None]:
df = df.apply(fix_IC50, axis=1)

In [None]:
df = df[df.acname == 'Ki'].reset_index(drop=True)

Replace categorical labels of unspecified with threshold values

In [None]:
Counter(df.activity)

In [None]:
activity_threshold = 50

In [None]:
assert len(df[(df.acvalue < activity_threshold) & (df.activity == 'Inactive')]) == 0
df.loc[(df.acvalue<activity_threshold), 'activity'] = 'Active'

In [None]:
assert len(df[(df.acvalue >= activity_threshold) & (df.activity == 'Active')]) == 0
df.loc[(df.acvalue>=activity_threshold) , 'activity'] = 'Inactive'

Remove rows with nan activity values and unspecified activity

In [None]:
df = df.drop(df[(pd.isna(df.acvalue)) & (df.activity == 'Unspecified')].index)

Some compounds have multiple rows in the dataset, use median activity value and reduce them to a single row.

In [None]:
cid_to_rows = {}
for i, row in df.iterrows():
    if row.cid not in cid_to_rows:
        cid_to_rows[row.cid] = []
    cid_to_rows[row.cid].append(row)

In [None]:
df.activity.unique()

In [None]:
valid_activities = set(['Active', 'Inactive'])

In [None]:
cleaned_rows = []
problem_rows = []
for cid, rows in tqdm(cid_to_rows.items()):
    if len(rows) == 1:
        row = rows[0]
        cleaned_rows.append(row[['cid', 'acvalue', 'activity']])
    else:
        activities = []
        acvalues = []
        for row in rows:
            activities.append(row.activity)
            acvalues.append(row.acvalue)
#             if row.acvalue
        activities = set(activities)
        
        if len(activities) != 1:
            problem_rows.append(rows)
            continue
        activity = activities.pop()
        if activity not in valid_activities:
            continue
#         acvalues = np.array(acvalues)[np.where(pd.notna(acvalues))]
        acvalue = np.nanmedian(acvalues)
        row = pd.Series({'cid': cid, 'acvalue': acvalue, 'activity': activity})
        cleaned_rows.append(row)

In [None]:
df = pd.DataFrame(cleaned_rows).reset_index(drop=True)

In [None]:
len(df)

In [None]:
try:
    cid_to_pubchem = load_pickle(f'{ds_name}_cid_to_pubchem')
except:
    cid_to_pubchem = {}
    print('will send requests to pubchempy, might take a while')

Query pubchempy to get SMILES codes and keep a dictionary of the compounds.

In [None]:
try:
    assert len(cid_to_pubchem) == len(df)
except:
    for i, row in tqdm(df.iterrows(), total=len(df)):
        if row.cid in cid_to_pubchem:
            continue
        compound = Compound.from_cid(row.cid)
        cid_to_pubchem[row.cid] = compound

In [None]:
df['smiles'] = None

In [None]:
def set_smiles(row):
    compound = cid_to_pubchem[row.cid]
    row.smiles = compound.isomeric_smiles
    return row

In [None]:
df = df.progress_apply(set_smiles, axis=1)

In [None]:
df.head()

## 1.2 Sanitize Molecules

Sanitization has two steps: first, standardize the molecule, then, remove the salts from it. We will use RDKit for both tasks.

In [None]:
# EXPORT
cid_to_rdkit = {}

In [None]:
s = Standardizer()
remover = SaltRemover.SaltRemover()
print(f'len(remover.salts) = {len(remover.salts)}')

In [None]:
for i, row in tqdm(df.iterrows(), total=len(df)):
    mol = Chem.MolFromSmiles(row.smiles)
    mol = s.standardize(mol)
    mol = remover.StripMol(mol)
    cid_to_rdkit[row.cid] = mol

## 1.3 Remove non-druglike molecules

In [None]:
property_keys = {'molecular weight': 'MW', 'polar surface area': 'PSA', 'LogP': 'ALOGP', 
                 'rotateable bonds': 'ROTB', 'h-bond donors': 'HBD', 'h-bond acceptors': 'HBA'}

In [None]:
# EXPORT
qed_properties = {key: {} for key in property_keys}

In [None]:
for cid, mol in tqdm(cid_to_rdkit.items()):
    mol_props = QED.properties(mol)
    for key in property_keys:
        qed_properties[key][cid] = mol_props.__getattribute__(property_keys[key])
qed_properties_df = pd.DataFrame(qed_properties)

In [None]:
qed_properties_df.index.name = 'cid'
qed_properties_df.head()

In [None]:
if plots:
    for key, prop in qed_properties.items():
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
        fig.suptitle(key)
        sns.boxplot(list(prop.values()), ax=axes[1])
        sns.distplot(list(prop.values()), ax=axes[0])

In [None]:
Q1 = qed_properties_df.quantile(.25)
Q3 = qed_properties_df.quantile(.75)
IQR = Q3 - Q1
threshold = 1.5

In [None]:
qed_properties_outliers_removed_df = qed_properties_df[~(((qed_properties_df < (Q1 - threshold*IQR)) |  (qed_properties_df > (Q3 + threshold*IQR))).any(axis=1))]
print(f'{len(qed_properties_df) - len(qed_properties_outliers_removed_df)} outliers removed.')
qed_properties_df = qed_properties_outliers_removed_df

In [None]:
qed_properties = qed_properties_df.to_dict()
cids_to_keep = list(qed_properties_df.index)

In [None]:
cid_to_rdkit = {cid: cid_to_rdkit[cid] for cid in cids_to_keep}
df = df.query('cid in @cids_to_keep').reset_index(drop=True)

In [None]:
if plots:
    for key, prop in qed_properties.items():
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
        fig.suptitle(f'{key} (w/o outliers)')
        sns.boxplot(list(prop.values()), ax=axes[1])
        sns.distplot(list(prop.values()), ax=axes[0])

## 1.4 Resolve errors

### 1.4.1 Calculate pairwise similarities using fingerprints

In [None]:
N = len(cids_to_keep)

In [None]:
assert (N == len(cid_to_rdkit) and N == len(df))

In [None]:
from functools import partial

In [None]:
fingperint_function = partial(rdMolDescriptors.GetMorganFingerprintAsBitVect, 
                              radius=2, useChirality=True)
fp_name = 'morgan'

In [None]:
# EXPORT
cid_to_fingerprint = {cid: fingperint_function(mol) for cid, mol in cid_to_rdkit.items()
                     }
fingerprint_similarity_matrix = np.empty((N, N))

for i, (cid1, fps1) in tqdm(enumerate(cid_to_fingerprint.items()), total=len(cid_to_fingerprint)):
    for j, (cid2, fps2) in enumerate(cid_to_fingerprint.items()):
        fingerprint_similarity_matrix[i, j] = DataStructs.FingerprintSimilarity(fps1, fps2)

In [None]:
if plots:
    fig, ax = plt.subplots(figsize=(25,25))
    cax = ax.matshow(fingerprint_similarity_matrix, interpolation='nearest')
    ax.grid(False)
    plt.title('RDKIT fingerprint similarity matrix')
    plt.xticks(range(N), cids_to_keep, rotation=90);
    plt.yticks(range(N), cids_to_keep);
    ax.tick_params(axis='both', which='major', labelsize=4)
    _=fig.colorbar(cax, ticks=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, .75,.8,.85,.90,.95,1])
    plt.savefig(f'../dumps/{ds_name}_fingerprint_similarity_matrix.png', dpi=400)

In [None]:
if plots:
    cg = sns.clustermap(fingerprint_similarity_matrix, cbar_pos=None, figsize=(15, 15))
    plt.xticks(rotation=90)
    fig = cg.fig
    _ = fig.suptitle('RDKIT fingerprint similarity matrix (clustered)')
    plt.savefig(f'../dumps/{ds_name}_fingerprint_similarity_matrix_clustered', dpi=400)

In [None]:
def fingerprint_to_np(fp):
    bit_string = fp.ToBitString()
    return np.array([int(char) for char in bit_string], dtype=np.uint8)

In [None]:
def add_fingerprint(row):
    row.rdkit_fingerprint = fingerprint_to_np(cid_to_fingerprint[row.cid])
    return row

### 1.4.2 Find and remove duplicates

In [None]:
upper_triangle = (~np.eye(fingerprint_similarity_matrix.shape[0],dtype=bool) * np.triu(fingerprint_similarity_matrix))

In [None]:
similarity_threshold = .90

In [None]:
duplicates = set()
ix_to_cid = {i: key for i, key in enumerate(cid_to_fingerprint.keys())}
while True:
    candidates = {}
    for i, (cid1, fps1) in enumerate(cid_to_fingerprint.items()):
        if i in duplicates: continue
        similar = np.where(upper_triangle[i] > similarity_threshold)[0]
        if len(similar) > 0:
            for j in similar:
                if j in duplicates: continue
                candidates[j] = candidates.get(j, 0) + 1 
                candidates[i] = candidates.get(i, 0) + 1 
    if len(candidates) == 0:
        break
    sorted_candidates = sorted([(val, key) for key, val in candidates.items()], reverse=True)
    duplicates.add(sorted_candidates[0][1])
print('Will remove {} ({:.1%}) compounds.'.format(len(duplicates), len(duplicates)/N))

In [None]:
for i, (cid1, fps1) in enumerate(cid_to_fingerprint.items()):
    assert cid1 == ix_to_cid[i]

In [None]:
duplicates = {ix_to_cid[i] for i in duplicates}

In [None]:
cids_to_keep = list(filter(lambda x: x not in duplicates, cids_to_keep))

In [None]:
df = df.query('cid in @cids_to_keep')

In [None]:
cid_to_rdkit = {cid: cid_to_rdkit[cid] for cid in cids_to_keep}
cid_to_fingerprint = {cid: cid_to_fingerprint[cid] for cid in cids_to_keep}

In [None]:
assert len(cids_to_keep) == len(cid_to_rdkit) == len(df)

### 1.4.3 Find and remove activity cliffs

In [None]:
N = len(df)

In [None]:
# EXPORT
fingerprint_similarity_matrix = np.empty((N, N))

for i, (cid1, fps1) in tqdm(enumerate(cid_to_fingerprint.items()), total=len(cid_to_fingerprint)):
    for j, (cid2, fps2) in enumerate(cid_to_fingerprint.items()):
        fingerprint_similarity_matrix[i, j] = DataStructs.FingerprintSimilarity(fps1, fps2)
upper_triangle = (~np.eye(fingerprint_similarity_matrix.shape[0],dtype=bool) * np.triu(fingerprint_similarity_matrix))

In [None]:
similarity_threshold = .85
activitiy_ratio_threshold = 100 #this is in folds as in 100-times fold.

In [None]:
cid_to_acvalue = {}
for i, row in df.iterrows():
    cid_to_acvalue[row.cid] = row.acvalue

In [None]:
irregularities = set()
ix_to_cid = {i: key for i, key in enumerate(cid_to_fingerprint.keys())}
for i, (cid1, fps1) in enumerate(cid_to_fingerprint.items()):
    acvalue1 = cid_to_acvalue[cid1]
    similar = np.where(upper_triangle[i] > similarity_threshold)[0]
    if len(similar) > 0:
        for j in similar:
            cid2 = ix_to_cid[j]
            if cid_to_acvalue[cid2] / acvalue1 > activitiy_ratio_threshold or acvalue1 / cid_to_acvalue[cid2] > activitiy_ratio_threshold:
                irregularities.add(cid1)
                irregularities.add(cid2)

print('Will remove {} ({:.1%}) compounds.'.format(len(irregularities), len(irregularities)/N))

In [None]:
cids_to_keep = list(filter(lambda x: x not in irregularities, cids_to_keep))

In [None]:
df = df.query('cid in @cids_to_keep')

In [None]:
cid_to_rdkit = {cid: cid_to_rdkit[cid] for cid in cids_to_keep}
cid_to_fingerprint = {cid: cid_to_fingerprint[cid] for cid in cids_to_keep}

In [None]:
assert len(cids_to_keep) == len(cid_to_rdkit) == len(df)

## Exports & Imports

In [None]:
def save_pickle(obj, filename):
    with open(f'../dumps/{filename}.pkl', 'wb') as f:
        pickle.dump(obj, f)

In [None]:
save_pickle(cid_to_pubchem, f'{ds_name}_cid_to_pubchem')
save_pickle(cid_to_rdkit, f'{ds_name}_cid_to_rdkit')

In [None]:
df.to_csv(f'../dumps/{ds_name}_processed.csv', index=False)

In [None]:
np_fingerprints = {str(cid): fingerprint_to_np(cid_to_fingerprint[cid]) for cid in df.cid}

In [None]:
np.savez(f'../dumps/{ds_name}_{fp_name}_fingerprints.npz', **np_fingerprints)