In [1]:
import argparse
import os
import random
import time
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from dgl.dataloading import GraphDataLoader
from sklearn.metrics import mean_absolute_error, mean_squared_error

import sys
sys.path.append("/home/yujie/AIcode/")
from Dataloader.dataloader import collate_fn, LeadOptDataset, LeadOptDataset_test
from ReadoutModel.readout_bind import DMPNN
from utilis.function import get_loss_func
from utilis.initial import initialize_weights
from utilis.scalar import StandardScaler
from utilis.scheduler import NoamLR_shan
from utilis.trick import Writer
from utilis.utilis import gm_process


In [2]:
import torch.nn.functional as F

In [3]:
model = torch.load("/home/yujie/leadopt/result_final/xiaorongshiyan_kejieshixing/model_11_115000_1.pth",map_location="cpu")
# model.to('cuda:1')

In [4]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem import BRICS
# from rdkit.Chem.Draw import IPythonConsole #Needed to show molecules
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions #Only needed if modifying defaults
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import Image

import copy
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
from IPython.display import display
import matplotlib
import matplotlib.cm as cm

In [5]:
def brics_decomp(mol):
    n_atoms = mol.GetNumAtoms()
    if n_atoms == 1:
        return [[0]], []

    cliques = []
    breaks = []
    for bond in mol.GetBonds():
        a1 = bond.GetBeginAtom().GetIdx()
        a2 = bond.GetEndAtom().GetIdx()
        cliques.append([a1, a2])

    res = list(BRICS.FindBRICSBonds(mol))
    if len(res) == 0:
        return [list(range(n_atoms))], []
    else:
        for bond in res:
            if [bond[0][0], bond[0][1]] in cliques:
                cliques.remove([bond[0][0], bond[0][1]])
            else:
                cliques.remove([bond[0][1], bond[0][0]])
            cliques.append([bond[0][0]])
            cliques.append([bond[0][1]])

    # merge cliques
    for c in range(len(cliques) - 1):
        if c >= len(cliques):
            break
        for k in range(c + 1, len(cliques)):
            if k >= len(cliques):
                break
            if len(set(cliques[c]) & set(cliques[k])) > 0:
                cliques[c] = list(set(cliques[c]) | set(cliques[k]))
                cliques[k] = []
        cliques = [c for c in cliques if len(c) > 0]
    cliques = [c for c in cliques if len(c) > 0]

    # edges
    edges = []
    for bond in res:
        for c in range(len(cliques)):
            if bond[0][0] in cliques[c]:
                c1 = c
            if bond[0][1] in cliques[c]:
                c2 = c
        edges.append((c1, c2))
    for bond in breaks:
        for c in range(len(cliques)):
            if bond[0] in cliques[c]:
                c1 = c
            if bond[1] in cliques[c]:
                c2 = c
        edges.append((c1, c2))

    return cliques, edges

In [6]:
# define dataloader and collec function
import dgl

from utilis.utilis import Extend, pkl_load



def collate_fn(samples):

    ligand1_dir = [s.Ligand1.values[0] for s in samples]
    ligand2_dir = [s.Ligand2.values[0] for s in samples]
    pocket_dir = [s.Ligand2.values[0].rsplit("/", 1)[0] + "/pocket.pkl" for s in samples]
    graph1_list = [pkl_load(s) for s in ligand1_dir]
    graph2_list = [pkl_load(s) for s in ligand2_dir]
    pocket_list = [pkl_load(s) for s in pocket_dir]
    
#     idx_ = 4
#     graph2_list[0].ndata['atom_feature'][bri[idx_]] = torch.zeros(graph2_list[0].ndata['atom_feature'][bri[idx_]].size())

    g1 = dgl.batch(graph1_list)
    g2 = dgl.batch(graph2_list)
    pock = dgl.batch(pocket_list)
    # index_kj1, index_ji1 = triplets(g1)
    # index_kj2, index_ji2 = triplets(g2)

    label_list = [s.Lable.values[0] for s in samples]  # delta
    label1_list = [s.Lable1.values[0] for s in samples]  # validation samples' labels
    label2_list = [s.Lable2.values[0] for s in samples]  # referance train samples' labels

    rank1_list = [s.Rank1.values[0] for s in samples]  # 用于识别pair属于哪一个validation sample
    file_name = [s.rsplit("/", 2)[1] for s in ligand1_dir]

    return g1, \
           g2, \
           pock, \
           torch.tensor(label_list), \
           torch.tensor(label1_list), \
           torch.tensor(label2_list), \
           torch.tensor(rank1_list), \
           file_name


class LeadOptDataset():
    def __init__(self, df_path, label_scalar=None):
        self.df_path = df_path
        self.df = pd.read_csv(self.df_path)
        self.label_scalar = label_scalar

        if self.label_scalar == "finetune":
            label = self.df.Lable.values
            label = (np.array(label).astype(float) - 0.04191832) / 1.34086546
            self.df["Lable"] = label

        elif self.label_scalar is not None:
            label = self.df.Lable.values
            label = np.reshape(label, (-1, 1))
            self.label_scalar = self.label_scalar.fit(label)
            label = self.label_scalar.transform(label)
            self.df["Lable"] = label.flatten()

        self.df = self.df[0:1]
        print(self.df.Ligand1.values)
        print(self.df.Ligand2.values)
        super(LeadOptDataset, self).__init__()

            
    def file_names_(self):
        ligand_dir = self.df.Ligand1.values
        file_names = [s.rsplit("/", 2)[1] for s in ligand_dir]
        return list(set(file_names))

        
    def __getitem__(self, idx):
        return self.df[idx:idx + 1]

    def __len__(self):
        return len(self.df)

In [7]:
file_name = "Thrombin"

In [8]:
test_dataset = LeadOptDataset(f"/home/yujie/data_for_fep_for_att/test_set_fep_graph_rmH_/resutls/0_reference/train_files/{file_name}.csv")
test_dataloader = GraphDataLoader(test_dataset, collate_fn=collate_fn, batch_size=1,
                                           drop_last=False, shuffle=False)

['/home/yujie/leadopt/data/test_set_fep_graph_rmH_I/Thrombin/Thrombin_1a.pkl']
['/home/yujie/leadopt/data/test_set_fep_graph_rmH_I/Thrombin/Thrombin_6a.pkl']


In [9]:
name = "Thrombin_6a"

In [10]:
ligand1_dir = [f'/home/yujie/leadopt/data/test_set_fep_graph_rmH_I/{file_name}/Thrombin_1a.pkl']
ligand2_dir = [f'/home/yujie/leadopt/data/test_set_fep_graph_rmH_I/{file_name}/{name}.pkl']

graph1_list = [pkl_load(s) for s in ligand1_dir]
graph2_list = [pkl_load(s) for s in ligand2_dir]

In [11]:
# @torch.no_grad()
# def predict(args, model, loader, device):
#     model.eval()

#     # if args.loss_function == 'mve':
#     #     uncertainty = []
#     # elif args.loss_function == "evidential":
#     #     uncertainty = []

#     valid_prediction = []
#     valid_labels = []
#     valid_1_labels = []
#     ref_2_labels = []
#     rank = []
#     file = []

#     att__1 = []
#     att__2 = []
device = "cuda:1"
model.eval()
for batch_data in test_dataloader:

    graph1, graph2, pock, label, label1, label2, rank1, _ = batch_data

#     graph1, graph2, pock, label, label1, label2 = graph1.to(device), graph2.to(device), pock.to(device), label.to(device), label1.to(
#                 device), label2.to(device)
    logits,_,att1,att2,a1,a2 = model(graph1,
                   graph2, pock)


In [12]:
logits

tensor([[0.6090]], grad_fn=<AddmmBackward0>)