# 1 : Imports and Setup

In [1]:
import os
import sys
import json
import gc
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch
import torch_geometric
from pathlib import Path
from sklearn.model_selection import KFold, train_test_split
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Add the parent directory to the Python path
# This allows importing the gnn_dta_mtl package
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

# Import your package - use absolute import instead of relative
from gnn_dta_mtl import (
    MTL_DTAModel, DTAModel,
    MTL_DTA, DTA,
    CrossValidator, MTLTrainer,
    StructureStandardizer, StructureProcessor, StructureChunkLoader,
    ESMEmbedder,
    add_molecular_properties_parallel,
    compute_ligand_efficiency,
    compute_mean_ligand_efficiency,
    filter_by_properties,
    prepare_mtl_experiment,
    build_mtl_dataset, build_mtl_dataset_optimized,
    evaluate_model,
    plot_results, plot_predictions, create_summary_report,
    ExperimentLogger,
    save_model, save_results, create_output_dir
)

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

Using device: cuda
GPU: NVIDIA A100-SXM4-40GB
Number of GPUs: 16


# 2: Configuration


In [2]:
import os
import json
from pathlib import Path
from datetime import datetime

# Create necessary directories first
base_dirs = [
    '../input/combined',
    '../input/chunk',
    '../input/embeddings',
    '../output/protein',
    '../output/ligand',
    '../output/experiments',
]

for dir_path in base_dirs:
    Path(dir_path).mkdir(parents=True, exist_ok=True)
    print(f"✓ Created: {dir_path}")

CONFIG = {
    # Data paths
    'data_path': '../data/curated/combined/df_combined.parquet',
    'protein_out_dir': '../output/protein',
    'ligand_out_dir': '../output/ligand',
    'structure_chunks_dir': '../input/chunk/',
    'embeddings_dir': '../input/embeddings/',
    'output_dir': '../output/experiments/',
    
    # Task configuration
    'task_cols': ['pKi', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency'],
    
    # Model configuration
    'model_config': {
        'prot_emb_dim': 1280,
        'prot_gcn_dims': [128, 256, 256],
        'prot_fc_dims': [1024, 128],
        'drug_node_in_dim': [66, 1],
        'drug_node_h_dims': [128, 64],
        'drug_edge_in_dim': [16, 1],
        'drug_edge_h_dims': [32, 1],
        'drug_fc_dims': [1024, 128],
        'mlp_dims': [1024, 512],
        'mlp_dropout': 0.25
    },
    
    # Training configuration
    'training_config': {
        'batch_size': 128,
        'n_epochs': 200,
        'learning_rate': 0.0005,
        'patience': 100,
        'n_folds': 3
    },
    
    # Data filtering
    'filter_config': {
        'min_heavy_atoms': 5,
        'max_heavy_atoms': 75,
        'max_mw': 1000,
        'min_carbons': 4,
        'min_le': 0.05,
        'max_le_norm': 0.003
    },
    
    # Processing
    'n_workers': os.cpu_count() - 1 if os.cpu_count() else 112,
    'chunk_size': 50000,
    'sample_size': None,  # Set to integer to limit data size for testing
    
    # ESM model
    'esm_model_name': 'facebook/esm2_t33_650M_UR50D'
}

# Create experiment directory with timestamp
experiment_name = 'gnn_dta_mtl_experiment'
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_dir = Path(CONFIG['output_dir']) / f"{experiment_name}_{timestamp}"

# Create subdirectories
(experiment_dir / 'models').mkdir(parents=True, exist_ok=True)
(experiment_dir / 'results').mkdir(parents=True, exist_ok=True)
(experiment_dir / 'figures').mkdir(parents=True, exist_ok=True)
(experiment_dir / 'logs').mkdir(parents=True, exist_ok=True)

CONFIG['experiment_dir'] = str(experiment_dir)

# Save configuration
config_path = experiment_dir / 'config.json'
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)

print(f"✓ Experiment directory: {CONFIG['experiment_dir']}")
print(f"✓ Configuration saved to: {config_path}")

✓ Created: ../input/combined
✓ Created: ../input/chunk
✓ Created: ../input/embeddings
✓ Created: ../output/protein
✓ Created: ../output/ligand
✓ Created: ../output/experiments
✓ Experiment directory: ../output/experiments/gnn_dta_mtl_experiment_20250929_175416
✓ Configuration saved to: ../output/experiments/gnn_dta_mtl_experiment_20250929_175416/config.json


# 3 : Load Data

In [3]:
print("Loading data...")
df = pd.read_parquet(CONFIG['data_path'])
print(f"Initial data shape: {df.shape}")
df.head().style

Loading data...
Initial data shape: (550663, 14)


Unnamed: 0,protein_pdb_path,ligand_sdf_path,smiles,pKi,source_file,is_experimental,resolution,pEC50,"pKd (Wang, FEP)",pKd,pIC50,SMILES,potency,assay
0,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL60581/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL60581/ligand.sdf,CCCCCCSCC(NC(=O)CCC(N)C(=O)O)C(=O)NCCC(=O)O,3.259637,BindingNetv2,False,,,,,,,,
1,../data/raw/BindingNetv2/moderate/target_CHEMBL3902/CHEMBL58951/protein.pdb,../data/raw/BindingNetv2/moderate/target_CHEMBL3902/CHEMBL58951/ligand.sdf,NC(CCC(=O)NC(CSCc1ccccc1)C(=O)NC(C(=O)O)c1ccccc1)C(=O)O,6.376751,BindingNetv2,False,,,,,,,,
2,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL301229/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL301229/ligand.sdf,Cc1ccc(CSCC(NC(=O)CCC(N)C(=O)O)C(=O)NCCC(=O)O)cc1,4.39794,BindingNetv2,False,,,,,,,,
3,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL442360/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL442360/ligand.sdf,NC(CCC(=O)NC(CSCc1ccc(Cl)cc1)C(=O)NC(C(=O)O)c1ccccc1)C(=O)O,6.920819,BindingNetv2,False,,,,,,,,
4,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL58451/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL58451/ligand.sdf,NC(CCC(=O)NC(CSCc1ccccc1)C(=O)NCCC(=O)O)C(=O)O,3.148742,BindingNetv2,False,,,,,,,,


In [4]:
# Quick one-liner to get all non-NaN counts
df[['pKi', 'resolution', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency']].notna().sum()

# Or to see the info for all columns at once
df.info()  # This shows non-null count for all columns

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 550663 entries, 0 to 550662
Data columns (total 14 columns):
 #   Column            Non-Null Count   Dtype  
---  ------            --------------   -----  
 0   protein_pdb_path  550663 non-null  object 
 1   ligand_sdf_path   550663 non-null  object 
 2   smiles            477203 non-null  object 
 3   pKi               115551 non-null  float64
 4   source_file       550663 non-null  object 
 5   is_experimental   550663 non-null  bool   
 6   resolution        9486 non-null    float64
 7   pEC50             67187 non-null   float64
 8   pKd (Wang, FEP)   1894 non-null    float64
 9   pKd               20890 non-null   float64
 10  pIC50             271665 non-null  float64
 11  SMILES            73460 non-null   object 
 12  potency           73460 non-null   float64
 13  assay             73460 non-null   object 
dtypes: bool(1), float64(7), object(6)
memory usage: 55.1+ MB


In [5]:
df.source_file.value_counts()

source_file
BindingNetv2                        392967
processed_data                       73460
BindingNetv1                         68738
PDBbind2020                           5118
HiQBind                               4429
BioLip2                               4057
FEP_Zariquiey_extended_Wang_2015      1651
FEP_Wang_2015                          243
Name: count, dtype: int64

In [6]:
df = df[df["source_file"]!="processed_data"]

# 4 : Data pack for reduced size

In [7]:
# Quick one-liner to get all non-NaN counts

# Or to see the info for all columns at once
df.info()  # This shows non-null count for all columns

<class 'pandas.core.frame.DataFrame'>
Index: 477203 entries, 0 to 550662
Data columns (total 14 columns):
 #   Column            Non-Null Count   Dtype  
---  ------            --------------   -----  
 0   protein_pdb_path  477203 non-null  object 
 1   ligand_sdf_path   477203 non-null  object 
 2   smiles            477203 non-null  object 
 3   pKi               115551 non-null  float64
 4   source_file       477203 non-null  object 
 5   is_experimental   477203 non-null  bool   
 6   resolution        9486 non-null    float64
 7   pEC50             67187 non-null   float64
 8   pKd (Wang, FEP)   1894 non-null    float64
 9   pKd               20890 non-null   float64
 10  pIC50             271665 non-null  float64
 11  SMILES            0 non-null       object 
 12  potency           0 non-null       float64
 13  assay             0 non-null       object 
dtypes: bool(1), float64(7), object(6)
memory usage: 51.4+ MB


In [8]:
df.source_file.value_counts()

source_file
BindingNetv2                        392967
BindingNetv1                         68738
PDBbind2020                           5118
HiQBind                               4429
BioLip2                               4057
FEP_Zariquiey_extended_Wang_2015      1651
FEP_Wang_2015                          243
Name: count, dtype: int64

# 5 : Standardize data
- identify errors, most of bindingnet 1 no protein in the pdb

In [9]:
# 0.1/sec (10k sec) + 5/sec (2000)  + 4.5/sec (105000

In [None]:
# Cell 4: Complete UltraFastStructureStandardizer Implementation

import os
import gc
import time
from pathlib import Path
from typing import Optional, List, Dict, Tuple, Any
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem, rdPartialCharges
from rdkit.Chem.MolStandardize import rdMolStandardize
from multiprocessing import Pool, cpu_count, Queue, Process, Manager
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm.auto import tqdm
import logging
from dataclasses import dataclass

# Suppress warnings
logging.basicConfig(level=logging.ERROR)
os.environ['OMP_NUM_THREADS'] = '1'

# Constants
POLAR_HEAVY = {7, 8, 15, 16}  # N, O, P, S

@dataclass
class WorkerTask:
    """Task for a single worker"""
    worker_id: int
    items: List[Tuple]
    start_idx: int

def keep_only_polar_H_rdkit(mol: Chem.Mol) -> Chem.Mol:
    """Keep only hydrogens bonded to polar atoms (N, O, P, S)."""
    if mol is None:
        return None
    
    h_to_del = []
    for a in mol.GetAtoms():
        if a.GetAtomicNum() != 1:
            continue
        nbrs = a.GetNeighbors()
        if not nbrs:
            continue
        if nbrs[0].GetAtomicNum() not in POLAR_HEAVY:
            h_to_del.append(a.GetIdx())
    
    if h_to_del:
        em = Chem.EditableMol(mol)
        for idx in sorted(h_to_del, reverse=True):
            em.RemoveAtom(idx)
        mol = em.GetMol()
    
    mol.UpdatePropertyCache(strict=False)
    targets = [a.GetIdx() for a in mol.GetAtoms()
               if a.GetAtomicNum() in POLAR_HEAVY and a.GetNumImplicitHs() > 0]
    if targets:
        mol = Chem.AddHs(mol, addCoords=(mol.GetNumConformers() > 0), onlyOnAtoms=targets)
    return mol

def standardize_ligand(path: str, output_path: str) -> bool:
    """Standardize a ligand structure from SDF or PDB file."""
    try:
        if path.endswith('.sdf'):
            mol = Chem.MolFromMolFile(path, removeHs=False)
        elif path.endswith('.pdb'):
            mol = Chem.MolFromPDBFile(path, removeHs=False)
        else:
            mol = Chem.MolFromMolFile(path, removeHs=False)
            
        if mol is None:
            return False
        
        mol = keep_only_polar_H_rdkit(mol)
        Chem.SanitizeMol(mol)
        Chem.AssignStereochemistry(mol, cleanIt=False, force=True)
        
        if mol.GetNumConformers() == 0:
            AllChem.EmbedMolecule(mol, randomSeed=42)
            
        rdPartialCharges.ComputeGasteigerCharges(mol)
        
        Chem.MolToSmiles(mol)
        if mol is None:
            return False
        else:
            Chem.MolToMolFile(mol, output_path)
            return True
    except:
        return False

def clean_protein_structure_minimal(pdb_path: str, output_path: str) -> None:
    """Minimal protein cleaning - fast version without expensive operations."""
    try:
        from pdbfixer import PDBFixer
        from openmm.app import PDBFile, Modeller, element as elem
        
        fixer = PDBFixer(filename=pdb_path)
        fixer.findMissingResidues()
        fixer.findNonstandardResidues()
        fixer.replaceNonstandardResidues()
        fixer.removeHeterogens(keepWater=False)
        
        # Skip expensive operations:
        # fixer.findMissingAtoms()
        # fixer.addMissingAtoms()
        # fixer.addMissingHydrogens(pH=ph)
        
        mod = Modeller(fixer.topology, fixer.positions)
        
        # Remove non-polar hydrogens
        to_delete = []
        for bond in mod.topology.bonds():
            a1, a2 = bond
            if a1.element == elem.hydrogen and a2.element == elem.carbon:
                to_delete.append(a1)
            elif a2.element == elem.hydrogen and a1.element == elem.carbon:
                to_delete.append(a2)
        
        if to_delete:
            mod.delete(to_delete)
        
        with open(output_path, 'w') as f:
            PDBFile.writeFile(mod.topology, mod.positions, f)
            
    except Exception as e:
        print(f"Error processing {pdb_path}: {e}")
        pass

def standardize_smiles_from_sdf(sdf_path: str) -> Optional[str]:
    """Generate standardized SMILES from SDF file."""
    try:
        mol = Chem.MolFromMolFile(sdf_path, removeHs=False)
        if mol is None:
            return None
        
        to_del = []
        for a in mol.GetAtoms():
            if a.GetAtomicNum() != 1:
                continue
            nbs = a.GetNeighbors()
            if nbs and nbs[0].GetAtomicNum() not in POLAR_HEAVY:
                to_del.append(a.GetIdx())
        
        if to_del:
            em = Chem.EditableMol(mol)
            for idx in sorted(to_del, reverse=True):
                em.RemoveAtom(idx)
            mol = em.GetMol()
        
        mol.UpdatePropertyCache(strict=False)
        AllChem.AssignAtomChiralTagsFromStructure(mol, replaceExistingTags=False)
        Chem.AssignStereochemistry(mol, force=True, cleanIt=False)
        
        mol = rdMolStandardize.Cleanup(mol)
        mol = rdMolStandardize.Normalizer().normalize(mol)
        mol = rdMolStandardize.FragmentParent(mol)
        mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)
        
        for atom in mol.GetAtoms():
            atom.SetIsotope(0)
        
        Chem.AssignStereochemistry(mol, force=True, cleanIt=True)
        return Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
        
    except Exception as e:
        print(f"Error processing {sdf_path}: {e}")
        return None

def independent_ligand_worker(task: WorkerTask, progress_queue: Optional[Queue] = None) -> Dict:
    """Independent worker for ligand processing."""
    results = {}
    
    for i, (in_path, out_path) in enumerate(task.items):
        global_idx = task.start_idx + i
        success = standardize_ligand(in_path, out_path)
        results[global_idx] = (out_path if success else None, success)
        
        if progress_queue is not None and (i + 1) % 10 == 0:
            progress_queue.put((task.worker_id, 10))
    
    if progress_queue is not None:
        remaining = len(task.items) % 10
        if remaining > 0:
            progress_queue.put((task.worker_id, remaining))
    
    return results

def independent_protein_worker(task: WorkerTask, progress_queue: Optional[Queue] = None) -> Dict:
    """Independent worker for protein processing."""
    results = {}
    
    for i, (in_path, out_path) in enumerate(task.items):
        global_idx = task.start_idx + i
        clean_protein_structure_minimal(in_path, out_path)
        success = os.path.exists(out_path)
        results[global_idx] = (out_path if success else None, success)
        
        if progress_queue is not None and (i + 1) % 10 == 0:
            progress_queue.put((task.worker_id, 10))
    
    if progress_queue is not None:
        remaining = len(task.items) % 10
        if remaining > 0:
            progress_queue.put((task.worker_id, remaining))
    
    return results

def independent_smiles_worker(task: WorkerTask, progress_queue: Optional[Queue] = None) -> Dict:
    """Independent worker for SMILES generation."""
    results = {}
    
    for i, sdf_path in enumerate(task.items):
        global_idx = task.start_idx + i
        smiles = standardize_smiles_from_sdf(sdf_path)
        results[global_idx] = smiles
        
        if progress_queue is not None and (i + 1) % 10 == 0:
            progress_queue.put((task.worker_id, 10))
    
    if progress_queue is not None:
        remaining = len(task.items) % 10
        if remaining > 0:
            progress_queue.put((task.worker_id, remaining))
    
    return results

class UltraFastStructureStandardizer:
    """Ultra-fast standardization with independent worker processing."""
    
    def __init__(self, n_workers: Optional[int] = None, show_worker_progress: bool = False):
        self.n_workers = n_workers or min(cpu_count() - 1, 32)
        self.show_worker_progress = show_worker_progress
        print(f"🚀 Ultra-fast mode: {self.n_workers} independent workers")
    
    def _distribute_work(self, items: List[Any]) -> List[WorkerTask]:
        """Distribute work evenly among workers."""
        n_items = len(items)
        chunk_size = n_items // self.n_workers
        remainder = n_items % self.n_workers
        
        tasks = []
        start_idx = 0
        
        for worker_id in range(self.n_workers):
            size = chunk_size + (1 if worker_id < remainder else 0)
            
            if size > 0:
                end_idx = start_idx + size
                task = WorkerTask(
                    worker_id=worker_id,
                    items=items[start_idx:end_idx],
                    start_idx=start_idx
                )
                tasks.append(task)
                start_idx = end_idx
        
        return tasks
    
    def _progress_monitor(self, queue: Queue, total: int, desc: str):
        """Monitor progress from all workers."""
        pbar = tqdm(total=total, desc=desc)
        worker_progress = {}
        
        while True:
            try:
                worker_id, count = queue.get(timeout=0.1)
                if worker_id == -1:
                    break
                    
                if self.show_worker_progress:
                    if worker_id not in worker_progress:
                        worker_progress[worker_id] = 0
                    worker_progress[worker_id] += count
                    pbar.set_postfix({f"W{k}": v for k, v in sorted(worker_progress.items())})
                
                pbar.update(count)
            except:
                continue
        
        pbar.close()
    
    def standardize_ligands(self, df: pd.DataFrame, input_col: str, output_dir: str, 
                           show_progress: bool = True) -> pd.DataFrame:
        """Standardize ligands with independent worker processing."""
        start_time = time.time()
        os.makedirs(output_dir, exist_ok=True)
        
        args = [(row[input_col], os.path.join(output_dir, f"{idx}.sdf"))
                for idx, row in df.iterrows()]
        
        tasks = self._distribute_work(args)
        print(f"📦 Distributed {len(args)} ligands to {len(tasks)} workers")
        
        manager = Manager() if show_progress else None
        progress_queue = manager.Queue() if show_progress else None
        
        if show_progress:
            monitor = Process(target=self._progress_monitor, 
                            args=(progress_queue, len(args), "Standardizing ligands"))
            monitor.start()
        
        with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
            futures = [executor.submit(independent_ligand_worker, task, progress_queue) 
                      for task in tasks]
            
            all_results = {}
            for future in as_completed(futures):
                worker_results = future.result()
                all_results.update(worker_results)
        
        if show_progress:
            progress_queue.put((-1, 0))
            monitor.join()
        
        sorted_results = [all_results[i] for i in sorted(all_results.keys())]
        df['standardized_ligand_sdf'] = [r[0] for r in sorted_results]
        
        success_count = sum(1 for r in sorted_results if r[1])
        elapsed = time.time() - start_time
        rate = len(args) / elapsed if elapsed > 0 else 0
        
        print(f"✅ Ligands: {success_count}/{len(args)} succeeded in {elapsed:.1f}s ({rate:.1f} structures/sec)")
        
        return df
    
    def standardize_proteins(self, df: pd.DataFrame, input_col: str, output_dir: str,
                            show_progress: bool = True) -> pd.DataFrame:
        """Standardize proteins with independent worker processing."""
        start_time = time.time()
        os.makedirs(output_dir, exist_ok=True)
        
        args = [(row[input_col], os.path.join(output_dir, f"{idx}.pdb"))
                for idx, row in df.iterrows()]
        
        tasks = self._distribute_work(args)
        print(f"📦 Distributed {len(args)} proteins to {len(tasks)} workers")
        
        manager = Manager() if show_progress else None
        progress_queue = manager.Queue() if show_progress else None
        
        if show_progress:
            monitor = Process(target=self._progress_monitor,
                            args=(progress_queue, len(args), "Standardizing proteins"))
            monitor.start()
        
        with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
            futures = [executor.submit(independent_protein_worker, task, progress_queue)
                      for task in tasks]
            
            all_results = {}
            for future in as_completed(futures):
                worker_results = future.result()
                all_results.update(worker_results)
        
        if show_progress:
            progress_queue.put((-1, 0))
            monitor.join()
        
        sorted_results = [all_results[i] for i in sorted(all_results.keys())]
        df['standardized_protein_pdb'] = [r[0] for r in sorted_results]
        
        success_count = sum(1 for r in sorted_results if r[1])
        elapsed = time.time() - start_time
        rate = len(args) / elapsed if elapsed > 0 else 0
        
        print(f"✅ Proteins: {success_count}/{len(args)} succeeded in {elapsed:.1f}s ({rate:.1f} structures/sec)")
        
        return df
    
    def standardize_smiles(self, df: pd.DataFrame, sdf_col: str = 'standardized_ligand_sdf',
                          show_progress: bool = True) -> pd.DataFrame:
        """Generate SMILES with independent worker processing."""
        start_time = time.time()
        sdf_paths = df[sdf_col].tolist()
        
        tasks = self._distribute_work(sdf_paths)
        print(f"📦 Distributed {len(sdf_paths)} SMILES tasks to {len(tasks)} workers")
        
        manager = Manager() if show_progress else None
        progress_queue = manager.Queue() if show_progress else None
        
        if show_progress:
            monitor = Process(target=self._progress_monitor,
                            args=(progress_queue, len(sdf_paths), "Generating SMILES"))
            monitor.start()
        
        with ProcessPoolExecutor(max_workers=self.n_workers) as executor:
            futures = [executor.submit(independent_smiles_worker, task, progress_queue)
                      for task in tasks]
            
            all_results = {}
            for future in as_completed(futures):
                worker_results = future.result()
                all_results.update(worker_results)
        
        if show_progress:
            progress_queue.put((-1, 0))
            monitor.join()
        
        sorted_smiles = [all_results[i] for i in sorted(all_results.keys())]
        df['std_smiles'] = sorted_smiles
        
        success_count = sum(1 for s in sorted_smiles if s is not None)
        elapsed = time.time() - start_time
        rate = len(sdf_paths) / elapsed if elapsed > 0 else 0
        
        print(f"✅ SMILES: {success_count}/{len(sdf_paths)} succeeded in {elapsed:.1f}s ({rate:.1f} SMILES/sec)")
        
        return df

    
from rdkit import RDLogger

# Disable RDKit warnings
RDLogger.DisableLog('rdApp.*')

# Initialize and run
standardizer = UltraFastStructureStandardizer(
    n_workers=64,  # Reduced from 64 for better efficiency
    show_worker_progress=True
)

# Standardize proteins
print("Standardizing proteins...")
df = standardizer.standardize_proteins(df, 'protein_pdb_path', './output/proteins/')
df = df.dropna(subset=['standardized_protein_pdb'])

# Standardize ligands
print("Standardizing ligands...")
df = standardizer.standardize_ligands(df, 'ligand_sdf_path', './output/ligands/')
df = df.dropna(subset=['standardized_ligand_sdf'])

# Generate standardized SMILES
print("Generating standardized SMILES...")
df = standardizer.standardize_smiles(df)

# Remove entries with missing standardized structures
print(f"Valid structures: {len(df)}")

🚀 Ultra-fast mode: 64 independent workers
Standardizing proteins...
📦 Distributed 477203 proteins to 64 workers
Error processing ../data/raw/BindingNetv2/moderate/target_CHEMBL5149/CHEMBL479209/protein.pdb: invalid literal for int() with base 16: 'A02FY'


In [1]:
1

1

In [2]:
1

1

In [None]:
df = df.dropna(subset=['standardized_protein_pdb', 'standardized_ligand_sdf'])


In [None]:
df.to_parquet('binding_standardized_v1_save.parquet', index = False)

In [None]:
# kill if nothing wrote at the end of the day




In [None]:
1

In [None]:

# Sample if specified
if CONFIG['sample_size']:
    df = df.sample(n=CONFIG['sample_size'], random_state=SEED).reset_index(drop=True)
    print(f"Sampled to {len(df)} entries")

# Filter for entries with required columns
required_cols = ['standardized_protein_pdb', 'standardized_ligand_sdf'] + CONFIG['task_cols']
df = df.dropna(how='all', subset=required_cols)
print(f"After filtering: {df.shape}")

# Add protein ID if not present
if 'protein_id' not in df.columns:
    df['protein_id'] = df['standardized_protein_pdb'].apply(
        lambda p: os.path.splitext(os.path.basename(p))[0] if pd.notnull(p) else None
    )

In [None]:
# Quick one-liner to get all non-NaN counts
df[['pKi', 'resolution', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency']].notna().sum()

# Or to see the info for all columns at once
df.info()  # This shows non-null count for all columns

In [None]:
df.to_parquet('binding_standardized.parquet', index = False)

# 6 : Filter complex

In [None]:
df = pd.read_parquet('binding_standardized.parquet')

In [None]:
df = df.dropna(subset=['standardized_protein_pdb', 'standardized_ligand_sdf'])

df

In [None]:
# Cell 5: Calculate Molecular Properties
if 'MolWt' not in df.columns:
    print("Calculating molecular properties...")
    df = add_molecular_properties_parallel(df, smiles_col='std_smiles')
    df = compute_ligand_efficiency(df, CONFIG['task_cols'])
    df = compute_mean_ligand_efficiency(df)
    print("Properties calculated")

# Display statistics
print("\nProperty Statistics:")
property_cols = ['MolWt', 'HeavyAtomCount', 'LogP', 'QED', 'LE', 'LE_norm']
for col in property_cols:
    if col in df.columns:
        print(f"{col}: {df[col].mean():.2f} ± {df[col].std():.2f}")

In [None]:
# Cell 6: Filter Data
print("Filtering data...")

# Apply property filters
df_filtered = filter_by_properties(
    df,
    min_heavy_atoms=CONFIG['filter_config']['min_heavy_atoms'],
    max_heavy_atoms=CONFIG['filter_config']['max_heavy_atoms'],
    max_mw=CONFIG['filter_config']['max_mw'],
    min_carbons=CONFIG['filter_config']['min_carbons'],
    min_le=CONFIG['filter_config']['min_le'] if 'LE' in df.columns else None,
    max_le_norm=CONFIG['filter_config']['max_le_norm'] if 'LE_norm' in df.columns else None
)

print(f"After filtering: {len(df)} -> {len(df_filtered)}")
df = df_filtered

# Remove duplicates
from gnn_dta_mtl.data.preprocessing import remove_duplicates
df = remove_duplicates(df, subset=['protein_id', 'std_smiles'])

print(f"Final dataset size: {len(df)}")

In [None]:
df.to_parquet("./featurization_set.parquet", index = False)

# 7 : Process Protein Structures

In [3]:
df = pd.read_parquet("./featurization_set.parquet")

In [4]:
import torch
import gc

# Clear cache
torch.cuda.empty_cache()

# Force garbage collection
gc.collect()

# If you have variables holding tensors
torch.cuda.empty_cache()

In [None]:
# Cell 7: Process Protein Structures

print("Processing protein structures and generating ESM embeddings...")

# Initialize structure processor
processor = StructureProcessor(
    esm_model_name=CONFIG['esm_model_name'],
    chunk_size=CONFIG['chunk_size'],
    max_workers=CONFIG['n_workers'],
    embed_dir=CONFIG['embeddings_dir'],
    out_dir=CONFIG['structure_chunks_dir']
)

# Process structures
metadata = processor.process_dataframe(df, pdb_col='standardized_protein_pdb')


Processing protein structures and generating ESM embeddings...
Processing 474647 unique PDBs in 10 chunks

[Chunk 0] Processing 50000 structures


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Chunk 0 - PDB parsing:   0%|          | 0/50000 [00:00<?, ?it/s]

[Chunk 0] Generating embeddings for 50000 proteins


  0%|          | 0/50000 [00:00<?, ?it/s]

[Chunk 0] ✅ Saved 50000 structures

[Chunk 1] Processing 50000 structures


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Chunk 1 - PDB parsing:   0%|          | 0/50000 [00:00<?, ?it/s]

In [None]:

# Create chunk loader
chunk_loader = StructureChunkLoader(
    chunk_dir=CONFIG['structure_chunks_dir'],
    cache_size=10
)

# Verify available structures
available_pdb_ids = chunk_loader.get_available_pdb_ids()
available_pdb_ids = [i.replace('@','/') for i in available_pdb_ids]

In [None]:
df = df[df['standardized_protein_pdb'].isin(available_pdb_ids)].reset_index(drop=True)
print(f"Structures available for {len(df)} entries")

In [None]:
df.to_parquet("./binding_set.parquet", index = False)

In [None]:
metadata

In [None]:
import pandas as pd
df = pd.read_parquet("./binding_set.parquet")

In [14]:
df.head().style

Unnamed: 0,protein_pdb_path,ligand_sdf_path,smiles,pKi,source_file,is_experimental,resolution,pEC50,"pKd (Wang, FEP)",pKd,pIC50,SMILES,potency,assay,standardized_protein_pdb,standardized_ligand_sdf,std_smiles,protein_id,InChIKey,MolWt,HeavyAtomCount,QED,NumHDonors,NumHAcceptors,NumRotatableBonds,TPSA,LogP,LE_pKi,LEnorm_pKi,LE_pEC50,LEnorm_pEC50,"LE_pKd (Wang, FEP)","LEnorm_pKd (Wang, FEP)",LE_pKd,LEnorm_pKd,LE_pIC50,LEnorm_pIC50,LE_potency,LEnorm_potency,LE,LE_norm,carbon_count
0,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL60581/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL60581/ligand.sdf,CCCCCCSCC(NC(=O)CCC(N)C(=O)O)C(=O)NCCC(=O)O,3.259637,BindingNetv2,False,,,,,,,,,./output/proteins/0.pdb,./output/ligands/0.sdf,CCCCCCSCC(NC(=O)CCC([NH3+])C(=O)[O-])C(=O)NCCC(=O)[O-],0,RILVFYFKIDXJNY-UHFFFAOYSA-M,404.509,27,0.233282,3,7,16,166.1,-2.8185,0.120727,0.000298,,,,,,,,,,,0.120727,0.000298,17
1,../data/raw/BindingNetv2/moderate/target_CHEMBL3902/CHEMBL58951/protein.pdb,../data/raw/BindingNetv2/moderate/target_CHEMBL3902/CHEMBL58951/ligand.sdf,NC(CCC(=O)NC(CSCc1ccccc1)C(=O)NC(C(=O)O)c1ccccc1)C(=O)O,6.376751,BindingNetv2,False,,,,,,,,,./output/proteins/1.pdb,./output/ligands/1.sdf,[NH3+]C(CCC(=O)NC(CSCc1ccccc1)C(=O)NC(C(=O)[O-])c1ccccc1)C(=O)[O-],1,ZPSKWMFLCHMEOY-UHFFFAOYSA-M,472.543,33,0.307698,3,7,13,166.1,-1.8474,0.193235,0.000409,,,,,,,,,,,0.193235,0.000409,23
2,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL301229/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL301229/ligand.sdf,Cc1ccc(CSCC(NC(=O)CCC(N)C(=O)O)C(=O)NCCC(=O)O)cc1,4.39794,BindingNetv2,False,,,,,,,,,./output/proteins/2.pdb,./output/ligands/2.sdf,Cc1ccc(CSCC(NC(=O)CCC([NH3+])C(=O)[O-])C(=O)NCCC(=O)[O-])cc1,2,MBXWAPNNAOGFPH-UHFFFAOYSA-M,424.499,29,0.305487,3,7,13,166.1,-2.89018,0.151653,0.000357,,,,,,,,,,,0.151653,0.000357,19
3,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL442360/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL442360/ligand.sdf,NC(CCC(=O)NC(CSCc1ccc(Cl)cc1)C(=O)NC(C(=O)O)c1ccccc1)C(=O)O,6.920819,BindingNetv2,False,,,,,,,,,./output/proteins/3.pdb,./output/ligands/3.sdf,[NH3+]C(CCC(=O)NC(CSCc1ccc(Cl)cc1)C(=O)NC(C(=O)[O-])c1ccccc1)C(=O)[O-],3,BXJSPWKYSSRFEB-UHFFFAOYSA-M,506.988,34,0.306713,3,7,13,166.1,-1.194,0.203553,0.000401,,,,,,,,,,,0.203553,0.000401,24
4,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL58451/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL58451/ligand.sdf,NC(CCC(=O)NC(CSCc1ccccc1)C(=O)NCCC(=O)O)C(=O)O,3.148742,BindingNetv2,False,,,,,,,,,./output/proteins/4.pdb,./output/ligands/4.sdf,[NH3+]C(CCC(=O)NC(CSCc1ccccc1)C(=O)NCCC(=O)[O-])C(=O)[O-],4,QLVGMERIDWMEBM-UHFFFAOYSA-M,410.472,28,0.307814,3,7,13,166.1,-3.1986,0.112455,0.000274,,,,,,,,,,,0.112455,0.000274,18


# TO DO : 
- speed up the standardization (less complex)
- error with some structure, mismatch of aa ? need better standardization, simple and faster

In [None]:
# 48 hours for 100k to standardized, can do better....and only 15k passed...

In [15]:
1

1

In [12]:
1

1

In [13]:
1

1