In [1]:
from rdkit import Chem
from ligand_features import LigandFeature
from load_receptor import ReceptorFile
from pocket_features import PocketFeatures
from cplx_graph import ComplexGraph
import dgl
from dgl.dataloading import GraphDataLoader
from model import *
from torch.utils.data import DataLoader, Dataset
from utils import run_an_eval_epoch, sdf_split, mol2_split
import pandas as pd
import numpy as np

In [2]:
class MyDataset(Dataset):
    def __init__(self, rec_gs, cplx_gs):
        self.rec_gs = rec_gs
        self.cplx_gs = cplx_gs

    def __getitem__(self, idx):
        return idx, self.rec_gs[idx], self.cplx_gs[idx]
    def __len__(self):
        return len(self.rec_gs)

# Create Dataset

In [3]:
prefix = "1bcu"
rec_fpath = "../samples/1bcu/1bcu_protein_atom_noHETATM.pdb"
ref_lig_fpath = "../samples/1bcu/1bcu_ligand.sdf"
pose_fpath = "../samples/1bcu/1bcu_decoys.sdf"
model = torch.load("../models/saved_model.pth", map_location=torch.device('cpu'))

In [4]:
# parse poses 
if pose_fpath.endswith("sdf"):
    poses_content = sdf_split(pose_fpath)
elif pose_fpath.endswith("mol2"):
    poses_content = mol2_split(pose_fpath)
else:
    print("InputError: Please input the pose file with .sdf or .mol2 format.")
    
print(f"Number of poses is {len(poses_content)}")

Number of poses is 10


In [5]:
# generate pocket graph
rec = ReceptorFile(rec_fpath=rec_fpath, ref_lig_fpath=ref_lig_fpath)
rec.clip_rec()
rec.define_pocket()
print("receptor clipped ...")

# parse pocket
pock_feat = PocketFeatures(rec, pock_center=rec.pock_center)
pock_g = pock_feat.pock_to_graph()
pock_g = dgl.add_self_loop(pock_g)
print("protein pocket graph generated ...")

receptor clipped ...
Warnning: psi feats calculate failed ...
protein pocket graph generated ...


The valence field specifies a valence 3 that is
less than the observed explicit valence 4.

The valence field specifies a valence 3 that is
less than the observed explicit valence 4.
The valence field specifies a valence 3 that is
less than the observed explicit valence 4.

The valence field specifies a valence 3 that is
less than the observed explicit valence 4.
The valence field specifies a valence 3 that is
less than the observed explicit valence 4.
The valence field specifies a valence 3 that is
less than the observed explicit valence 4.

The valence field specifies a valence 3 that is
less than the observed explicit valence 4.
The valence field specifies a valence 3 that is
less than the observed explicit valence 4.
The valence field specifies a valence 3 that is
less than the observed explicit valence 4.
The valence field specifies a valence 3 that is
less than the observed explicit valence 4.

The valence field specifies a valence 3 that is
less than the observed explicit valenc

In [6]:
keys = []
cplx_graphs = []
for idx, p in enumerate(poses_content):
    print(f"pose-{idx}")
    basename = prefix + "-" + str(idx+1)

    if pose_fpath.endswith("sdf"):
        mol = Chem.MolFromMolBlock(p)
    elif pose_fpath.endswith("mol2"):
        mol = Chem.MolFromMol2Block(p)
    else:
        print("InputError: Please input the pose file with .sdf or .mol2 format.")
        continue

    try:
        # get ligand subgraph
        lig = LigandFeature(mol=mol)
        lig.lig_to_graph()

        # create cplx graph
        cplx = ComplexGraph(rec, lig)
        cplx_graph = cplx.get_cplx_graph()
        cplx_graphs.append(cplx_graph)
        
        keys.append(basename)
    except Exception as e:
        print(f"Error: {idx}", e)
        
pock_graphs = [pock_g] * len(keys) 

# create datasets
dataset = MyDataset(pock_graphs, cplx_graphs)
test_loader = GraphDataLoader(dataset, batch_size=32, shuffle=False)

pose-0
pose-1
pose-2
pose-3
pose-4
pose-5
pose-6
pose-7
pose-8
pose-9


# Scoring

In [7]:
pred_rmsd, pred_pkd = run_an_eval_epoch(model, test_loader, device="cpu")

In [8]:
values = np.concatenate([pred_rmsd.reshape(-1, 1), pred_pkd.reshape(-1, 1)], axis=1)
df = pd.DataFrame(values, index=keys, columns=["pred_rmsd", "pred_pkd"])
df = df.sort_values(by="pred_pkd", ascending=False)
df

Unnamed: 0,pred_rmsd,pred_pkd
1bcu-1,0.405091,4.184812
1bcu-2,0.64705,4.100025
1bcu-3,2.626471,2.096566
1bcu-8,3.763524,1.701476
1bcu-5,3.84161,1.556922
1bcu-4,3.638276,1.068059
1bcu-6,4.976994,-0.403105
1bcu-7,5.079344,-0.465075
1bcu-10,7.302489,-2.573426
1bcu-9,7.514856,-2.740049
