# Prepare Training Data  
  
In this step we read the data from [Maier, L., et al (2018)](https://www.nature.com/articles/nature25979). And prepare it to be used for training, testing and validation of the XGBoost model. 

In [1]:
# Read Libraries
import os

import numpy as np
import pandas as pd

from tqdm import tqdm
import pubchempy as pcp

from dataset.dataset_representation import process_dataset


from rdkit import Chem
from rdkit.Chem import Descriptors

## Global variables  

In [2]:
# Directory from which to read the raw data
INPUT_DIR = '../raw_data/maier_microbiome'

# Create the output directory
OUTPUT_DIR = "../data/01.prepare_training_data"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Pvalue cutoff for label determination
PVAL_CUTOFF = 0.05

## Read data and binarize
  
Here we read in Supplementary Table 3 from the study.

In [3]:
# Read the raw datas
screen_df = pd.read_excel(os.path.join(INPUT_DIR, "screen_results_info_SF3.xlsx"))

# Clean the data
screen_df.drop(columns=["chemical_name", "drug_class", "n_hit"], inplace=True)
screen_df.set_index("prestwick_ID", inplace=True)

# Convert the data to binary
screen_df = screen_df <= PVAL_CUTOFF
screen_df = screen_df.astype(int)
screen_df.head()



Unnamed: 0_level_0,Akkermansia muciniphila (NT5021),Bacteroides caccae (NT5050),Bacteroides fragilis (ET) (NT5033),Bacteroides fragilis (NT) (NT5003),Bacteroides ovatus (NT5054),Bacteroides thetaiotaomicron (NT5004),Bacteroides uniformis (NT5002),Bacteroides vulgatus (NT5001),Bacteroides xylanisolvens (NT5064),Bifidobacterium adolescentis (NT5022),...,Parabacteroides merdae (NT5071),Prevotella copri (NT5019),Roseburia hominis (NT5079),Roseburia intestinalis (NT5011),Ruminococcus bromii (NT5045),Ruminococcus gnavus (NT5046),Ruminococcus torques (NT5047),Streptococcus parasanguinis (NT5072),Streptococcus salivarius (NT5038),Veillonella parvula (NT5017)
prestwick_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Prestw-1109,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
Prestw-1399,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
Prestw-145,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
Prestw-1464,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1
Prestw-31,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,1,1,1


## Gather SMILES
  
We use the chemical names to gather the SMILES from PubChem using [PubChemPy](https://pubchempy.readthedocs.io/en/latest/)


In [None]:
# FUNCTIONS FOR PROCESSING OBTAINING SMILES

def clean_names_chemlibrary(original_name):
    
    # Remove additional information from name 
    name = original_name.split(" (")[0]
    name = name.split(" [")[0]
    name = name.rstrip()
    
    return name

def get_pubchemid(name):
    
    """
    Retrieve PubChem compound information based on the given name.

    Parameters:
    - name (str): The name of the compound to search for.

    Returns:
    - result_df (pandas.DataFrame): DataFrame containing PubChem compound information.
        Columns:
            - 'name' (str): The name of the compound.
            - 'cid' (str): PubChem Compound ID. 'not_found' if the compound is not found.
            - 'pchem_canonical_smile' (str): Canonical SMILES representation of the compound.
                'not_found' if the compound is not found.
            - 'pchem_isomeric_smile' (str): Isomeric SMILES representation of the compound.
                'not_found' if the compound is not found.
            - 'pchem_inchi' (str): InChI representation of the compound.
                'not_found' if the compound is not found.
            - 'pchem_inchikey' (str): InChIKey representation of the compound.
                'not_found' if the compound is not found.
    """
    
    # Attempt to find result with search name
    results = pcp.get_compounds(name, 'name')
    result_dict = {}

    # If that did not work, use the clean name
    if len(results) == 0:
        clean_name = clean_names_chemlibrary(name)
        results = pcp.get_compounds(clean_name, "name")
    
    # Now prepare the output
    if len(results) > 0:
        result_dict["name"] = name 
        result_dict["cid"] = results[0].cid
        result_dict["pchem_canonical_smile"] = results[0].canonical_smiles
        result_dict["pchem_isomeric_smile"] = results[0].isomeric_smiles
        result_dict["pchem_inchi"] = results[0].inchi
        result_dict["pchem_inchikey"] = results[0].inchikey
        
    else:
        result_dict["name"] = name
        result_dict["cid"] = "not_found"
        result_dict["pchem_canonical_smile"] = "not_found"
        result_dict["pchem_isomeric_smile"] = "not_found"
        result_dict["pchem_inchi"] = "not_found"
        result_dict["pchem_inchikey"] = "not_found"
        
    result_df = pd.DataFrame(result_dict, index=[0])
    
    return result_df


def cid_info(cid, df):
    """
    Retrieve information about a compound from PubChem using its CID (Compound ID).

    Parameters:
    - cid (str): PubChem Compound ID of the compound to query.
    - df (pandas.DataFrame): DataFrame containing compound information with 'cid' column.

    Returns:
    - result_df (pandas.DataFrame): DataFrame containing PubChem compound information.
        Columns:
            - 'name' (str): The name of the compound.
            - 'cid' (str): PubChem Compound ID.
            - 'pchem_canonical_smile' (str): Canonical SMILES representation of the compound.
            - 'pchem_isomeric_smile' (str): Isomeric SMILES representation of the compound.
            - 'pchem_inchi' (str): InChI representation of the compound.
            - 'pchem_inchikey' (str): InChIKey representation of the compound.
    """
    
    # Query Pubchem
    results = pcp.Compound.from_cid(cid)
    
    # Init dictionary
    result_dict = {}
    
    # Prepare output
    result_dict["name"] = df.loc[df["cid"]==cid, "name"].values[0]
    result_dict["cid"] = cid
    result_dict["pchem_canonical_smile"] = results.canonical_smiles
    result_dict["pchem_isomeric_smile"] = results.isomeric_smiles
    result_dict["pchem_inchi"] = results.inchi
    result_dict["pchem_inchikey"] = results.inchikey
    
    return pd.DataFrame(result_dict, index=[0])

In [7]:
# READ THE DATA

maier_chemicals = pd.read_excel(os.path.join(INPUT_DIR, "chem_library_info_SF1.xlsx"))
maier_chemicals.set_index("prestwick_ID", inplace=True)

I will use the **chemical_name** field to query PubChem (via pubchempy) to find their SMILES and other relevant information

In [None]:
# Iterate over the unique names in the dataset
cid_search = pd.concat([get_pubchemid(name) for name in tqdm(maier_chemicals["chemical name"].unique())])

# Couldn't find these chemicals using PCP, so we will manually add them
manual_cid = pd.DataFrame([["(-)-Eseroline fumarate salt", 16219298], ["Clonixin Lysinate", 3080836], ["Ziprasidone  Hydrochloride", 219099],
             ["Clavulanate potassium salt", 23665591], ["Oxibendazol", 4622], 
             ["Morpholinoethylamino-3-benzocyclohepta-(5,6-c)-pyridazine dihydrochloride", 195164],
             ["Gabazine bromide", 71316800], ["Colistin sulfate", 91885449], ["Bacitracin", 11980094]], 
             columns=["name", "cid"])

manual_try = pd.concat([cid_info(c, df=manual_cid) for c in tqdm(manual_cid.cid.to_list())])

# Concatenate the results
chem_smiles = pd.concat([cid_search[cid_search["cid"] != "not_found"], manual_try]).drop_duplicates()


We can now get the RDKit versions of these SMILES

In [None]:
# Obtiain canonical smiles
chem_smiles["rdkit_canonical_smile"] = chem_smiles["pchem_canonical_smile"].apply(lambda x: Chem.MolToSmiles(Chem.MolFromSmiles(x),
                                                                                          canonical=True, 
                                                                                          isomericSmiles=False))

# Obtain isomeric smiles
chem_smiles["rdkit_isomeric_smile"] = chem_smiles["pchem_canonical_smile"].apply(lambda x: Chem.MolToSmiles(Chem.MolFromSmiles(x),
                                                                                          canonical=True, 
                                                                                          isomericSmiles=True))

# Obtain chemical metadata
chem_smiles["n_atoms"] = chem_smiles["pchem_canonical_smile"].apply(lambda x: Chem.MolFromSmiles(x).GetNumAtoms())
chem_smiles["n_bonds"] = chem_smiles["pchem_canonical_smile"].apply(lambda x: Chem.MolFromSmiles(x).GetNumBonds())
chem_smiles["ExactMolWt"] = chem_smiles["pchem_canonical_smile"].apply(lambda x: Descriptors.ExactMolWt(Chem.MolFromSmiles(x)))


# Remove the salts
remover = SaltRemover()
chem_smiles["rdkit_no_salt"] = chem_smiles["rdkit_canonical_smile"].apply(lambda x: Chem.MolToSmiles(remover.StripMol(Chem.MolFromSmiles(x))))


Now we can combine this information with the provided data

In [None]:
maier_chemicals.reset_index(inplace=True)
maier_chemicals.set_index("chemical name", inplace=True)

chem_smiles.set_index("name", inplace=True)

chemical_metadata = maier_chemicals.join(chem_smiles)
chemical_metadata.head()

With that, we can write the final output

In [None]:
os.makedirs("../data/01.prepare_training_data", exist_ok=True)
chemical_metadata.to_csv("../data/01.prepare_training_data/prestwick_library.tsv.gz", sep='\t')

## Molecular representation and data splitting. 
  
Now that we have determined the labels, we can now represent the chemical library using MolE, ECFP4 and Chemical Descriptors. At the same time, we can split the dataset using scaffold splitting

In [8]:
chemical_metadata_screened = chemical_metadata.loc[chemical_metadata["prestwick_ID"].isin(screen_df.index)]
chemical_metadata_screened.to_csv(os.path.join(OUTPUT_DIR, "prestwick_library_screened.tsv.gz"), sep='\t')

In [3]:
# MOLE REPRESENTATION

maier_scaffold_split, mole_representation = process_dataset(dataset_path = os.path.join(OUTPUT_DIR, "prestwick_library_screened.tsv.gz"), 
                                                  pretrain_architecture = "gin_concat", 
                                                  pretrained_model = "model_ginconcat_btwin_100k_d8000_l0.0001", 
                                                  
                                                  split_approach = "scaffold", 
                                                  validation_proportion = 0.1, 
                                                  test_proportion = 0.1, 
                                                  
                                                  smile_column_str = "rdkit_no_salt", 
                                                  id_column_str = "prestwick_ID") 

About to generate scaffolds
About to sort in scaffold sets
Representation dimension (1000) - Embedding dimension (8000)
../pretrained_model/model_ginconcat_btwin_100k_d8000_l0.0001/model.pth
x_embedding1.weight
x_embedding2.weight
gnns.0.mlp.0.weight
gnns.0.mlp.0.bias
gnns.0.mlp.1.weight
gnns.0.mlp.1.bias
gnns.0.mlp.1.running_mean
gnns.0.mlp.1.running_var
gnns.0.mlp.1.num_batches_tracked
gnns.0.mlp.3.weight
gnns.0.mlp.3.bias
gnns.0.edge_embedding1.weight
gnns.0.edge_embedding2.weight
gnns.1.mlp.0.weight
gnns.1.mlp.0.bias
gnns.1.mlp.1.weight
gnns.1.mlp.1.bias
gnns.1.mlp.1.running_mean
gnns.1.mlp.1.running_var
gnns.1.mlp.1.num_batches_tracked
gnns.1.mlp.3.weight
gnns.1.mlp.3.bias
gnns.1.edge_embedding1.weight
gnns.1.edge_embedding2.weight
gnns.2.mlp.0.weight
gnns.2.mlp.0.bias
gnns.2.mlp.1.weight
gnns.2.mlp.1.bias
gnns.2.mlp.1.running_mean
gnns.2.mlp.1.running_var
gnns.2.mlp.1.num_batches_tracked
gnns.2.mlp.3.weight
gnns.2.mlp.3.bias
gnns.2.edge_embedding1.weight
gnns.2.edge_embedding2.we

In [4]:
# ECFP4

ecfp4_representation = process_dataset(dataset_path = os.path.join(OUTPUT_DIR, "prestwick_library_screened.tsv.gz"), 
                                                  
                                                  pretrained_model = "ECFP4", 
                                                  dataset_split=False,

                                                  smile_column_str = "rdkit_no_salt", 
                                                  id_column_str = "prestwick_ID") 

In [6]:
# CHEMICAL DESCRIPTORS
chemdesc_representation = process_dataset(dataset_path = os.path.join(OUTPUT_DIR, "prestwick_library_screened.tsv.gz"), 
                                                  pretrained_model = "ChemDesc", 
                                                  dataset_split=False,

                                                  smile_column_str = "pchem_isomeric_smile", 
                                                  id_column_str = "prestwick_ID")
  


Could not compute descriptors for Prestw-919


## Prepare representations for 100K molecules from PubChem

In [6]:
mole_pubchem = pd.read_csv(os.path.join("../data/01.prepare_training_data/", "pubchem_mole_representation.tsv.gz"), sep='\t', index_col=0)

In [16]:
mole_pubchem.iloc[0:25_000, :].to_csv("../data/01.prepare_training_data/pubchem_mole_representation_p1.tsv.gz", sep='\t')
mole_pubchem.iloc[25_000:50_000, :].to_csv("../data/01.prepare_training_data/pubchem_mole_representation_p2.tsv.gz", sep='\t')
mole_pubchem.iloc[50_000:75_000, :].to_csv("../data/01.prepare_training_data/pubchem_mole_representation_p3.tsv.gz", sep='\t')
mole_pubchem.iloc[75_000:, :].to_csv("../data/01.prepare_training_data/pubchem_mole_representation_p4.tsv.gz", sep='\t')

## Write files

In [18]:
#maier_scaffold_split.to_csv(os.path.join(OUTPUT_DIR, "maier_scaffold_split.tsv.gz"), sep='\t')
#mole_representation.to_csv(os.path.join(OUTPUT_DIR, "maier_mole_representation.tsv.gz"), sep='\t')
#ecfp4_representation.to_csv(os.path.join(OUTPUT_DIR, "maier_ecfp4_representation.tsv.gz"), sep='\t')
#chemdesc_representation.to_csv(os.path.join(OUTPUT_DIR, "maier_chemdesc_representation.tsv.gz"), sep='\t', index=False)

## Written files

In [10]:
os.listdir(OUTPUT_DIR)

['maier_ecfp4_representation.tsv.gz',
 'prestwick_library.tsv.gz',
 'prestwick_library_screened.tsv.gz',
 'maier_scaffold_split.tsv.gz',
 'maier_chemdesc_representation.tsv.gz',
 'maier_mole_representation.tsv.gz',
 'maier_screening_results.tsv.gz']