In [1]:
import pandas as pd
from collections import Counter
import os 
from tqdm import trange
import warnings
warnings.filterwarnings('ignore')
from feature_fusion import ProteinFeatureExtractor,SMILESFeatureExtractor
import esm
print(os.getcwd())
import torch
import numpy as np
from scipy.sparse import coo_matrix
#from utils import *
import yaml
from Structe_DPP_HyperGraph import HyGraph_Matrix_DPP_Structure
import torch.nn.functional as F
#import wandb
from model import *
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from hypergraph_utils import generate_G_from_H
from hypergraph_utils import construct_H_with_KNN
from sklearn.metrics import roc_auc_score, f1_score
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import cosine_similarity as cos
from rdkit import Chem
from rdkit.Chem import AllChem, DataStructs

/data/zyf/HyperGCN-DTI/codes


In [2]:

def calculate_similarity_filter(data, similarity_threshold=0.8):
    """
    根据化合物 SMILES 计算相似度并过滤
    
    参数:
    - data: pd.DataFrame，必须包含 'SMILES' 列
    - similarity_threshold: float，相似度阈值
    
    返回:
    - pd.DataFrame：过滤后的 DataFrame
    """
    # 将 SMILES 转换为 RDKit 分子对象
    data['mol'] = data['SMILES'].map(lambda x: Chem.MolFromSmiles(x))
    
    # 计算分子指纹
    data['fingerprint'] = data['mol'].map(lambda x: AllChem.GetMorganFingerprintAsBitVect(x, radius=2) if x else None)
    
    # 删除无法解析的分子
    data = data[~data['fingerprint'].isna()].reset_index(drop=True)

    # 计算相似度矩阵
    fingerprints = list(data['fingerprint'])
    num_molecules = len(fingerprints)
    similarity_matrix = np.zeros((num_molecules, num_molecules))

    for i in range(num_molecules):
        for j in range(i + 1, num_molecules):
            similarity = DataStructs.FingerprintSimilarity(fingerprints[i], fingerprints[j])
            similarity_matrix[i, j] = similarity
            similarity_matrix[j, i] = similarity

    # 过滤相似度低于阈值的化合物
    to_keep = []
    for i in range(num_molecules):
        if all(similarity_matrix[i, j] < similarity_threshold for j in range(i)):
            to_keep.append(i)

    filtered_data = data.iloc[to_keep].drop(['mol', 'fingerprint'], axis=1)
    return filtered_data




In [3]:

new_bdb = pd.read_csv('../data/BindingDB/bdb_202501.csv',low_memory=False)
old_bdb = pd.read_csv('../data/BindingDB/bdb_202310.csv',low_memory=False) 
old_bdb = old_bdb[old_bdb['Curation/DataSource'] == 'Curated from the literature by BindingDB']## 28456


In [4]:
# 确保列不为空
new_bdb = new_bdb[new_bdb['IC50 (nM)'].notna()]

# 遍历数据
for i in range(len(new_bdb)):
    try:
        value = str(new_bdb.iloc[i, 5])  # 修改列索引为正确值
        if '>' in value or '<' in value:
            new_bdb.iloc[i, 5] = float(value[1:])  # 去掉符号后转换
        else:
            new_bdb.iloc[i, 5] = float(value)  # 正常转换
    except ValueError:
        print(f"错误值：{new_bdb.iloc[i, 5]}")  # 输出无法转换的值


In [5]:
new_bdb = new_bdb[(new_bdb['IC50 (nM)'] <= 100) | (new_bdb['IC50 (nM)'] >= 10000)] # 37349
new_bdb ['label'] = 0
new_bdb.loc[new_bdb['IC50 (nM)'] <= 100, 'label'] = 1
new_bdb = new_bdb.rename(columns={'Ligand SMILES': 'SMILES','BindingDB Target Chain Sequence':'protein_sequence'})
old_bdb = old_bdb.rename(columns={'mol': 'SMILES','BindingDB Target Chain Sequence':'protein_sequence'})

In [6]:
def stat_count(BindingDB):
    BindingDB_mol = list(BindingDB['SMILES'].unique())
    BindingDB_mol = list(BindingDB['PubChem CID of Ligand'].unique())
    print('numbers of BindingDB mol:',len(BindingDB_mol))
    BindingDB_target = list(BindingDB['UniProt (SwissProt) Primary ID of Target Chain'].unique())
    print('numbers of BindingDB targets:',len(BindingDB_target))

In [7]:
#数据预处理
def data_proprecessing(data):
    data['protein_sequence']=data['protein_sequence'].map(lambda x: x.replace(' ',''))
    data['protein_sequence']=data['protein_sequence'].map(lambda x: x.upper())
    data['counts'] = data['SMILES'].map(lambda x: x.lower().count('c'))
    #去掉无机物
    
    data = data[data.counts > 3 ].reset_index(drop = True)
    data  = data.drop('counts',axis=1)
    # 调用过滤函数
    filtered_data = calculate_similarity_filter(data, similarity_threshold=0.7)
    return filtered_data

In [None]:
stat_count(new_bdb)
stat_count(old_bdb)
new_bdb = data_proprecessing(new_bdb)
old_bdb = data_proprecessing(old_bdb)
old_bdb.to_csv('../data/BindingDB/old_bdb.csv',index= False)
new_bdb.to_csv('../data/BindingDB/new_bdb.csv',index= False)
mol = pd.DataFrame(new_bdb['SMILES'].unique(),columns=['SMILES']) #24243
target = pd.DataFrame(new_bdb['protein_sequence'].unique(),columns=['protein_sequence']) #1338
mol.to_csv('../data/BindingDB/drug_smiles.csv',index=False)
target.to_csv('../data/BindingDB/protein_seq.csv',index=False)
extractor = ProteinFeatureExtractor(model_name='esm2_t6_8M_UR50D')
extractor.extract_features_from_csv('../data/BindingDB/protein_seq.csv')
extractor = SMILESFeatureExtractor(model_name='DeepChem/ChemBERTa-77M-MTR')
extractor.process_and_save_features('../data/BindingDB/drug_smiles.csv')

In [8]:
def construct_bdb_dataset(new_bdb):
    # Step 1: 获取唯一药物和蛋白质
    unique_drugs = new_bdb['SMILES'].unique()
    unique_proteins =new_bdb['protein_sequence'].unique()

    # Step 2: 为药物和蛋白质分配索引
    drug_to_index = {drug: i for i, drug in enumerate(unique_drugs)}
    protein_to_index = {protein: i for i, protein in enumerate(unique_proteins)}

    # Step 3: 替换药物和蛋白质为索引
    new_bdb['drug_index'] = new_bdb['SMILES'].map(drug_to_index)
    new_bdb['protein_index'] = new_bdb['protein_sequence'].map(protein_to_index)
    # Step 1: 获取唯一药物和蛋白质
    unique_drugs = new_bdb['SMILES'].unique()
    unique_proteins =new_bdb['protein_sequence'].unique()

    # Step 2: 为药物和蛋白质分配索引
    drug_to_index = {drug: i for i, drug in enumerate(unique_drugs)}
    protein_to_index = {protein: i for i, protein in enumerate(unique_proteins)}

    # Step 3: 替换药物和蛋白质为索引
    new_bdb['drug_index'] = new_bdb['SMILES'].map(drug_to_index)
    new_bdb['protein_index'] = new_bdb['protein_sequence'].map(protein_to_index)
    
    num_drug = len(unique_drugs)
    num_protein = len(unique_proteins)
    interaction_matrix = np.zeros((num_drug, num_protein), dtype=int)
    for _, row in new_bdb.iterrows():
        interaction_matrix[row['drug_index'], row['protein_index']] = row['label']
    
    # 将数据集写入文件
    data_set = np.array(new_bdb[['drug_index','protein_index','label']])
    with open("../data/BindingDB/dti_index.txt", "w", encoding="utf-8") as f:
        for i in data_set:
            f.write(f"{i[0]}\t{i[1]}\t{i[2]}\n")


    col1 = data_set[:, 0]
    col2 = data_set[:, 1]


    col1_dict = {}
    col2_dict = {}
    for i, val in enumerate(col1):
        col1_dict.setdefault(val, []).append(i)
    for i, val in enumerate(col2):
        col2_dict.setdefault(val, []).append(i)


    rows, cols = [], []
    for indices in col1_dict.values():
        for i in indices:
            for j in indices:
                if i <= j:
                    rows.append(i)
                    cols.append(j)

    for indices in col2_dict.values():
        for i in indices:
            for j in indices:
                if i <= j:
                    rows.append(i)
                    cols.append(j)

    # 去重
    edges = set(zip(rows, cols))

    # 写入文件
    with open("../data/BindingDB/dtiedge.txt", "w", encoding="utf-8") as f:
        for i, j in edges:
            f.write(f"{i}\t{j}\n")

    node_num = [num_drug, num_protein]
    drug_protein_tensor = torch.Tensor(interaction_matrix)
    protein_drug_tensor = drug_protein_tensor.t()
    return node_num, drug_protein_tensor, protein_drug_tensor, data_set

In [9]:
def get_train_test_index(new_bdb,old_bdb):
    # 创建 drug-protein pairs 列
    new_bdb['drug_protein_pair'] = new_bdb['PubChem CID of Ligand'].astype(str)+ '_' + new_bdb['UniProt (SwissProt) Primary ID of Target Chain']
    old_bdb['drug_protein_pair'] = old_bdb['PubChem CID of Ligand'].astype(str)+ '_' + old_bdb['UniProt (SwissProt) Primary ID of Target Chain']

    #标注测试集
    # 1. 共有的 drug-protein pairs
    common_pairs = new_bdb[new_bdb['drug_protein_pair'].isin(old_bdb['drug_protein_pair'])]
    print("common drug-protein pairs:",len(common_pairs))

    # 2. old drug-new protein pairs
    new_bdb_without_common = new_bdb[~new_bdb['drug_protein_pair'].isin(common_pairs['drug_protein_pair'])]
    old_drugs = old_bdb['PubChem CID of Ligand'].unique()
    new_proteins = new_bdb['UniProt (SwissProt) Primary ID of Target Chain'].unique()
    old_drug_new_protein_pairs = new_bdb_without_common[new_bdb_without_common['PubChem CID of Ligand'].isin(old_drugs) & new_bdb_without_common['UniProt (SwissProt) Primary ID of Target Chain'].isin(new_proteins)]
    print("old drug-new protein pairs:",len(old_drug_new_protein_pairs))
    print()

    # 3. new drug-old protein pairs
    new_drugs = new_bdb['PubChem CID of Ligand'].unique()
    old_proteins = old_bdb['UniProt (SwissProt) Primary ID of Target Chain'].unique()
    new_drug_old_protein_pairs = new_bdb_without_common[new_bdb_without_common['PubChem CID of Ligand'].isin(new_drugs) & new_bdb_without_common['UniProt (SwissProt) Primary ID of Target Chain'].isin(old_proteins)]
    print("new drug-old protein pairs:",len(new_drug_old_protein_pairs))
    # 4. new drug-new protein pairs
    new_drug_new_protein_pairs =new_bdb_without_common[~new_bdb['PubChem CID of Ligand'].isin(old_bdb['PubChem CID of Ligand']) & 
                                                        ~new_bdb_without_common['UniProt (SwissProt) Primary ID of Target Chain'].isin(old_bdb['UniProt (SwissProt) Primary ID of Target Chain'])]
    print("new drug-new protein pairs:",len(new_drug_new_protein_pairs))
    
    train_indeces = np.array(common_pairs.index)
    test_tp_indeces = np.array(old_drug_new_protein_pairs.index)
    test_td_indeces = np.array(new_drug_old_protein_pairs.index)
    test_tn_indeces = np.array(new_drug_new_protein_pairs.index)
    return train_indeces,test_tp_indeces,test_td_indeces,test_tn_indeces

In [18]:
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)
#config['device'] = "cuda" if torch.cuda.is_available() else "cpu"
#config['device'] = "cpu"

setup(config['seed'])


reg_loss_co = 0.0002
fold = 0

torch.set_default_dtype(torch.float32)

In [11]:

'''1:HyperDrug & HyperProtein features
    2:sequence feature from pre-trained LLM model
3:HyperDrug-Disease & HyperProtein-Disease features
    input:list,contains types of features ''' 

'为了后续图架构，保证使用HyperDrug和HyperProtein特征' 
new_bdb = pd.read_csv('../data/BindingDB/new_bdb.csv')
old_bdb = pd.read_csv('../data/BindingDB/old_bdb.csv')
train_indeces,test_tp_indeces,test_td_indeces,test_tn_indeces = get_train_test_index(new_bdb,old_bdb)
node_num, drug_protein, protein_drug, dtidata = construct_bdb_dataset(new_bdb)

drug_protein_eye = torch.cat((drug_protein, torch.eye(node_num[0])), dim=1)
protein_drug_eye = torch.cat((protein_drug, torch.eye(node_num[1])), dim=1)
HyGraph_Drug = generate_G_from_H(drug_protein_eye).to(config['device'])
HyGraph_protein = generate_G_from_H(protein_drug_eye).to(config['device'])
#print('HyGraph_Drug,HyGraph_protein:',HyGraph_Drug.shape,HyGraph_protein.shape)

if  config['feature_list'] ==[1,2]:
    
    hd = pd.read_csv(f'/data/zyf/HyperGCN-DTI/data/BindingDB/drug_smiles.csv')
    hp = pd.read_csv(f'/data/zyf/HyperGCN-DTI/data/BindingDB/protein_seq.csv')

    features_d = torch.tensor(hd.iloc[:,1:].values,dtype=torch.float32).to(config['device'])
    features_p = torch.tensor(hp.iloc[:,1:].values,dtype=torch.float32).to(config['device'])
    print('load LLM features')

'''common drug-protein pairs: 6881
old drug-new protein pairs: 40

new drug-old protein pairs: 2631
new drug-new protein pairs: 108'''

common drug-protein pairs: 6881
old drug-new protein pairs: 40

new drug-old protein pairs: 2631
new drug-new protein pairs: 108
load LLM features


'common drug-protein pairs: 6881\nold drug-new protein pairs: 40\n\nnew drug-old protein pairs: 2631\nnew drug-new protein pairs: 108'

In [12]:
dti_label = torch.tensor(dtidata[:, 2:3]).to(config['device'])
drug_protein = drug_protein.to(config['device'])
protein_drug = protein_drug.to(config['device'])
HyGraph_Structure_DPP = HyGraph_Matrix_DPP_Structure(dtidata, node_num[0], node_num[1],'BindingDB' )
HyGraph_Structure_DPP = HyGraph_Structure_DPP.to(config['device'])
data = dtidata
label = dti_label

In [13]:
print(node_num,drug_protein.shape,protein_drug.shape,dtidata.shape,features_d.shape,features_p.shape,HyGraph_Drug.shape,HyGraph_protein.shape)
print(type(drug_protein),type(protein_drug),type(dtidata),type(features_d),type(features_p),type(HyGraph_Drug),type(HyGraph_protein))

[9623, 832] torch.Size([9623, 832]) torch.Size([832, 9623]) (9623, 3) torch.Size([9623, 384]) torch.Size([832, 320]) torch.Size([9623, 9623]) torch.Size([832, 832])
<class 'torch.Tensor'> <class 'torch.Tensor'> <class 'numpy.ndarray'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'> <class 'torch.Tensor'>


In [14]:
print(len(train_indeces),len(test_tp_indeces),len(test_td_indeces),len(test_tn_indeces),len(new_bdb))

6881 40 2631 108 9623


In [15]:


def train(model, optim, train_index, epoch):
    model.train()
    out, d, p = model(node_num, features_d, features_p, protein_drug, drug_protein, HyGraph_Drug, HyGraph_protein, train_index, data, HyGraph_Structure_DPP)
    tr_acc = (out.argmax(dim=1) == label[train_index].reshape(-1).long()).sum(dtype=float) / torch.tensor(len(train_index), dtype=float)
    tr_task1_roc = get_roc(out, label[train_index])
    reg = get_L2reg(model.parameters())
    loss = F.nll_loss(out, label[train_index].reshape(-1).long()) + reg_loss_co * reg
    optim.zero_grad()
    loss.backward()
    optim.step()
    tr_acc, tr_task1_roc1, tr_task1_pr, tr_task_precision, tr_task_recall, tr_task1_f1 = main_test(model, d, p, train_indeces)
    return loss.item(), tr_acc, tr_task1_roc1, tr_task1_pr, tr_task_precision, tr_task_recall, tr_task1_f1, d, p



def main_test(model, d, p, test_index):
    model.eval()
    out = model(node_num, features_d, features_p, protein_drug, drug_protein, HyGraph_Drug, HyGraph_protein, test_index, data, HyGraph_Structure_DPP, iftrain=False, d=d, p=p)
    acc1 = (out.argmax(dim=1) == label[test_index].reshape(-1).long()).sum(dtype=float) / torch.tensor(len(test_index), dtype=float)
    task_roc = get_roc(out, label[test_index])
    task_precision,task_recall,task_pr = get_pr(out, label[test_index])
    task_f1 = get_f1score(out, label[test_index])
    return acc1, task_roc, task_pr, task_precision,task_recall,task_f1

def main(train_index, test_tp_indeces,test_td_indeces,test_tn_indeces, seed):

    model = HyperGCNDTI(
        num_protein=node_num[1],
        num_drug=node_num[0],
        num_hidden1=config['in_size'],
        num_hidden2=config['hidden_size'],
        num_out=config['out_size'],
        feature_list= config['feature_list']
    ).to(config['device'])
        
    # model.load_state_dict(torch.load(f"{dir}/net{i}.pth"))
    optim = torch.optim.Adam(lr=config['lr'], weight_decay= float(config['weight_decay']), params=model.parameters())
    best_roc =0
    best_results = []
    model_path = os.path.join(config['save_dir'], f"{config['feature_list']}_dataset_BindingDB_best_model_roc.pth")

    for epoch in tqdm(range(config['epochs'])):
        loss, acc, task1_roc1, task1_pr, task1_precision, task1_recall, task1_f1, d, p = train(model, optim, train_index, epoch)
        if task1_roc1 > best_roc:
            best_roc = task1_roc1
            best_model_state = model.state_dict()  # Update the best model state
            torch.save(best_model_state, model_path)
            best_results = acc, task1_roc1, task1_pr, task1_precision, task1_recall, task1_f1
            best_results = list(tuple(f"{value:.4f}" for value in best_results))

    # 加载最佳模型并测试
    print("Training finished!", best_results)
    best_model_state = torch.load(model_path)
    model.load_state_dict(best_model_state)
    model.eval()

    # 测试结果
    results = {}

    # 训练集结果
    tr_acc, tr_task1_roc1, tr_task1_pr, tr_task_precision, tr_task_recall, tr_task1_f1 = main_test(model, d, p, train_index)
    results["train"] = {
        "accuracy": tr_acc,
        "roc": tr_task1_roc1,
        "pr": tr_task1_pr,
        "precision": tr_task_precision,
        "recall": tr_task_recall,
        "f1": tr_task1_f1
    }

    # 测试集 tp
    tp_acc, tp_task1_roc1, tp_task1_pr, tp_task_precision, tp_task_recall, tp_task1_f1 = main_test(model, d, p, test_tp_indeces)
    results["test_tp"] = {
        "accuracy": tp_acc,
        "roc": tp_task1_roc1,
        "pr": tp_task1_pr,
        "precision": tp_task_precision,
        "recall": tp_task_recall,
        "f1": tp_task1_f1
    }

    # 测试集 td
    td_acc, td_task1_roc1, td_task1_pr, td_task_precision, td_task_recall, td_task1_f1 = main_test(model, d, p, test_td_indeces)
    results["test_td"] = {
        "accuracy": td_acc,
        "roc": td_task1_roc1,
        "pr": td_task1_pr,
        "precision": td_task_precision,
        "recall": td_task_recall,
        "f1": td_task1_f1
    }

    # 测试集 tn
    tn_acc, tn_task1_roc1, tn_task1_pr, tn_task_precision, tn_task_recall, tn_task1_f1 = main_test(model, d, p, test_tn_indeces)
    results["test_tn"] = {
        "accuracy": tn_acc,
        "roc": tn_task1_roc1,
        "pr": tn_task1_pr,
        "precision": tn_task_precision,
        "recall": tn_task_recall,
        "f1": tn_task1_f1
    }

    # 保存结果
    df_results = pd.DataFrame(results).T
    df_results.index.name = "dataset"
    df_results.to_csv(os.path.join(config['results_dir'], f"BindingDB_{config['feature_list']}_results.csv"), index=False)
    return df_results



In [20]:
main(train_indeces, test_tp_indeces,test_td_indeces,test_tn_indeces, config['seed'])

100%|██████████| 100/100 [20:07<00:00, 12.08s/it]


Training finished! ['0.9943', '1.0000', '1.0000', '1.0000', '0.9854', '0.9927']


Unnamed: 0_level_0,accuracy,roc,pr,precision,recall,f1
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
train,"tensor(0.9951, device='cuda:0', dtype=torch.fl...",1.0,1.0,1.0,0.987313,0.993616
test_tp,"tensor(1., device='cuda:0', dtype=torch.float64)",1.0,1.0,1.0,1.0,1.0
test_td,"tensor(0.9962, device='cuda:0', dtype=torch.fl...",0.999946,0.999824,0.993506,0.990291,0.991896
test_tn,"tensor(1., device='cuda:0', dtype=torch.float64)",1.0,1.0,1.0,1.0,1.0
