In [None]:
%matplotlib inline
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Scaffolds import MurckoScaffold
from rdkit.Chem import rdMolDescriptors
from multiprocessing import Pool, cpu_count
import numpy as np
import matplotlib.pyplot as plt 
import itertools
from rdkit.Chem import PandasTools
import glob
from pathlib import Path
import os

OVERWRITE_FILES = False

def normalize_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return Chem.MolToSmiles(mol, isomericSmiles=False)
    else:
        return np.NaN

def normalize_smiles_series(smiles_series):
    return smiles_series.map(normalize_smiles)


    
def parallelize_dataframe(df, func):
    parts = np.array_split(df, cpu_count())
    pool = Pool(cpu_count())
    df = pd.concat(pool.map(func, list(parts)))
    pool.close()
    pool.join()
    return df
 

def generate_murcko_scaffold(smile):
    try:
        mol = Chem.MolFromSmiles(smile)
        if mol:
            scaffold = MurckoScaffold.GetScaffoldForMol(mol)
            return Chem.MolToSmiles(scaffold, isomericSmiles=False)
        else:
            return np.NaN
    except:
        return np.NaN
    

def generate_topological_scaffold(smile):
    try:
        mol = Chem.MolFromSmiles(smile)
        if mol:
            scaffold = MurckoScaffold.MakeScaffoldGeneric(MurckoScaffold.GetScaffoldForMol(mol))
            return Chem.MolToSmiles(scaffold, isomericSmiles=False)
        else:
            return np.NaN
    except:
        return np.NaN


    
                   
def generate_murcko_scaffold_series(data):
    return data.map(generate_murcko_scaffold)

def generate_topological_scaffold_series(data):
    return data.map(generate_topological_scaffold)


def calcLogP(smiles):
    return rdMolDescriptors.CalcCrippenDescriptors(Chem.MolFromSmiles(smiles))[0] 

def calcLogP_series(df):
    return df.map(calcLogP)

        

In [None]:
filternames = ["NoFilter", "CompoundSimilarity", "IdenticalMurckoScaffold", "IdenticalTopologicalScaffold", "ScaffoldSimilarity"]
filternames_in_plots = [ "No memory", "CompoundSimilarity memory", "IdenticalMurckoScaffold memory", "IdenticalTopologicalScaffold memory", "ScaffoldSimilarity memory"]

target_params = {
    "DRD2":      {"maxstep": 300,
                  "minactivity": 0.7},
    "HTR1A":     {"maxstep": 300,
                  "minactivity": 0.7},
    "clogP":     {"maxstep": 150,
                  "minactivity": 1.}
}



pathname = f"{Path.home()}/REINVENT/results/*/scaffold_memory.csv"
for path in glob.glob(pathname):

    folder = path.split("/")[-2].replace(" ","_")
    if os.path.exists(f"data/memories/{folder}/memory_preprocessed.csv.gz") and not OVERWRITE_FILES:
            print(f"Skipping {folder} as it seems to already be processed")
            continue

    elements = folder.split("_")
    if len(elements) > 7:
        continue
    target, filtername, minsimilarity, bucket_size, outputmode, temperature, experience_replay = elements
    minsimilarity = float(minsimilarity)
    bucket_size = int(bucket_size)
    temperature = float(temperature)
    experience_replay = bool(experience_replay)        

    memory = pd.read_csv(path)
    if len(memory) <= 1:
        print(f"{path} contains nothing")
        continue
    memory.rename(columns = {'SMILES':'GENERATED_SMILES'}, inplace = True)
    memory['SMILES'] = parallelize_dataframe(memory["GENERATED_SMILES"], normalize_smiles_series)
    memory = memory.dropna()
    memory = memory.sort_values(by=['step'])
    memory = memory.drop_duplicates('SMILES', keep="first")
    memory["Murcko Scaffold"] = parallelize_dataframe(memory["SMILES"], generate_murcko_scaffold_series)
    memory["Topological Scaffold"] = parallelize_dataframe(memory["SMILES"], generate_topological_scaffold_series)
    memory = memory.dropna()

    memory['ID'] = ["generated_{}_{}_{}_{}_{}_{}_{}_{}".format(target.replace(" ","_"), filtername, minsimilarity, bucket_size, outputmode, temperature, experience_replay, i) for i in range(len(memory))]

    maxstep = target_params[target]["maxstep"]
    memory = memory.query("step < @maxstep")
    os.makedirs(f"data/memories/{folder}", exist_ok=True)
    memory.to_csv(f"data/memories/{folder}/memory_preprocessed.csv.gz", index=False)


    os.makedirs(f"to_fragment/{folder}", exist_ok=True)
    memory[["SMILES","ID"]].to_csv(f"to_fragment/{folder}/generated_to_fragment.smi" ,sep=",",index=False,header=False)


In [None]:
targets = ["DRD2"]
for target in targets:
    df = pd.read_pickle(f"{Path.home()}/projects/reinvent-classifiers/{target}_df.pkl.gz").query("activity_label == 1")[["Original_Entry_ID","DB","RDKIT_SMILES","trainingset_class","cluster_id","activity_label","cfp"]]
    if os.path.exists(f"data/{target}/actives.pkl.gz") and not OVERWRITE_FILES:
        print(f"Skipping {target} as it seems to already be processed")
        continue
    _training = 0
    _test = 0
    _validation = 0
    _target = target

    def make_id(row):
        global _training, _test, _validation, _target

        template = "{}_{}_{}_{}"

        if row["trainingset_class"] == "training":
            _training += 1
            return template.format(row["Original_Entry_ID"], _target, row["trainingset_class"], _training)

        elif row["trainingset_class"] == "test":
            _test += 1
            return template.format(row["Original_Entry_ID"], _target, row["trainingset_class"], _test)

        elif row["trainingset_class"] == "validation":
            _validation += 1
            return template.format(row["Original_Entry_ID"], _target, row["trainingset_class"], _validation)

    df['ID'] = df.apply(make_id, axis=1)
    df["Murcko Scaffold"] = parallelize_dataframe(df["RDKIT_SMILES"], generate_murcko_scaffold_series)
    df["Topological Scaffold"] = parallelize_dataframe(df["RDKIT_SMILES"], generate_topological_scaffold_series)
    
    os.makedirs(f"data/{target}/", exist_ok=True)
    df.to_pickle(f"data/{target}/actives.pkl.gz")
    df = df[["RDKIT_SMILES","ID"]]
    os.makedirs(f"to_fragment/{target}", exist_ok=True)
    df[["RDKIT_SMILES","ID"]].to_csv(f"to_fragment/{target}/actives_to_fragment.smi",sep=",",index=False,header=False)


In [None]:

from joblib import Parallel, delayed
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem import AllChem
from rdkit.Chem import SaltRemover
from rdkit.Chem import rdmolops

rdBase.DisableLog('rdApp.error')


def _initialiseNeutralisationReactions():
    patts = (
        # Imidazoles
        ('[n+;H]', 'n'),
        # Amines
        ('[N+;!H0]', 'N'),
        # Carboxylic acids and alcohols
        ('[$([O-]);!$([O-][#7])]', 'O'),
        # Thiols
        ('[S-;X1]', 'S'),
        # Sulfonamides
        ('[$([N-;X2]S(=O)=O)]', 'N'),
        # Enamines
        ('[$([N-;X2][C,N]=C)]', 'N'),
        # Tetrazoles
        ('[n-]', '[nH]'),
        # Sulfoxides
        ('[$([S-]=O)]', 'S'),
        # Amides
        ('[$([N-]C=O)]', 'N'),
        )
    return [(Chem.MolFromSmarts(x), Chem.MolFromSmiles(y, False)) for x, y in patts]


_reactions = _initialiseNeutralisationReactions()


def _neutraliseCharges(mol, reactions=None):
    global _reactions
    if reactions is None:
        reactions = _reactions
    replaced = False
    for i, (reactant, product) in enumerate(reactions):
        while mol.HasSubstructMatch(reactant):
            replaced = True
            rms = AllChem.ReplaceSubstructs(mol, reactant, product)
            mol = rms[0]
    if replaced:
        return mol, True
    else:
        return mol, False


def _getlargestFragment(mol):
    frags = rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
    maxmol = None
    for mol in frags:
        if mol is None:
            continue
        if maxmol is None:
            maxmol = mol
        if maxmol.GetNumHeavyAtoms() < mol.GetNumHeavyAtoms():
            maxmol = mol
    return maxmol


_saltremover = SaltRemover.SaltRemover()


def valid_size(mol, min_heavy_atoms, max_heavy_atoms, element_list, remove_long_side_chains):
    """Filters molecules on number of heavy atoms and atom types"""
    if mol:
        correct_size = min_heavy_atoms < mol.GetNumHeavyAtoms() < max_heavy_atoms
        if not correct_size:
            return

        valid_elements = all([atom.GetAtomicNum() in element_list for atom in mol.GetAtoms()])
        if not valid_elements:
            return

        has_long_sidechains = False
        if remove_long_side_chains:
            # remove aliphatic side chains with at least 4 carbons not in a ring
            sma = '[CR0]-[CR0]-[CR0]-[CR0]'
            has_long_sidechains = mol.HasSubstructMatch(Chem.MolFromSmarts(sma))

        return correct_size and valid_elements and not has_long_sidechains


def standardize_smiles(smiles, min_heavy_atoms=10, max_heavy_atoms=50, element_list=[6, 7, 8, 9, 16, 17, 35],
                       remove_long_side_chains=False, neutralise_charges=True):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        mol = _getlargestFragment(mol)
    if mol:
        mol = rdmolops.RemoveHs(mol, implicitOnly=False, updateExplicitCount=False, sanitize=True)
    if mol:
        mol = _saltremover.StripMol(mol, dontRemoveEverything=True)
    if mol and neutralise_charges:
        mol, _ = _neutraliseCharges(mol)
    if mol:
        rdmolops.Cleanup(mol)
        rdmolops.SanitizeMol(mol)
        mol = rdmolops.RemoveHs(mol, implicitOnly=False, updateExplicitCount=False, sanitize=True)
    if mol and valid_size(mol, min_heavy_atoms, max_heavy_atoms, element_list, remove_long_side_chains):
        return Chem.MolToSmiles(mol, isomericSmiles=False)
    return np.NaN


def standardize_smiles_from_file(fname):
    """Reads a SMILES file and returns a list of RDKIT SMILES"""
    with open(fname, 'r') as f:
        smiles_list = [line.strip().split(" ")[0] for line in f]
    return standardize_smiles_list(smiles_list)


def standardize_smiles_list(smiles_list):
    """Reads a SMILES list and returns a list of RDKIT SMILES"""
    smiles_list = Parallel(n_jobs=-1, verbose=0)(delayed(standardize_smiles)(line) for line in smiles_list)
    smiles_list = [smiles for smiles in set(smiles_list) if smiles is not None]
    logging.debug("{} unique SMILES retrieved".format(len(smiles_list)))
    return smiles_list