# Load functions

In [1]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from tqdm.auto import tqdm
from transformers import EsmModel, EsmTokenizer
from rdkit import RDLogger

# Add the parent directory to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # set cuda:0 OR cuda:1 to change the GPU
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

# Import 
from gnn_dta_mtl import (
    MTL_DTAModel,
    DTAPredictor,
    predict_affinity,
)

# Disable RDKit warnings
RDLogger.DisableLog('rdApp.*')

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

Using device: cuda
GPU: NVIDIA A100-SXM4-40GB
Number of GPUs: 16


# Load models

In [2]:
print("Loading ESM model...")
model_name = "facebook/esm2_t33_650M_UR50D"
model_checkpoint = './models/best_model.pt'
device='cuda'
tokenizer = EsmTokenizer.from_pretrained(model_name)
esm_model = EsmModel.from_pretrained(model_name)
esm_model.eval()
if torch.cuda.is_available():
    esm_model = esm_model.cuda()
print("✓ ESM model loaded")

print("Loading APEX model...")
# Default config
config = {
    'task_cols': ['pKi', 'pEC50', 'pKd', 'pIC50', 'pKd (Wang, FEP)', 'potency'],
    'model_config': {
        'prot_emb_dim': 1280,
        'prot_gcn_dims': [128, 256, 256],
        'prot_fc_dims': [1024, 128],
        'drug_node_in_dim': [66, 1],
        'drug_node_h_dims': [128, 64],
        'drug_edge_in_dim': [16, 1],
        'drug_edge_h_dims': [32, 1],
        'drug_fc_dims': [1024, 128],
        'mlp_dims': [1024, 512],
        'mlp_dropout': 0.25
    }
}
# Load model
model = MTL_DTAModel(
    task_names=config['task_cols'],
    **config['model_config']
).to(device)
predictor = DTAPredictor(model, model_checkpoint, device=device, esm_model=esm_model)
print("✓ APEX model loaded")

Loading ESM model...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ ESM model loaded
Loading APEX model...
✓ Model loaded from ./models/best_model.pt
✓ Using device: cuda
✓ APEX model loaded


# Data (change with your data)

In [10]:
# Your test complexes
test_complexes = [
    ('./test/245158.pdb', './test/245158.sdf'),
    ('./test/279794.pdb', './test/279794.sdf'),
    ('./test/203418.pdb', './test/203418.sdf'),
    ('./test/48008.pdb', './test/48008.sdf'),
    ('./test/434516.pdb', './test/434516.sdf'),
    ('./test/3705.pdb', './test/3705.sdf'),
    ('./test/323023.pdb', './test/323023.sdf'),
    ('./test/251630.pdb', './test/251630.sdf'),
    ('./test/472746.pdb', './test/472746.sdf'),
    ('./test/24704.pdb', './test/24704.sdf')
]

# OR from a .csv or .parquet file as below. Should be a csv with col target and col ligand
# test_complexes = list(pd.read_csv('./input/examples.csv')[['target', 'ligand']].itertuples(index=False, name=None))


[('./test/245158.pdb', ' ./test/245158.sdf'),
 ('./test/279794.pdb', ' ./test/279794.sdf'),
 ('./test/203418.pdb', ' ./test/203418.sdf'),
 ('./test/48008.pdb', ' ./test/48008.sdf'),
 ('./test/434516.pdb', ' ./test/434516.sdf'),
 ('./test/3705.pdb', ' ./test/3705.sdf'),
 ('./test/323023.pdb', ' ./test/323023.sdf'),
 ('./test/251630.pdb', ' ./test/251630.sdf'),
 ('./test/472746.pdb', ' ./test/472746.sdf'),
 ('./test/24704.pdb', ' ./test/24704.sdf')]

# Inference

In [4]:
# Get predictions with real ESM embeddings
predictions = predict_affinity(
    protein_ligand_pairs=test_complexes,
    output_path='./predictions/affinity_predictions.csv',
    device=device,
    predictor=predictor,
    esm_model=esm_model,  
    fast=False # True if >1000 complex for scaling
)

Featurizing all protein-ligand pairs...


Featurizing: 100%|██████████| 10/10 [00:02<00:00,  3.92it/s]


Running batch prediction...
Predictions saved to affinity_predictions.csv


In [5]:
# Ignore the potency task, it is a futur head that will be covered once SAIR is integreted.
predictions

Unnamed: 0,protein_path,ligand_path,pKi,pEC50,pKd,pIC50,"pKd (Wang, FEP)",potency
0,./test/245158.pdb,./test/245158.sdf,7.637315,6.891932,6.341615,6.878548,7.291972,0.053844
1,./test/279794.pdb,./test/279794.sdf,8.614182,7.789924,8.435829,7.966605,8.19705,0.069151
2,./test/203418.pdb,./test/203418.sdf,8.009085,6.813055,7.741957,7.76549,7.860967,0.123765
3,./test/48008.pdb,./test/48008.sdf,8.159952,7.957793,8.316508,7.484304,7.904311,0.013412
4,./test/434516.pdb,./test/434516.sdf,6.489398,6.41966,6.763867,6.206766,6.37801,0.074012
5,./test/3705.pdb,./test/3705.sdf,4.400119,3.93528,3.692814,3.779052,3.783771,0.096013
6,./test/323023.pdb,./test/323023.sdf,7.812621,7.259854,7.734333,7.102062,7.284111,0.022825
7,./test/251630.pdb,./test/251630.sdf,7.984521,7.570292,7.843274,7.310745,7.629936,0.032966
8,./test/472746.pdb,./test/472746.sdf,4.777402,4.56088,3.996355,4.426298,4.796299,0.189638
9,./test/24704.pdb,./test/24704.sdf,7.860795,7.160041,7.636954,7.765124,7.64901,0.021086
