In [1]:
import os
import shutil
import sys
sys.path.append('/'.join(os.path.abspath(__file__).split('/')[:-2]))

from collections import defaultdict
from tqdm.notebook import tqdm

from einops import rearrange

from model.layers import LigPoseStruct
from utils.pdbbind_utils import ComplexStructDataset, collate_struct, batch_index_select, pred_ens, calc_rmsd
from utils.common import *

In [2]:
device = 'cuda:0'
ens = 5
seed = 7
weight_path = '../example/LigPose_param.chk'

In [5]:
if device == 'cpu':
    torch.set_num_threads(16)
else:
     torch.cuda.set_device(device)
set_all_seed(seed)

chk = torch.load(weight_path, map_location=device)
args = chk['struct_args']
model = LigPoseStruct(args).to(device)
model.load_state_dict(chk['struct_state_dict'], strict=True)
model.train(False)


test_list = []
for i in load_idx_list('pdbbind/core_test/test_list.txt'):
    test_list += [i] * ens
test_dataset = ComplexStructDataset('test', args, test_list)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=ens, num_workers=2,
                                         shuffle=False, persistent_workers=False,
                                         collate_fn=collate_struct)


with torch.no_grad():

    dic_eval = defaultdict(list)

    for dic_data in tqdm(test_loader):
        dic_data = dic_data.to(device)

        tup_pred = model(dic_data)
        ens_pred = pred_ens(tup_pred[0][-1], dic_data)

        coor_true = batch_index_select(dic_data.coor_true, dic_data.node_sampling_loc[dic_data.cycle_i])
        coor_true = rearrange(coor_true, 'b n c -> (b n) c')[dic_data.ligand_node_loc_after_sampling_flat].reshape(ens, -1, 3)  # to (ens, n_atom, 3)
        ens_pred, coor_true = ens_pred * args.coor_scale, coor_true * args.coor_scale
        ligand_match = dic_data.ligand_match.reshape(ens, -1)[0]

        rmsd_match_ens, _, _ = calc_rmsd(ens_pred.unsqueeze(0), coor_true[[0]], match=ligand_match) # return [match, 1]
        rmsd_value = rmsd_match_ens.min()

        dic_eval['idx'].append(dic_data['idx'][0])
        dic_eval['rmsd_value'].append(rmsd_value.item())


dic_eval['rmsd_value'] = np.array(dic_eval['rmsd_value'])
succ_rate = (dic_eval['rmsd_value'] < 2).sum() / len(dic_eval['rmsd_value'])
print(f"RMSD < 2A: {(dic_eval['rmsd_value'] < 2).sum()}/{len(dic_eval['rmsd_value'])}, {succ_rate*100:.2f}%")

  0%|          | 0/285 [00:00<?, ?it/s]

SystemExit: 