# 1 : Imports and Setup

In [12]:
import os
import sys
import json
import gc
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch
import torch_geometric
from pathlib import Path
from sklearn.model_selection import KFold, train_test_split
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Add the parent directory to the Python path
# This allows importing the gnn_dta_mtl package
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

# Import your package - use absolute import instead of relative
from gnn_dta_mtl import (
    MTL_DTAModel, DTAModel,
    MTL_DTA, DTA,
    CrossValidator, MTLTrainer,
    StructureStandardizer, StructureProcessor, StructureChunkLoader,
    ESMEmbedder,
    add_molecular_properties_parallel,
    compute_ligand_efficiency,
    compute_mean_ligand_efficiency,
    filter_by_properties,
    prepare_mtl_experiment,
    build_mtl_dataset, build_mtl_dataset_optimized,
    evaluate_model,
    plot_results, plot_predictions, create_summary_report,
    ExperimentLogger,
    save_model, save_results, create_output_dir
)

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

Using device: cuda
GPU: Tesla V100-SXM2-16GB
Number of GPUs: 4


# 2: Configuration


In [13]:
import os
import json
from pathlib import Path
from datetime import datetime

# Create necessary directories first
base_dirs = [
    '../input/combined',
    '../input/chunk',
    '../input/embeddings',
    '../output/protein',
    '../output/ligand',
    '../output/experiments',
]

for dir_path in base_dirs:
    Path(dir_path).mkdir(parents=True, exist_ok=True)
    print(f"✓ Created: {dir_path}")

CONFIG = {
    # Data paths
    'data_path': '../data/curated/combined/df_combined.parquet',
    'protein_out_dir': '../output/protein',
    'ligand_out_dir': '../output/ligand',
    'structure_chunks_dir': '../input/chunk/',
    'embeddings_dir': '../input/embeddings/',
    'output_dir': '../output/experiments/',
    
    # Task configuration
    'task_cols': ['pKi', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency'],
    
    # Model configuration
    'model_config': {
        'prot_emb_dim': 1280,
        'prot_gcn_dims': [128, 256, 256],
        'prot_fc_dims': [1024, 128],
        'drug_node_in_dim': [66, 1],
        'drug_node_h_dims': [128, 64],
        'drug_edge_in_dim': [16, 1],
        'drug_edge_h_dims': [32, 1],
        'drug_fc_dims': [1024, 128],
        'mlp_dims': [1024, 512],
        'mlp_dropout': 0.25
    },
    
    # Training configuration
    'training_config': {
        'batch_size': 128,
        'n_epochs': 200,
        'learning_rate': 0.0005,
        'patience': 100,
        'n_folds': 3
    },
    
    # Data filtering
    'filter_config': {
        'min_heavy_atoms': 5,
        'max_heavy_atoms': 75,
        'max_mw': 1000,
        'min_carbons': 4,
        'min_le': 0.05,
        'max_le_norm': 0.003
    },
    
    # Processing
    'n_workers': os.cpu_count() - 1 if os.cpu_count() else 112,
    'chunk_size': 50000,
    'sample_size': None,  # Set to integer to limit data size for testing
    
    # ESM model
    'esm_model_name': 'facebook/esm2_t33_650M_UR50D'
}

# Create experiment directory with timestamp
experiment_name = 'gnn_dta_mtl_experiment'
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_dir = Path(CONFIG['output_dir']) / f"{experiment_name}_{timestamp}"

# Create subdirectories
(experiment_dir / 'models').mkdir(parents=True, exist_ok=True)
(experiment_dir / 'results').mkdir(parents=True, exist_ok=True)
(experiment_dir / 'figures').mkdir(parents=True, exist_ok=True)
(experiment_dir / 'logs').mkdir(parents=True, exist_ok=True)

CONFIG['experiment_dir'] = str(experiment_dir)

# Save configuration
config_path = experiment_dir / 'config.json'
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)

print(f"✓ Experiment directory: {CONFIG['experiment_dir']}")
print(f"✓ Configuration saved to: {config_path}")

✓ Created: ../input/combined
✓ Created: ../input/chunk
✓ Created: ../input/embeddings
✓ Created: ../output/protein
✓ Created: ../output/ligand
✓ Created: ../output/experiments
✓ Experiment directory: ../output/experiments/gnn_dta_mtl_experiment_20250919_102448
✓ Configuration saved to: ../output/experiments/gnn_dta_mtl_experiment_20250919_102448/config.json


# 3 : Load Data

In [14]:
print("Loading data...")
df = pd.read_parquet(CONFIG['data_path'])
print(f"Initial data shape: {df.shape}")
df.head().style

Loading data...
Initial data shape: (550663, 14)


Unnamed: 0,protein_pdb_path,ligand_sdf_path,smiles,pKi,source_file,is_experimental,resolution,pEC50,"pKd (Wang, FEP)",pKd,pIC50,SMILES,potency,assay
0,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL60581/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL60581/ligand.sdf,CCCCCCSCC(NC(=O)CCC(N)C(=O)O)C(=O)NCCC(=O)O,3.259637,BindingNetv2,False,,,,,,,,
1,../data/raw/BindingNetv2/moderate/target_CHEMBL3902/CHEMBL58951/protein.pdb,../data/raw/BindingNetv2/moderate/target_CHEMBL3902/CHEMBL58951/ligand.sdf,NC(CCC(=O)NC(CSCc1ccccc1)C(=O)NC(C(=O)O)c1ccccc1)C(=O)O,6.376751,BindingNetv2,False,,,,,,,,
2,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL301229/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL301229/ligand.sdf,Cc1ccc(CSCC(NC(=O)CCC(N)C(=O)O)C(=O)NCCC(=O)O)cc1,4.39794,BindingNetv2,False,,,,,,,,
3,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL442360/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL442360/ligand.sdf,NC(CCC(=O)NC(CSCc1ccc(Cl)cc1)C(=O)NC(C(=O)O)c1ccccc1)C(=O)O,6.920819,BindingNetv2,False,,,,,,,,
4,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL58451/protein.pdb,../data/raw/BindingNetv2/high/target_CHEMBL3902/CHEMBL58451/ligand.sdf,NC(CCC(=O)NC(CSCc1ccccc1)C(=O)NCCC(=O)O)C(=O)O,3.148742,BindingNetv2,False,,,,,,,,


In [15]:
# Quick one-liner to get all non-NaN counts
df[['pKi', 'resolution', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency']].notna().sum()

# Or to see the info for all columns at once
df.info()  # This shows non-null count for all columns

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 550663 entries, 0 to 550662
Data columns (total 14 columns):
 #   Column            Non-Null Count   Dtype  
---  ------            --------------   -----  
 0   protein_pdb_path  550663 non-null  object 
 1   ligand_sdf_path   550663 non-null  object 
 2   smiles            477203 non-null  object 
 3   pKi               115551 non-null  float64
 4   source_file       550663 non-null  object 
 5   is_experimental   550663 non-null  bool   
 6   resolution        9486 non-null    float64
 7   pEC50             67187 non-null   float64
 8   pKd (Wang, FEP)   1894 non-null    float64
 9   pKd               20890 non-null   float64
 10  pIC50             271665 non-null  float64
 11  SMILES            73460 non-null   object 
 12  potency           73460 non-null   float64
 13  assay             73460 non-null   object 
dtypes: bool(1), float64(7), object(6)
memory usage: 55.1+ MB


In [16]:
df.source_file.value_counts()

source_file
BindingNetv2                        392967
processed_data                       73460
BindingNetv1                         68738
PDBbind2020                           5118
HiQBind                               4429
BioLip2                               4057
FEP_Zariquiey_extended_Wang_2015      1651
FEP_Wang_2015                          243
Name: count, dtype: int64

# 4 : Data pack for reduced size

In [17]:
# df = df[df["source_file"].isin(['BioLip2', 'HiQBind','FEP_Zariquiey_extended_Wang_2015', 'FEP_Wang_2015', 'PDBbind2020'])]
df = df[df["source_file"].isin(['BioLip2', 'HiQBind','FEP_Zariquiey_extended_Wang_2015', 'FEP_Wang_2015', 'PDBbind2020'])]

df

Unnamed: 0,protein_pdb_path,ligand_sdf_path,smiles,pKi,source_file,is_experimental,resolution,pEC50,"pKd (Wang, FEP)",pKd,pIC50,SMILES,potency,assay
93335,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]C1:C([H]):C([H]):C2:C(:C:1[H])C(=O)N=C(C([H...,3.004365,PDBbind2020,True,1.55,,,,,,,
93336,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]ON1C(=S)C([H])=C(C([H])([H])[H])C([H])=C1C(...,3.022276,PDBbind2020,True,1.55,,,,,,,
93337,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]OC([H])([H])C1:C([H]):C([H]):C2:C(:C:1[H])O...,3.026872,PDBbind2020,True,1.55,,,,,,,
93338,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(/C1:C([H]):C([H]):C([H]):C([H]):C:1[H]...,3.040959,PDBbind2020,True,2.30,,,,,,,
93339,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]OC([H])([H])[C@@]1([H])O[C@@]([H])(SC([H])(...,3.040959,PDBbind2020,True,2.00,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
550658,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,OC[C@H]1O[C@H](O[PH](O)(O)O[PH](O)(O)OC[C@H]2O...,,BioLip2,True,,,,3.124939,,,,
550659,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,C[C@H]1S[C@H]2NC(N)N[C@@H](O)[C@@H]2C1SC1CCC(C...,,BioLip2,True,,,,7.698970,,,,
550660,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,CCCCCCC1CCC(C(O)NNC(S)NC)O1,,BioLip2,True,,,,4.096910,,,,
550661,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,CCOC(O)C1NNN2C3CCCCC3C(S[C@H](C(C)O)C(O)OCC)NC12,,BioLip2,True,,,,3.920819,,,,


In [18]:
# Quick one-liner to get all non-NaN counts

# Or to see the info for all columns at once
df.info()  # This shows non-null count for all columns

<class 'pandas.core.frame.DataFrame'>
Index: 15498 entries, 93335 to 550662
Data columns (total 14 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   protein_pdb_path  15498 non-null  object 
 1   ligand_sdf_path   15498 non-null  object 
 2   smiles            15498 non-null  object 
 3   pKi               5286 non-null   float64
 4   source_file       15498 non-null  object 
 5   is_experimental   15498 non-null  bool   
 6   resolution        9486 non-null   float64
 7   pEC50             263 non-null    float64
 8   pKd (Wang, FEP)   1894 non-null   float64
 9   pKd               5073 non-null   float64
 10  pIC50             2978 non-null   float64
 11  SMILES            0 non-null      object 
 12  potency           0 non-null      float64
 13  assay             0 non-null      object 
dtypes: bool(1), float64(7), object(6)
memory usage: 1.7+ MB


In [19]:
df.source_file.value_counts()

source_file
PDBbind2020                         5118
HiQBind                             4429
BioLip2                             4057
FEP_Zariquiey_extended_Wang_2015    1651
FEP_Wang_2015                        243
Name: count, dtype: int64

# 5 : Standardize data
- identify errors, most of bindingnet 1 no protein in the pdb

In [None]:
1

In [21]:
df["std_smiles"] = df["smiles"].tolist()
df["standardized_protein_pdb"] = df["protein_pdb_path"].tolist()
df["standardized_ligand_sdf"] = df["ligand_sdf_path"].tolist()

In [22]:
1

1

In [23]:
11

11

In [24]:
1

1

In [25]:

# Sample if specified
if CONFIG['sample_size']:
    df = df.sample(n=CONFIG['sample_size'], random_state=SEED).reset_index(drop=True)
    print(f"Sampled to {len(df)} entries")

# Filter for entries with required columns
required_cols = ['standardized_protein_pdb', 'standardized_ligand_sdf'] + CONFIG['task_cols']
df = df.dropna(how='all', subset=required_cols)
print(f"After filtering: {df.shape}")

# Add protein ID if not present
if 'protein_id' not in df.columns:
    df['protein_id'] = df['standardized_protein_pdb'].apply(
        lambda p: os.path.splitext(os.path.basename(p))[0] if pd.notnull(p) else None
    )

After filtering: (15498, 17)


In [26]:
# Quick one-liner to get all non-NaN counts
df[['pKi', 'resolution', 'pEC50', 'pKd (Wang, FEP)', 'pKd', 'pIC50', 'potency']].notna().sum()

# Or to see the info for all columns at once
df.info()  # This shows non-null count for all columns

<class 'pandas.core.frame.DataFrame'>
Index: 15498 entries, 93335 to 550662
Data columns (total 18 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   protein_pdb_path          15498 non-null  object 
 1   ligand_sdf_path           15498 non-null  object 
 2   smiles                    15498 non-null  object 
 3   pKi                       5286 non-null   float64
 4   source_file               15498 non-null  object 
 5   is_experimental           15498 non-null  bool   
 6   resolution                9486 non-null   float64
 7   pEC50                     263 non-null    float64
 8   pKd (Wang, FEP)           1894 non-null   float64
 9   pKd                       5073 non-null   float64
 10  pIC50                     2978 non-null   float64
 11  SMILES                    0 non-null      object 
 12  potency                   0 non-null      float64
 13  assay                     0 non-null      object 
 14  std_sm

In [27]:
df.to_parquet('binding_standardized.parquet', index = False)

In [28]:
df = pd.read_parquet('binding_standardized.parquet')

In [29]:
df = df.dropna(subset=['standardized_protein_pdb', 'standardized_ligand_sdf'])

df

Unnamed: 0,protein_pdb_path,ligand_sdf_path,smiles,pKi,source_file,is_experimental,resolution,pEC50,"pKd (Wang, FEP)",pKd,pIC50,SMILES,potency,assay,std_smiles,standardized_protein_pdb,standardized_ligand_sdf,protein_id
0,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]C1:C([H]):C([H]):C2:C(:C:1[H])C(=O)N=C(C([H...,3.004365,PDBbind2020,True,1.55,,,,,,,,[H]C1:C([H]):C([H]):C2:C(:C:1[H])C(=O)N=C(C([H...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,4llk_protein
1,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]ON1C(=S)C([H])=C(C([H])([H])[H])C([H])=C1C(...,3.022276,PDBbind2020,True,1.55,,,,,,,,[H]ON1C(=S)C([H])=C(C([H])([H])[H])C([H])=C1C(...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,4q81_protein
2,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]OC([H])([H])C1:C([H]):C([H]):C2:C(:C:1[H])O...,3.026872,PDBbind2020,True,1.55,,,,,,,,[H]OC([H])([H])C1:C([H]):C([H]):C2:C(:C:1[H])O...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,4lm2_protein
3,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]/N=C(/C1:C([H]):C([H]):C([H]):C([H]):C:1[H]...,3.040959,PDBbind2020,True,2.30,,,,,,,,[H]/N=C(/C1:C([H]):C([H]):C([H]):C([H]):C:1[H]...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,1rtf_protein
4,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,[H]OC([H])([H])[C@@]1([H])O[C@@]([H])(SC([H])(...,3.040959,PDBbind2020,True,2.00,,,,,,,,[H]OC([H])([H])[C@@]1([H])O[C@@]([H])(SC([H])(...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,../data/raw/PDBbind2020/PDBbind2020/main/refin...,3t08_protein
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15493,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,OC[C@H]1O[C@H](O[PH](O)(O)O[PH](O)(O)OC[C@H]2O...,,BioLip2,True,,,,3.124939,,,,,OC[C@H]1O[C@H](O[PH](O)(O)O[PH](O)(O)OC[C@H]2O...,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,3gf4A
15494,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,C[C@H]1S[C@H]2NC(N)N[C@@H](O)[C@@H]2C1SC1CCC(C...,,BioLip2,True,,,,7.698970,,,,,C[C@H]1S[C@H]2NC(N)N[C@@H](O)[C@@H]2C1SC1CCC(C...,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,3ghwA
15495,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,CCCCCCC1CCC(C(O)NNC(S)NC)O1,,BioLip2,True,,,,4.096910,,,,,CCCCCCC1CCC(C(O)NNC(S)NC)O1,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,3gk1A
15496,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,CCOC(O)C1NNN2C3CCCCC3C(S[C@H](C(C)O)C(O)OCC)NC12,,BioLip2,True,,,,3.920819,,,,,CCOC(O)C1NNN2C3CCCCC3C(S[C@H](C(C)O)C(O)OCC)NC12,../data/raw/BioLip2/biolip_downloads/biolip_re...,../data/raw/BioLip2/biolip_downloads/biolip_re...,3gk4X


# 6 : Filter complex

In [30]:
# Cell 5: Calculate Molecular Properties
if 'MolWt' not in df.columns:
    print("Calculating molecular properties...")
    df = add_molecular_properties_parallel(df, smiles_col='std_smiles')
    df = compute_ligand_efficiency(df, CONFIG['task_cols'])
    df = compute_mean_ligand_efficiency(df)
    print("Properties calculated")

# Display statistics
print("\nProperty Statistics:")
property_cols = ['MolWt', 'HeavyAtomCount', 'LogP', 'QED', 'LE', 'LE_norm']
for col in property_cols:
    if col in df.columns:
        print(f"{col}: {df[col].mean():.2f} ± {df[col].std():.2f}")

Calculating molecular properties...


Computing properties:   0%|          | 16/15498 [00:00<13:00, 19.84it/s][10:27:12] Can't kekulize mol.  Unkekulized atoms: 3 4 5 6 7
[10:27:12] Can't kekulize mol.  Unkekulized atoms: 4 5 6 8 9
[10:27:12] Can't kekulize mol.  Unkekulized atoms: 0 1 5 6 7 8 9 10 11
Computing properties:   0%|          | 38/15498 [00:01<04:19, 59.54it/s][10:27:12] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 4 5 9 10 11
[10:27:12] Can't kekulize mol.  Unkekulized atoms: 14 15 16 17 18
[10:27:12] Can't kekulize mol.  Unkekulized atoms: 9 10 11 12 13
[10:27:12] Can't kekulize mol.  Unkekulized atoms: 1 2 3 5 6 7 8 9 10
Computing properties:   0%|          | 54/15498 [00:01<05:03, 50.83it/s][10:27:13] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 4 6 7 9 10
Computing properties:   0%|          | 69/15498 [00:01<06:02, 42.62it/s][10:27:13] Can't kekulize mol.  Unkekulized atoms: 0 1 5 6 7 8 9 10 11
[10:27:13] Can't kekulize mol.  Unkekulized atoms: 4 5 6 7 8 9 11
Computing properties:   0%|          | 7

Properties calculated

Property Statistics:
MolWt: 388.23 ± 153.36
HeavyAtomCount: 26.61 ± 10.62
LogP: 1.08 ± 2.99
QED: 0.47 ± 0.22
LE: 0.29 ± 0.13
LE_norm: 0.00 ± 0.00


In [31]:
# Cell 6: Filter Data
print("Filtering data...")

# Apply property filters
df_filtered = filter_by_properties(
    df,
    min_heavy_atoms=CONFIG['filter_config']['min_heavy_atoms'],
    max_heavy_atoms=CONFIG['filter_config']['max_heavy_atoms'],
    max_mw=CONFIG['filter_config']['max_mw'],
    min_carbons=CONFIG['filter_config']['min_carbons'],
    min_le=CONFIG['filter_config']['min_le'] if 'LE' in df.columns else None,
    max_le_norm=CONFIG['filter_config']['max_le_norm'] if 'LE_norm' in df.columns else None
)

print(f"After filtering: {len(df)} -> {len(df_filtered)}")
df = df_filtered

# Remove duplicates
from gnn_dta_mtl.data.preprocessing import remove_duplicates
df = remove_duplicates(df, subset=['protein_id', 'std_smiles'])

print(f"Final dataset size: {len(df)}")

Filtering data...
After filtering: 15498 -> 14034
Removed 2695 duplicates
Final dataset size: 11339


In [32]:
df.to_parquet("./featurization_set.parquet", index = False)

# 7 : Process Protein Structures

In [33]:
df = pd.read_parquet("./featurization_set.parquet")

In [34]:
import torch
import gc

# Clear cache
torch.cuda.empty_cache()

# Force garbage collection
gc.collect()

# If you have variables holding tensors
torch.cuda.empty_cache()

In [35]:
# Cell 7: Process Protein Structures

print("Processing protein structures and generating ESM embeddings...")

# Initialize structure processor
processor = StructureProcessor(
    esm_model_name=CONFIG['esm_model_name'],
    chunk_size=CONFIG['chunk_size'],
    max_workers=CONFIG['n_workers'],
    embed_dir=CONFIG['embeddings_dir'],
    out_dir=CONFIG['structure_chunks_dir']
)

# Process structures
metadata = processor.process_dataframe(df, pdb_col='standardized_protein_pdb')

# Create chunk loader
chunk_loader = StructureChunkLoader(
    chunk_dir=CONFIG['structure_chunks_dir'],
    cache_size=2
)

# Verify available structures
available_pdb_ids = chunk_loader.get_available_pdb_ids()
available_pdb_ids = [i.replace('@','/') for i in available_pdb_ids]

Processing protein structures and generating ESM embeddings...
Processing 11202 unique PDBs in 1 chunks

[Chunk 0] Processing 11202 structures


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Chunk 0 - PDB parsing:   0%|          | 0/11202 [00:00<?, ?it/s]

[Chunk 0] Generating embeddings for 11202 proteins


  0%|          | 0/11202 [00:00<?, ?it/s]

[Chunk 0] ✅ Saved 11202 structures

✅ Processing complete!
  Total structures: 11202
  Metadata saved: ../input/chunk/chunk_metadata.json
Loaded 11202 structures from 1 chunks


In [36]:
df = df[df['standardized_protein_pdb'].isin(available_pdb_ids)].reset_index(drop=True)
print(f"Structures available for {len(df)} entries")

Structures available for 11339 entries


In [37]:
df.to_parquet("./binding_set.parquet", index = False)

In [None]:
metadata

{'num_chunks': 1,
 'chunk_size': 50000,
 'total_structures': 11202,
 'chunks': [{'chunk_idx': 0,
   'filename': 'structures_chunk_0000.json',
   'path': '../input/chunk/structures_chunk_0000.json',
   'num_structures': 11202,
   'num_errors': 0}]}

In [None]:
import pandas as pd
df = pd.read_parquet("./binding_set.parquet")

In [None]:
df.head().style

Unnamed: 0,protein_pdb_path,ligand_sdf_path,smiles,pKi,source_file,is_experimental,resolution,pEC50,"pKd (Wang, FEP)",pKd,pIC50,SMILES,potency,assay,std_smiles,standardized_protein_pdb,standardized_ligand_sdf,protein_id,InChIKey,MolWt,HeavyAtomCount,QED,NumHDonors,NumHAcceptors,NumRotatableBonds,TPSA,LogP,LE_pKi,LEnorm_pKi,LE_pEC50,LEnorm_pEC50,"LE_pKd (Wang, FEP)","LEnorm_pKd (Wang, FEP)",LE_pKd,LEnorm_pKd,LE_pIC50,LEnorm_pIC50,LE_potency,LEnorm_potency,LE,LE_norm,carbon_count
0,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4llk/4llk_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4llk/4llk_ligand.sdf,[H]C1:C([H]):C([H]):C2:C(:C:1[H])C(=O)N=C(C([H])([H])[H])N2[H],3.004365,PDBbind2020,True,1.55,,,,,,,,[H]C1:C([H]):C([H]):C2:C(:C:1[H])C(=O)N=C(C([H])([H])[H])N2[H],../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4llk/4llk_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4llk/4llk_ligand.sdf,4llk_protein,FIEYHAAMDAPVCH-UHFFFAOYSA-N,160.176,12.0,0.629623,1.0,2.0,0.0,45.75,1.23152,0.250364,0.001563,,,,,,,,,,,0.250364,0.001563,9
1,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4q81/4q81_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4q81/4q81_ligand.sdf,[H]ON1C(=S)C([H])=C(C([H])([H])[H])C([H])=C1C([H])([H])[H],3.022276,PDBbind2020,True,1.55,,,,,,,,[H]ON1C(=S)C([H])=C(C([H])([H])[H])C([H])=C1C([H])([H])[H],../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4q81/4q81_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4q81/4q81_ligand.sdf,4q81_protein,OZNBIIYHOPDSLX-UHFFFAOYSA-N,155.222,10.0,0.458055,1.0,3.0,0.0,25.16,2.07173,0.302228,0.001947,,,,,,,,,,,0.302228,0.001947,7
2,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4lm2/4lm2_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4lm2/4lm2_ligand.sdf,[H]OC([H])([H])C1:C([H]):C([H]):C2:C(:C:1[H])OC([H])([H])C([H])([H])O2,3.026872,PDBbind2020,True,1.55,,,,,,,,[H]OC([H])([H])C1:C([H]):C([H]):C2:C(:C:1[H])OC([H])([H])C([H])([H])O2,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4lm2/4lm2_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/4lm2/4lm2_ligand.sdf,4lm2_protein,FFLHNBGNAWYMRH-UHFFFAOYSA-N,166.176,12.0,0.674658,1.0,3.0,1.0,38.69,0.9501,0.252239,0.001518,,,,,,,,,,,0.252239,0.001518,9
3,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/1rtf/1rtf_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/1rtf/1rtf_ligand.sdf,[H]/N=C(/C1:C([H]):C([H]):C([H]):C([H]):C:1[H])N([H])[H],3.040959,PDBbind2020,True,2.3,,,,,,,,[H]/N=C(/C1:C([H]):C([H]):C([H]):C([H]):C:1[H])N([H])[H],../data/raw/PDBbind2020/PDBbind2020/main/refined-set/1rtf/1rtf_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/1rtf/1rtf_ligand.sdf,1rtf_protein,PXXJHWLDUBFPOL-UHFFFAOYSA-N,120.155,9.0,0.420789,2.0,1.0,1.0,49.87,0.97067,0.337884,0.002812,,,,,,,,,,,0.337884,0.002812,7
4,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/3t08/3t08_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/3t08/3t08_ligand.sdf,[H]OC([H])([H])[C@@]1([H])O[C@@]([H])(SC([H])(C([H])([H])[H])C([H])([H])[H])[C@]([H])(O[H])[C@@]([H])(O[H])[C@@]1([H])O[H],3.040959,PDBbind2020,True,2.0,,,,,,,,[H]OC([H])([H])[C@@]1([H])O[C@@]([H])(SC([H])(C([H])([H])[H])C([H])([H])[H])[C@]([H])(O[H])[C@@]([H])(O[H])[C@@]1([H])O[H],../data/raw/PDBbind2020/PDBbind2020/main/refined-set/3t08/3t08_protein.pdb,../data/raw/PDBbind2020/PDBbind2020/main/refined-set/3t08/3t08_ligand.sdf,3t08_protein,BPHPUYQFMNQIOC-NXRLNHOXSA-N,238.305,15.0,0.500348,4.0,6.0,3.0,90.15,-1.0721,0.202731,0.000851,,,,,,,,,,,0.202731,0.000851,9


# TO DO : 
- speed up the standardization (less complex)
- error with some structure, mismatch of aa ? need better standardization, simple and faster

In [None]:
# 48 hours for 100k to standardized, can do better....and only 15k passed...

In [None]:
1

In [None]:
1

In [None]:
1