# Standardization
This notebook seeks to standardize the structure of the ligand and protein, in order to deliver data that are standardized for the training.
A special interest is given to AVOID impacting the original conformation, pose, and defined stereoisomery. To pass this checkpoint, the compounds requires to have valid 2D SMILES, with a valid protein .pdb and ligand .sdf. The ligand .sdf is of importance as its charge, aromaticity, adn coordinates are considered by the GVP featurization step. 


# Function

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski, rdMolDescriptors, QED, Draw
from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges
from rdkit.RDLogger import DisableLog
from Bio.PDB import PDBParser, PDBIO
from joblib import Parallel, delayed
from tqdm import tqdm
import pandas as pd
import os
import io
import matplotlib.pyplot as plt
from multiprocessing import Pool, cpu_count
import os
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from Bio.PDB import PDBParser, PDBIO


import os
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdPartialCharges
from Bio.PDB import PDBParser, PDBIO


import io
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.warning')

    
import os
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from rdkit import Chem

N_PROC = cpu_count() - 1



from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, Lipinski, rdMolDescriptors, QED
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm

tqdm.pandas()

def compute_props(smiles):
    if not isinstance(smiles, str) or smiles.strip() == '':
        return {
            'InChIKey': None,
            'MolWt': None,
            'HeavyAtomCount': None,
            'QED': None,
            'NumHDonors': None,
            'NumHAcceptors': None,
            'NumRotatableBonds': None,
            'TPSA': None,
            'LogP': None
        }
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return {
            'InChIKey': None,
            'MolWt': None,
            'HeavyAtomCount': None,
            'QED': None,
            'NumHDonors': None,
            'NumHAcceptors': None,
            'NumRotatableBonds': None,
            'TPSA': None,
            'LogP': None
        }
    return {
        'InChIKey': Chem.MolToInchiKey(mol),
        'MolWt': Descriptors.MolWt(mol),
        'HeavyAtomCount': mol.GetNumHeavyAtoms(),
        'QED': QED.qed(mol),
        'NumHDonors': Lipinski.NumHDonors(mol),
        'NumHAcceptors': Lipinski.NumHAcceptors(mol),
        'NumRotatableBonds': Lipinski.NumRotatableBonds(mol),
        'TPSA': rdMolDescriptors.CalcTPSA(mol),
        'LogP': Crippen.MolLogP(mol)
    }

def add_molecular_properties_parallel(df, smiles_col='std_smiles', n_jobs=-1):
    smiles_list = df[smiles_col].tolist()
    props = Parallel(n_jobs=n_jobs)(
        delayed(compute_props)(smi) for smi in tqdm(smiles_list)
    )
    props_df = pd.DataFrame(props)
    return pd.concat([df.reset_index(drop=True), props_df], axis=1)

def compute_ligand_efficiency(df):
    for col in ['pKi', 'pKd', 'pEC50', 'pIC50']:
        if col in df.columns:
            le_col = f'LE_{col}'
            df[le_col] = df.apply(
                lambda row: row[col] / row['HeavyAtomCount']
                if pd.notnull(row[col]) and pd.notnull(row['HeavyAtomCount']) and row['HeavyAtomCount'] > 0
                else None,
                axis=1
            )
    return df

def compute_mean_ligand_efficiency(df):
    le_cols = ['LE_pKi', 'LE_pKd', 'LE_pEC50', 'LE_pIC50']
    df['LE'] = df[le_cols].mean(axis=1, skipna=True)
    return df

def display_top_bottom_le(df, n=20):
    df_valid = df.dropna(subset=['LE'])
    top_le = df_valid.sort_values('LE', ascending=False).head(n)
    bottom_le = df_valid.sort_values('LE', ascending=True).head(n)
    return top_le, bottom_le

from rdkit import Chem
from rdkit.Chem import Draw

def show_molecules(smiles_list, title='Molecules'):
    mols = [Chem.MolFromSmiles(smi) for smi in smiles_list]
    return Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(200, 200), legends=smiles_list)

import matplotlib.pyplot as plt

def plot_property_distributions(df):
    columns_to_plot = ['LogP', 'QED', 'MolWt', 'HeavyAtomCount', "LE"]
    for col in columns_to_plot:
        plt.figure()
        df[col].dropna().hist(bins=50)
        plt.title(f'Distribution of {col}')
        plt.xlabel(col)
        plt.ylabel('Frequency')
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        
        
DisableLog('rdApp.warning')
N_PROC = cpu_count() - 1

# === STANDARDIZE PROTEINS ===
def standardize_protein(pdb_path, output_path):
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("prot", pdb_path)
    io = PDBIO()
    io.set_structure(structure)
    io.save(output_path)




def standardize_ligand(path, output_path):
    ext = os.path.splitext(path)[1].lower()
    if ext == '.sdf':
        mol = Chem.MolFromMolFile(path, removeHs=False)   # keep H + coords
    elif ext == '.pdb':
        mol = Chem.MolFromPDBFile(path, removeHs=False)
    else:
        return False
    if mol is None:
        return False

    mol = keep_only_polar_H_rdkit(mol)                   # <<< enforce rule

    Chem.SanitizeMol(mol)
    Chem.AssignStereochemistry(mol, cleanIt=False, force=True)
    # NO embedding/minimization: preserves localization
    ComputeGasteigerCharges(mol)
    Chem.MolToMolFile(mol, output_path)
    return True


def _process_ligand(args):
    idx, lig_path, smiles, input_sdf_subpath = args
    lig_out = f"{input_sdf_subpath}/{idx}.sdf"
    success = standardize_ligand(lig_path, lig_out)
    return lig_out if success else None

def standardize_all_ligands(df, input_sdf_subpath):
    os.makedirs(input_sdf_subpath, exist_ok=True)
    args = [(idx, row['ligand_sdf_path'], row['smiles'], input_sdf_subpath) for idx, row in df.iterrows()]
    with Pool(N_PROC) as pool:
        new_paths = list(tqdm(pool.imap(_process_ligand, args), total=len(args)))
    df['standardized_ligand_sdf'] = new_paths
    return df



# === MOLECULAR PROPERTIES + LE ===
def compute_props(smiles):
    if not isinstance(smiles, str) or smiles.strip() == '':
        return {k: None for k in ['InChIKey', 'MolWt', 'HeavyAtomCount', 'QED', 'NumHDonors',
                                  'NumHAcceptors', 'NumRotatableBonds', 'TPSA', 'LogP']}
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return {k: None for k in ['InChIKey', 'MolWt', 'HeavyAtomCount', 'QED', 'NumHDonors',
                                  'NumHAcceptors', 'NumRotatableBonds', 'TPSA', 'LogP']}
    return {
        'InChIKey': Chem.MolToInchiKey(mol),
        'MolWt': Descriptors.MolWt(mol),
        'HeavyAtomCount': mol.GetNumHeavyAtoms(),
        'QED': QED.qed(mol),
        'NumHDonors': Lipinski.NumHDonors(mol),
        'NumHAcceptors': Lipinski.NumHAcceptors(mol),
        'NumRotatableBonds': Lipinski.NumRotatableBonds(mol),
        'TPSA': rdMolDescriptors.CalcTPSA(mol),
        'LogP': Crippen.MolLogP(mol)
    }

def add_molecular_properties_parallel(df, smiles_col='std_smiles', n_jobs=-1):
    smiles_list = df[smiles_col].tolist()
    props = Parallel(n_jobs=n_jobs)(delayed(compute_props)(smi) for smi in tqdm(smiles_list))
    props_df = pd.DataFrame(props)
    return pd.concat([df.reset_index(drop=True), props_df], axis=1)

def compute_ligand_efficiency(df):
    activity_cols = [col for col in df.columns if col.strip().lower() in ['pki', 'pkd', 'pec50', 'pic50']]
    
    for col in activity_cols:
        le_col = f'LE_{col}'
        le_norm_col = f'LEnorm_{col}'

        df[le_col] = df.apply(
            lambda row, col=col: row[col] / row['HeavyAtomCount']
            if pd.notnull(row[col]) and pd.notnull(row['HeavyAtomCount']) and row['HeavyAtomCount'] > 0
            else None,
            axis=1
        )

        df[le_norm_col] = df.apply(
            lambda row, col=col: row[f'LE_{col}'] / row['MolWt']
            if pd.notnull(row.get(f'LE_{col}')) and pd.notnull(row.get('MolWt')) and row['MolWt'] > 0
            else None,
            axis=1
        )

    return df



def compute_mean_ligand_efficiency(df):
    le_cols = [c for c in df.columns if c.startswith("LE_") and not c.startswith("LEnorm_")]
    le_norm_cols = [c for c in df.columns if c.startswith("LEnorm_")]

    df['LE'] = df[le_cols].mean(axis=1, skipna=True)
    df['LE_norm'] = df[le_norm_cols].mean(axis=1, skipna=True)
    return df


# === FILTER & VISUALIZE ===
def filter_invalid_ligands(df):
    return df[df['ligand_sdf_path'].notna()].reset_index(drop=True)

def display_top_bottom_le(df, n=10):
    df_valid = df.dropna(subset=['LE_norm'])
    top_le = df_valid.sort_values('LE_norm', ascending=False).head(n)
    bottom_le = df_valid.sort_values('LE_norm', ascending=True).head(n)
    return top_le, bottom_le

def show_molecules(smiles_list, title='Molecules'):
    mols = [Chem.MolFromSmiles(smi) for smi in smiles_list]
    return Draw.MolsToGridImage(mols, molsPerRow=5, subImgSize=(200, 200), legends=smiles_list)

    
def plot_property_distributions(df):
    for col in ['LogP', 'QED', 'MolWt', 'HeavyAtomCount', 'LE', "LE_norm"]:
        plt.figure()
        df[col].dropna().hist(bins=50)
        plt.title(f'Distribution of {col}')
        plt.xlabel(col)
        plt.ylabel('Frequency')
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        
        

def standardize_smiles_from_sdf(sdf_path):
    from rdkit import Chem
    from rdkit.Chem import AllChem
    from rdkit.Chem.MolStandardize import rdMolStandardize
    POLAR = {7,8,9,15,16}  # N,O,F,P,S

    try:
        mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
        if mol is None:
            return None
        mol = rdMolStandardize.Cleanup(mol)
        AllChem.AssignAtomChiralTagsFromStructure(mol, replaceExistingTags=False)
        Chem.AssignStereochemistry(mol, force=True, cleanIt=False)
        targets = [a.GetIdx() for a in mol.GetAtoms()
           if a.GetAtomicNum() in POLAR and not a.GetNoImplicit() and a.GetImplicitHCount()>0]
        if targets:
            mol = Chem.AddHs(mol, addCoords=mol.GetNumConformers()>0, onlyOnAtoms=targets)
        mol = rdMolStandardize.Normalizer().normalize(mol)
        mol = rdMolStandardize.FragmentParent(mol)
        # mol = rdMolStandardize.Uncharger().uncharge(mol)
        mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)
        for atom in mol.GetAtoms():
            atom.SetIsotope(0)
        return Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
    except Exception as e:
        print(e)
        return None  # or return original path or fallback string if needed


    
    
def standardize_all_smiles(df):
    with Pool(N_PROC) as pool:
        new_smiles = list(tqdm(pool.imap(standardize_smiles_from_sdf, df['standardized_ligand_sdf']), total=len(df)))
    df['std_smiles'] = new_smiles
    return df

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')


import os
import shutil
import warnings
from rdkit import Chem
from multiprocessing import Pool
from tqdm import tqdm
from pdbfixer import PDBFixer
from openmm.app import PDBFile
from rdkit import Chem

POLAR_HEAVY = {7, 8, 15, 16}  # N,O,P,S

warnings.filterwarnings("ignore", category=UserWarning, module="openmm.app")


def sdf_to_pdb_block_preserve_coords(sdf_path):
    mol = Chem.MolFromMolFile(sdf_path, removeHs=True)
    return Chem.MolToPDBBlock(mol)


def write_temp_complex(protein_pdb_path, ligand_sdf_path, temp_complex_path):
    ligand_pdb_block = sdf_to_pdb_block_preserve_coords(ligand_sdf_path)

    with open(protein_pdb_path, 'r') as f:
        prot_lines = [line for line in f if line.startswith('ATOM')]

    with open(temp_complex_path, 'w') as out:
        out.writelines(prot_lines)
        out.write(ligand_pdb_block)
        out.write("END\n")


import uuid

def fix_complex_and_extract_protein(protein_pdb, ligand_sdf, output_protein_pdb):
    os.makedirs("./temp", exist_ok=True)
    tmp_complex = f"./temp/temp_complex_{uuid.uuid4().hex}.pdb"
    write_temp_complex(protein_pdb, ligand_sdf, tmp_complex)

    fixer = PDBFixer(filename=tmp_complex)
    # fixer.findMissingResidues()
    # fixer.findMissingAtoms()
    # fixer.addMissingAtoms()
    # fixer.addMissingHydrogens(pH=7.4)
    fixer.removeHeterogens(keepWater=False)

    with open(output_protein_pdb, 'w') as f:
        PDBFile.writeFile(fixer.topology, fixer.positions, f)

    os.remove(tmp_complex)


def process_protein_entry(args):
    protein_path, ligand_path, out_prot_path, out_lig_path = args
    try:
        clean_structure(protein_path, out_prot_path)
        # if os.path.abspath(ligand_path) != os.path.abspath(out_lig_path):
        #     shutil.copy2(ligand_path, out_lig_path)
        return out_prot_path, out_lig_path
    except Exception as e:
        print(f"Error on {protein_path}: {e}")
        return None, None


def standardize_all_proteins(df, protein_out_dir, ligand_out_dir=None, n_proc=4):
    os.makedirs(protein_out_dir, exist_ok=True)
    if ligand_out_dir:
        os.makedirs(ligand_out_dir, exist_ok=True)

    tasks = []
    for idx, row in df.iterrows():
        prot_path = row["protein_pdb_path"]
        lig_path = row["ligand_sdf_path"]
        out_prot = os.path.join(protein_out_dir, f"{idx}.pdb")
        out_lig = os.path.join(ligand_out_dir, f"{idx}.sdf") if ligand_out_dir else lig_path
        tasks.append((prot_path, lig_path, out_prot, out_lig))

    with Pool(n_proc) as pool:
        results = list(tqdm(pool.imap(process_protein_entry, tasks), total=len(tasks)))

    prot_paths, lig_paths = zip(*results)
    df["standardized_protein_pdb"] = prot_paths
    df["standardized_ligand_sdf"] = lig_paths
    return df

def clean_structure(
    pdb_path,
    output_path,
    ph=7.4,
    remove_water=True
):
    from pdbfixer import PDBFixer
    from openmm.app import PDBFile, Modeller, element as elem

    fixer = PDBFixer(filename=pdb_path)
    fixer.findMissingResidues()
    fixer.findNonstandardResidues()
    fixer.replaceNonstandardResidues()
    fixer.removeHeterogens(keepWater=not remove_water)
    fixer.findMissingAtoms()
    fixer.addMissingAtoms()
    fixer.addMissingHydrogens(pH=ph)

    mod = Modeller(fixer.topology, fixer.positions)

    to_delete = []
    for bond in mod.topology.bonds():
        a1, a2 = bond
        # if a hydrogen is bonded to carbon -> non-polar H -> delete
        if a1.element == elem.hydrogen and a2.element == elem.carbon:
            to_delete.append(a1)
        elif a2.element == elem.hydrogen and a1.element == elem.carbon:
            to_delete.append(a2)

    if to_delete:
        mod.delete(to_delete)

    with open(output_path, 'w') as f:
        PDBFile.writeFile(mod.topology, mod.positions, f)

def sdf_to_pdb_block_preserve_coords(sdf_path):
    mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
    mol = keep_only_polar_H_rdkit(mol)
    return Chem.MolToPDBBlock(mol)


def keep_only_polar_H_rdkit(mol: Chem.Mol) -> Chem.Mol:
    # 1) remove explicit H attached to non-polar heavy atoms
    h_to_del = []
    for a in mol.GetAtoms():
        if a.GetAtomicNum() != 1:
            continue
        nbrs = a.GetNeighbors()
        if not nbrs:
            continue
        if nbrs[0].GetAtomicNum() not in POLAR_HEAVY:
            h_to_del.append(a.GetIdx())
    em = Chem.EditableMol(mol)
    for idx in sorted(h_to_del, reverse=True):
        em.RemoveAtom(idx)
    mol = em.GetMol()

    # 2) add ONLY missing H on polar atoms (don’t touch existing H)
    mol.UpdatePropertyCache(strict=False)
    targets = [a.GetIdx() for a in mol.GetAtoms()
               if a.GetAtomicNum() in POLAR_HEAVY and a.GetImplicitHCount() > 0]
    if targets:
        mol = Chem.AddHs(mol,
                         addCoords=(mol.GetNumConformers() > 0),
                         onlyOnAtoms=targets)
    return mol


from rdkit import Chem

POLAR_HEAVY = {7, 8, 15, 16}  # N,O,P,S

def _num_missing_polar_H(a: Chem.Atom) -> int:
    if a.GetAtomicNum() not in POLAR_HEAVY:
        return 0
    try:
        ih = a.GetNumImplicitHs()           # correct API
    except AttributeError:
        ih = a.GetTotalNumHs() - a.GetNumExplicitHs()
    return max(0, ih)

def keep_only_polar_H_rdkit(mol: Chem.Mol) -> Chem.Mol:
    # 1) remove explicit H attached to non-polar heavy atoms
    to_del = []
    for h in mol.GetAtoms():
        if h.GetAtomicNum() != 1:
            continue
        nbs = h.GetNeighbors()
        if nbs and nbs[0].GetAtomicNum() not in POLAR_HEAVY:
            to_del.append(h.GetIdx())
    if to_del:
        em = Chem.EditableMol(mol)
        for idx in sorted(to_del, reverse=True):
            em.RemoveAtom(idx)
        mol = em.GetMol()

    # 2) add ONLY missing H on polar atoms (don’t touch existing H)
    mol.UpdatePropertyCache(strict=False)
    targets = [a.GetIdx() for a in mol.GetAtoms() if _num_missing_polar_H(a) > 0]
    if targets:
        mol = Chem.AddHs(mol,
                         addCoords=(mol.GetNumConformers() > 0),
                         onlyOnAtoms=targets)
    return mol


def standardize_ligand(path, output_path):
    ext = os.path.splitext(path)[1].lower()
    mol = (Chem.MolFromMolFile(path, removeHs=False) if ext == '.sdf'
           else Chem.MolFromPDBFile(path, removeHs=False) if ext == '.pdb'
           else None)
    if mol is None:
        return False

    mol = keep_only_polar_H_rdkit(mol)     # enforce rule

    Chem.SanitizeMol(mol)
    Chem.AssignStereochemistry(mol, cleanIt=False, force=True)
    if mol.GetNumConformers() == 0:
        AllChem.EmbedMolecule(mol, randomSeed=42)  # only if no coords
    ComputeGasteigerCharges(mol)
    Chem.MolToMolFile(mol, output_path)
    return True

def standardize_smiles_from_sdf(sdf_path):
    from rdkit import Chem
    from rdkit.Chem import AllChem
    from rdkit.Chem.MolStandardize import rdMolStandardize

    POLAR = {7, 8, 15, 16}  # N,O,P,S

    try:
        mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
        if mol is None:
            return None

        # keep existing polar H; drop non-polar H
        to_del = []
        for a in mol.GetAtoms():
            if a.GetAtomicNum() != 1:
                continue
            nbs = a.GetNeighbors()
            if nbs and nbs[0].GetAtomicNum() not in POLAR:
                to_del.append(a.GetIdx())
        if to_del:
            em = Chem.EditableMol(mol)
            for idx in sorted(to_del, reverse=True):
                em.RemoveAtom(idx)
            mol = em.GetMol()

        mol.UpdatePropertyCache(strict=False)
        AllChem.AssignAtomChiralTagsFromStructure(mol, replaceExistingTags=False)
        Chem.AssignStereochemistry(mol, force=True, cleanIt=False)

        # add ONLY missing H on polar atoms
        try:
            mol = Chem.AddHs(mol, addCoords=False, onlyOnPolarAtoms=True)
        except TypeError:
            targets = [a.GetIdx() for a in mol.GetAtoms()
                       if a.GetAtomicNum() in POLAR and a.GetNumImplicitHs() > 0]
            if targets:
                mol = Chem.AddHs(mol, addCoords=False, onlyOnAtoms=targets)

        # standardize
        mol = rdMolStandardize.Cleanup(mol)
        mol = rdMolStandardize.Normalizer().normalize(mol)
        mol = rdMolStandardize.FragmentParent(mol)
        mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)

        for atom in mol.GetAtoms():
            atom.SetIsotope(0)
        Chem.AssignStereochemistry(mol, force=True, cleanIt=True)

        return Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
    except Exception as e:
        print(e)
        return None


In [None]:
# === PARALLELIZED STANDARDIZE LIGANDS ===
def standardize_ligand_worker(args):
    """Worker function for parallel ligand standardization"""
    input_path, output_path = args
    from rdkit import Chem
    from rdkit.Chem import AllChem
    from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges
    import os
    
    try:
        ext = os.path.splitext(input_path)[1].lower()
        mol = (Chem.MolFromMolFile(input_path, removeHs=False) if ext == '.sdf'
               else Chem.MolFromPDBFile(input_path, removeHs=False) if ext == '.pdb'
               else None)
        if mol is None:
            return None
        
        mol = keep_only_polar_H_rdkit(mol)     # enforce rule
        Chem.SanitizeMol(mol)
        Chem.AssignStereochemistry(mol, cleanIt=False, force=True)
        if mol.GetNumConformers() == 0:
            AllChem.EmbedMolecule(mol, randomSeed=42)  # only if no coords
        ComputeGasteigerCharges(mol)
        Chem.MolToMolFile(mol, output_path)
        return output_path
    except Exception as e:
        print(f"Error processing {input_path}: {e}")
        return None

def standardize_ligand(path, output_path):
    """Legacy single ligand standardization function"""
    result = standardize_ligand_worker((path, output_path))
    return result is not None

def standardize_all_ligands(df, ligand_out_dir):
    """Parallelized ligand standardization"""
    # Prepare arguments for parallel processing
    args_list = []
    for idx, row in df.iterrows():
        in_path = row["ligand_sdf_path"]
        out_path = os.path.join(ligand_out_dir, f"{idx}.sdf")
        args_list.append((in_path, out_path))
    
    # Process in parallel
    with Pool(N_PROC) as pool:
        ligand_out_paths = list(tqdm(
            pool.imap(standardize_ligand_worker, args_list), 
            total=len(args_list), 
            desc="Standardizing Ligands"
        ))
    
    return ligand_out_paths

# === PARALLELIZED SMILES STANDARDIZATION (IMPROVED) ===
def standardize_smiles_from_sdf(sdf_path):
    """Worker function for parallel SMILES standardization"""
    from rdkit import Chem
    from rdkit.Chem import AllChem
    from rdkit.Chem.MolStandardize import rdMolStandardize
    
    POLAR = {7, 8, 15, 16}  # N,O,P,S
    
    if sdf_path is None:
        return None
        
    mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
    if mol is None:
        return None
    
    # keep existing polar H; drop non-polar H
    to_del = []
    for a in mol.GetAtoms():
        if a.GetAtomicNum() != 1:
            continue
        nbs = a.GetNeighbors()
        if nbs and nbs[0].GetAtomicNum() not in POLAR:
            to_del.append(a.GetIdx())
    
    if to_del:
        em = Chem.EditableMol(mol)
        for idx in sorted(to_del, reverse=True):
            em.RemoveAtom(idx)
        mol = em.GetMol()
    
    mol.UpdatePropertyCache(strict=False)
    AllChem.AssignAtomChiralTagsFromStructure(mol, replaceExistingTags=False)
    Chem.AssignStereochemistry(mol, force=True, cleanIt=False)
    
    # add ONLY missing H on polar atoms
    mol = Chem.AddHs(mol, addCoords=False, onlyOnPolarAtoms=True)
    
    # standardize
    mol = rdMolStandardize.Cleanup(mol)
    mol = rdMolStandardize.Normalizer().normalize(mol)
    mol = rdMolStandardize.FragmentParent(mol)
    mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)
    
    for atom in mol.GetAtoms():
        atom.SetIsotope(0)
    
    Chem.AssignStereochemistry(mol, force=True, cleanIt=True)
    return Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)

def standardize_all_smiles(df):
    """Parallelized SMILES standardization"""
    with Pool(N_PROC) as pool:
        new_smiles = list(tqdm(
            pool.imap(standardize_smiles_from_sdf, df['standardized_ligand_sdf']), 
            total=len(df), 
            desc="Standardizing SMILES"
        ))
    df['std_smiles'] = new_smiles
    return df



# === HELPER FUNCTION (needs to be defined elsewhere in your code) ===
def keep_only_polar_H_rdkit(mol):
    """
    Keep only hydrogens bonded to polar atoms (N, O, P, S).
    This function should be defined elsewhere in your codebase.
    """
    POLAR = {7, 8, 15, 16}  # N,O,P,S
    to_del = []
    for a in mol.GetAtoms():
        if a.GetAtomicNum() != 1:  # if not hydrogen
            continue
        nbs = a.GetNeighbors()
        if nbs and nbs[0].GetAtomicNum() not in POLAR:
            to_del.append(a.GetIdx())
    
    if to_del:
        em = Chem.EditableMol(mol)
        for idx in sorted(to_del, reverse=True):
            em.RemoveAtom(idx)
        mol = em.GetMol()
    
    return mol


def _standardize_protein_worker(args):
    idx, in_path, out_dir = args
    try:
        out_path = os.path.join(out_dir, f"{idx}.pdb")
        if os.path.exists(out_path) == False:
            clean_structure(in_path, out_path)
        return (out_path, idx)
    except:
        return (None, idx)

# Input data

In [None]:
experimental_data_path = "../data/curated/combined/df_combined.parquet"
input_pdb_subpath = 'prepared/pdb_protein/protein_'
input_sdf_subpath = 'prepared/sdf_ligand/ligand_'
output_path = '../prepared_df.parquet'
lig_STD = True
prot_STD = True

df = pd.read_parquet(experimental_data_path)
# df = df[df["is_experimental"]==True]
print(len(df))

# === SET OUTPUT DIRS ===
protein_out_dir = "../data/standardized_clean/protein"
ligand_out_dir = "../data/standardized_clean/ligand"
os.makedirs(protein_out_dir, exist_ok=True)
os.makedirs(ligand_out_dir, exist_ok=True)

# Strandardize Protein

In [None]:
import os
from tqdm import tqdm
from multiprocessing import Pool, cpu_count





# === Code principal ===
if prot_STD:
    os.makedirs(protein_out_dir, exist_ok=True)
    
    df = df.reset_index(drop=True)  # <--- Fix here

    args_list = [(idx, row["protein_pdb_path"], protein_out_dir) for idx, row in df.iterrows()][::-1]
    results = [None] * len(df)

    with Pool(cpu_count() - 1) as pool:
        for out_path, idx in tqdm(pool.imap_unordered(_standardize_protein_worker, args_list), total=len(args_list)): # adapted to ignore if file exist so need to consider it... 
            results[idx] = out_path

    df["standardized_protein_pdb"] = results
else:
    df["standardized_protein_pdb"] = df["protein_pdb_path"].tolist()


12k = 2h
6k = 1h
500k/6 = 83h... = >3 days

# Standardize Ligand

In [None]:

# === USAGE FOR LIGANDS ===
if lig_STD:
    ligand_out_paths = standardize_all_ligands(df, ligand_out_dir)
    df["standardized_ligand_sdf"] = ligand_out_paths
else:
    df["standardized_ligand_sdf"] = df["ligand_sdf_path"].tolist()



# Standardize SMILES

In [None]:
import os
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.MolStandardize import rdMolStandardize
from multiprocessing import Pool, cpu_count
from tqdm import tqdm

def validate_mol_safe(mol):
    """Check if molecule is valid without raising exceptions"""
    if mol is None:
        return False
    if mol.GetNumAtoms() == 0:
        return False
    # Check for valid valence - this prevents most Boost.Python errors
    for atom in mol.GetAtoms():
        if atom.GetImplicitValence() < 0:
            return False
    return True

def standardize_smiles_from_sdf_safe(sdf_path):
    """Worker function that avoids exceptions by validation"""
    POLAR = {7, 8, 15, 16}  # N,O,P,S
    
    # Validate input
    
    # Read molecule without sanitization to avoid immediate errors
    mol = Chem.MolFromMolFile(sdf_path, removeHs=False, sanitize=False)
    if not validate_mol_safe(mol):
        print("fail...Chem.MolFromMolFile(sdf_path, removeHs= ", sdf_path)
        return None
    
    # Attempt sanitization - this returns a status code instead of raising
    sanitize_result = Chem.SanitizeMol(mol, catchErrors=True)
    if sanitize_result != Chem.SanitizeFlags.SANITIZE_NONE:
        # Sanitization failed, try cleanup first
        mol_cleaned = rdMolStandardize.Cleanup(mol, catchErrors=True)
        if not validate_mol_safe(mol_cleaned):
            print("fail....SanitizeMol(mol, catchErrors= ", sdf_path)
            return None
        mol = mol_cleaned
        # Try sanitization again
        sanitize_result = Chem.SanitizeMol(mol, catchErrors=True)
        if sanitize_result != Chem.SanitizeFlags.SANITIZE_NONE:
            print("fail... ", sdf_path)
            return None
    
    # Remove non-polar hydrogens
    to_del = []
    for a in mol.GetAtoms():
        if a.GetAtomicNum() != 1:
            continue
        nbs = a.GetNeighbors()
        if len(nbs) > 0 and nbs[0].GetAtomicNum() not in POLAR:
            to_del.append(a.GetIdx())
    
    if to_del:
        em = Chem.EditableMol(mol)
        for idx in sorted(to_del, reverse=True):
            em.RemoveAtom(idx)
        mol = em.GetMol()
        if not validate_mol_safe(mol):
            print("fail... em = Chem.EditableMol(mol)", sdf_path)
            return None
    
    # Update property cache
    mol.UpdatePropertyCache(strict=False)
    
    # Assign stereochemistry - these operations should not raise if mol is valid
    AllChem.AssignAtomChiralTagsFromStructure(mol, replaceExistingTags=False)
    Chem.AssignStereochemistry(mol, force=True, cleanIt=False)
    
    # Add hydrogens on polar atoms - using correct method name
    targets = []
    for a in mol.GetAtoms():
        if a.GetAtomicNum() in POLAR and a.GetNumImplicitHs() > 0:
            targets.append(a.GetIdx())
    
    if targets:
        mol = Chem.AddHs(mol, addCoords=False, onlyOnAtoms=targets)
        if not validate_mol_safe(mol):
            print("fail... ol = Chem.AddHs(mol, addCoords=False, onlyOn", sdf_path)
            return None
    
    # Standardization pipeline - each step validates the molecule
    mol = rdMolStandardize.Cleanup(mol)
    if not validate_mol_safe(mol):
        print("fail... = rdMolStandardize.Cleanup ", sdf_path)
        return None
        
    normalizer = rdMolStandardize.Normalizer()
    mol = normalizer.normalize(mol)
    if not validate_mol_safe(mol):
        print("fail... alizer = rdMolStandardize.Norm", sdf_path)
        return None
    
    mol = rdMolStandardize.FragmentParent(mol)
    if not validate_mol_safe(mol):
        print("fail... MolStandardize.FragmentParent(m", sdf_path)
        return None
    
    # tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
    # mol = tautomer_enumerator.Canonicalize(mol)
    # if not validate_mol_safe(mol):
    #     return None
    
    # Clear isotopes
    for atom in mol.GetAtoms():
        atom.SetIsotope(0)
    
    # Final stereochemistry
    Chem.AssignStereochemistry(mol, force=True, cleanIt=True)
    
    # Generate SMILES - this should work if molecule is valid
    smiles = Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
    if smiles:
        return smiles 
    else:
        print("fail... m.MolToSmiles(mol, isomericSmiles=True, ca", sdf_path)
        return None


def process_smiles_chunk(chunk_data):
    """Process a chunk of SMILES in a single process"""
    chunk_idx, sdf_paths = chunk_data
    results = []
    for path in sdf_paths:
        smiles = standardize_smiles_from_sdf_safe(path)
        results.append(smiles)
    return chunk_idx, results


def standardize_all_smiles_chunked(df, chunk_size=1000):
    """
    Process SMILES in chunks to avoid memory issues and Boost.Python errors
    Each worker processes a chunk sequentially to minimize inter-process communication
    """
    N_PROC = cpu_count() - 1
    
    sdf_paths = df['standardized_ligand_sdf'].tolist()
    n_total = len(sdf_paths)
    
    # Create chunks for processing
    chunks = []
    for i in range(0, n_total, chunk_size):
        chunk_end = min(i + chunk_size, n_total)
        chunks.append((i // chunk_size, sdf_paths[i:chunk_end]))
    
    # Process chunks in parallel
    results_dict = {}
    with Pool(N_PROC) as pool:
        for chunk_idx, chunk_results in tqdm(
            pool.imap_unordered(process_smiles_chunk, chunks),
            total=len(chunks),
            desc="Processing SMILES chunks"
        ):
            results_dict[chunk_idx] = chunk_results
    
    # Reassemble results in order
    all_smiles = []
    for i in range(len(chunks)):
        all_smiles.extend(results_dict[i])
    
    df['std_smiles'] = all_smiles
    return df


# Alternative: Simple map without imap to avoid iterator issues
def standardize_all_smiles_simple(df):
    """
    Simple parallel processing using map instead of imap
    This avoids iterator-based issues with Boost.Python errors
    """
    N_PROC = cpu_count() - 1
    
    sdf_paths = df['standardized_ligand_sdf'].tolist()
    
    with Pool(N_PROC) as pool:
        # Use imap with tqdm for progress tracking
        new_smiles = list(tqdm(
            pool.imap(standardize_smiles_from_sdf_safe, sdf_paths),
            total=len(sdf_paths),
            desc="Standardizing SMILES"
        ))
    
    df['std_smiles'] = new_smiles
    return df


# Sequential version as ultimate fallback
def standardize_all_smiles_sequential(df):
    """
    Sequential SMILES standardization - slowest but most stable
    """
    from tqdm import tqdm
    
    new_smiles = []
    for sdf_path in tqdm(df['standardized_ligand_sdf'], desc="Standardizing SMILES"):
        smiles = standardize_smiles_from_sdf_safe(sdf_path)
        new_smiles.append(smiles)
    
    df['std_smiles'] = new_smiles
    return df

In [None]:
df_in = df[df["standardized_protein_pdb"].isna()==False]
df_in = df_in[df_in["standardized_ligand_sdf"].isna()==False]

# === USAGE FOR SMILES ===
df_in = standardize_all_smiles_simple(df_in)
df_in

In [None]:
df_in = df_in[df_in["std_smiles"].isna()==False]
len(df_in)

# Save

In [None]:
# === SAVE TO DISK ===
output_path = "../prepared_df.parquet"
df_in.to_parquet(output_path, index=False)

In [None]:
output_path = "../prepared_df.parquet"
df = pd.read_parquet(output_path)

# Missing ?

In [None]:
missing_prot = df['standardized_protein_pdb'].apply(lambda x: not x or not os.path.exists(x)).sum()
missing_lig = df['standardized_ligand_sdf'].apply(lambda x: not x or not os.path.exists(x)).sum()
print("Missing protein:", missing_prot)
print("Missing ligand:", missing_lig)
print(len(df))
df = df[df["std_smiles"].isna()==False]
len(df)

In [None]:
print("Missing protein:", missing_prot)
print("Missing ligand:", missing_lig)


# Compute properties

In [None]:
df = add_molecular_properties_parallel(df)
df = compute_ligand_efficiency(df)
df = compute_mean_ligand_efficiency(df)


In [None]:
# Get top and bottom 5 molecules based on LE
top5, bottom5 = display_top_bottom_le(df)

# Plot

In [None]:
len(df)

In [None]:
df.source_file.value_counts()

In [None]:
plot_property_distributions(df)

In [None]:
# Display top 20
show_molecules(top5['std_smiles'].tolist(), title='Top 20 LE')

In [None]:
# Display bottom 20
show_molecules(bottom5['std_smiles'].tolist(), title='Bottom 20 LE')

# Atom count identification

In [None]:
df['std_smiles'].tolist()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter
import re


def count_atoms_in_smiles(smiles):
    """
    Count heavy atoms in a SMILES string.
    Returns a dictionary with atom counts.
    """
    # Remove chirality markers, brackets, charges, and other SMILES notation
    # but keep the actual atom symbols
    atom_counts = Counter()
    
    # Define heavy atoms (excluding H)
    heavy_atoms = ['C', 'N', 'O', 'S', 'P', 'F', 'Cl', 'Br', 'I', 'B', 'Si', 'Se']
    
    # Clean the SMILES string for easier parsing
    cleaned = smiles
    # Remove stereochemistry markers
    cleaned = re.sub(r'[@\\\/]', '', cleaned)
    # Remove charges and H counts
    cleaned = re.sub(r'[+\-]\d*', '', cleaned)
    cleaned = re.sub(r'H\d*', '', cleaned)
    
    # Count two-letter atoms first (Cl, Br, Si, Se)
    for atom in ['Cl', 'Br', 'Si', 'Se']:
        count = cleaned.count(atom)
        if count > 0:
            atom_counts[atom] = count
            cleaned = cleaned.replace(atom, '')
    
    # Count single-letter atoms
    for atom in ['C', 'N', 'O', 'S', 'P', 'F', 'I', 'B']:
        # Count lowercase 'c', 'n', 'o', 's', 'p' (aromatic) as the same element
        count = cleaned.count(atom) + cleaned.count(atom.lower())
        if count > 0:
            atom_counts[atom] = count
    
    return atom_counts

def count_carbon_atoms(smiles):
    """Count the number of carbon atoms in a SMILES string."""
    # Count both 'C' and 'c' (aromatic carbon)
    return smiles.count('C') + smiles.count('c')

# Count atoms for each SMILES
all_atom_counts = []
for smiles in df['std_smiles']:
    atom_counts = count_atoms_in_smiles(smiles)
    all_atom_counts.append(atom_counts)

# Count how many SMILES contain each atom type
atom_presence_count = Counter()
for atom_dict in all_atom_counts:
    for atom in atom_dict.keys():
        atom_presence_count[atom] += 1

# Add column for carbon count ≤ 1
df['carbon_le_1'] = df['std_smiles'].apply(lambda x: count_carbon_atoms(x) <= 1)

# Print the DataFrame with the new column
print("DataFrame with carbon ≤ 1 annotation:")
print(df[['std_smiles', 'carbon_le_1']].head())
print(f"\nTotal SMILES with ≤1 carbon: {df['carbon_le_1'].sum()}")

# Print atom presence statistics
print("\n" + "="*50)
print("Number of SMILES containing each atom type:")
print("="*50)
for atom, count in sorted(atom_presence_count.items(), key=lambda x: x[1], reverse=True):
    print(f"{atom}: {count:3d} SMILES ({count/len(df)*100:.1f}%)")


In [None]:

# Create bar plot
plt.figure(figsize=(12, 6))
atoms = list(atom_presence_count.keys())
counts = list(atom_presence_count.values())

# Sort by count for better visualization
sorted_items = sorted(zip(atoms, counts), key=lambda x: x[1], reverse=True)
atoms_sorted = [item[0] for item in sorted_items]
counts_sorted = [item[1] for item in sorted_items]

bars = plt.bar(atoms_sorted, counts_sorted, color='steelblue', edgecolor='black', linewidth=1.5)

# Add value labels on top of bars
for bar, count in zip(bars, counts_sorted):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, 
             str(count), ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.xlabel('Atom Type', fontsize=12, fontweight='bold')
plt.ylabel('Number of SMILES', fontsize=12, fontweight='bold')
plt.title('Heavy Atoms Distribution in SMILES Dataset', fontsize=14, fontweight='bold')
plt.grid(axis='y', alpha=0.3, linestyle='--')
plt.ylim(0, max(counts_sorted) * 1.1)  # Add some space at the top

# Add a horizontal line at the total number of SMILES
plt.axhline(y=len(df), color='red', linestyle='--', alpha=0.5, label=f'Total SMILES: {len(df)}')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:

# Additional analysis: Show carbon atom distribution
carbon_counts = df['std_smiles'].apply(count_carbon_atoms)
print("\n" + "="*50)
print("Carbon atom distribution:")
print("="*50)
print(f"Min carbons: {carbon_counts.min()}")
print(f"Max carbons: {carbon_counts.max()}")
print(f"Mean carbons: {carbon_counts.mean():.1f}")
print(f"Median carbons: {carbon_counts.median():.1f}")


In [None]:

# Create a second plot for carbon distribution
plt.figure(figsize=(10, 6))
plt.hist(carbon_counts, bins=range(carbon_counts.min(), carbon_counts.max() + 2), 
         color='coral', edgecolor='black', linewidth=1.2)
plt.xlabel('Number of Carbon Atoms', fontsize=12, fontweight='bold')
plt.ylabel('Number of SMILES', fontsize=12, fontweight='bold')
plt.title('Distribution of Carbon Atoms per SMILES', fontsize=14, fontweight='bold')
plt.grid(axis='y', alpha=0.3, linestyle='--')
plt.xticks(range(carbon_counts.min(), carbon_counts.max() + 1))
plt.tight_layout()
plt.show()

# Filter

In [None]:
# maybe remove Si Se ? have to see with the full size


In [None]:
# Count atoms for each SMILES
all_atom_counts = []
for smiles in df['std_smiles']:
    atom_counts = count_atoms_in_smiles(smiles)
    all_atom_counts.append(atom_counts)

# Count how many SMILES contain each atom type
atom_presence_count = Counter()
for atom_dict in all_atom_counts:
    for atom in atom_dict.keys():
        atom_presence_count[atom] += 1

# Add column for carbon count < 3 (less than 3 carbons)
df['carbon_lt_3'] = df['std_smiles'].apply(lambda x: count_carbon_atoms(x) < 4)

In [None]:
# Define your thresholds
carbon_atom = df['carbon_lt_3'] == True
low_heavy = df["HeavyAtomCount"] < 5
high_MW = df["MolWt"] > 1000
high_heavy = df["HeavyAtomCount"] > 75
low_le = df["LE"] <= 0.05
high_le = df["LE"] >= .7  # or 0.8 if you want
high_le = df["LE_norm"] >= 0.003  # or 0.8 if you want
# Combine with OR logic
bad_filter = low_heavy | high_heavy | low_le | high_le | high_MW | carbon_atom

# Split into two DataFrames
df_bad = df[bad_filter]
df_good = df[~bad_filter]

In [None]:
# Get top and bottom 5 molecules based on LE
top5, bottom5 = display_top_bottom_le(df_good)

In [None]:
# Display top 20
show_molecules(top5['std_smiles'].tolist(), title='Top 20 LE')

In [None]:
# Display bottom 20
show_molecules(bottom5['std_smiles'].tolist(), title='Bottom 20 LE')

In [None]:
plot_property_distributions(df_good)

In [None]:
df_bad.source_file.value_counts()

In [None]:
df_good.source_file.value_counts()

# Save

In [None]:
output_path = '../data/standardized/standardized_input.parquet'
df["smiles"] = df["std_smiles"].tolist()
df = df[df['standardized_ligand_sdf'].isna()==False]
df = df[df['standardized_protein_pdb'].isna()==False]
df = df[df['std_smiles'].isna()==False]

df.to_parquet(output_path, index = False)

In [None]:
df.columns.tolist()

In [1]:
1

1

# Extra