In [76]:

import os
import pandas as pd
import sys
import os
from utils import process
import warnings
from torch_geometric.data import Data, HeteroData
from torch_geometric.utils.convert import to_dgl
from torch_geometric.loader import LinkNeighborLoader, HGTLoader
from torch_geometric.sampler import NeighborSampler,NegativeSampling
import dgl
import torch
from torch_geometric.utils import degree

class PreData:

    def __init__(self, data_folder_path):
        if not os.path.exists(data_folder_path):
            os.mkdir(data_folder_path)

        self.data_folder = data_folder_path  # the data folder, contains the kg.csv

    def prepare_split(self, split='complex_disease',
                      disease_eval_idx=None,
                      seed=42,
                      no_kg=False,
                      test_size=0.05,
                      mask_ratio=0.1,
                      one_hop=False):

        if split not in ['random', 'complex_disease', 'complex_disease_cv', 'disease_eval', 'cell_proliferation',
                         'mental_health', 'cardiovascular', 'anemia', 'adrenal_gland', 'autoimmune',
                         'metabolic_disorder', 'diabetes', 'neurodigenerative', 'full_graph', 'downstream_pred',
                         'few_edeges_to_kg', 'few_edeges_to_indications']:
            raise ValueError(
                "Please select one of the following supported splits: 'random', 'complex_disease', 'disease_eval', 'cell_proliferation', 'mental_health', 'cardiovascular', 'anemia', 'adrenal_gland'")

        if disease_eval_idx is not None:
            split = 'disease_eval'
            print('disease eval index is not none, use the individual disease split...')
        self.split = split

        if split in ['cell_proliferation', 'mental_health', 'cardiovascular', 'anemia', 'adrenal_gland', 'autoimmune',
                     'metabolic_disorder', 'diabetes', 'neurodigenerative']:

            if test_size != 0.05:
                folder_name = split + '_kg' + '_frac' + str(test_size)
            elif one_hop:
                folder_name = split + '_kg' + '_one_hop_ratio' + str(mask_ratio)
            else:
                folder_name = split + '_kg'

            if not os.path.exists(os.path.join(self.data_folder, folder_name)):
                os.mkdir(os.path.join(self.data_folder, folder_name))
            kg_path = os.path.join(self.data_folder, folder_name, 'kg_directed.csv')
        else:
            kg_path = os.path.join(self.data_folder, 'kg_directed.csv')

        if os.path.exists(kg_path):
            print('Found saved processed KG... Loading...')
            df = pd.read_csv(kg_path)
        else:
            if os.path.exists(os.path.join(self.data_folder, 'kg.csv')):
                print('First time usage... Mapping TxData raw KG to directed csv... it takes several minutes...')
                process.preprocess_kg(self.data_folder, split, test_size, one_hop, mask_ratio)
                df = pd.read_csv(kg_path)
            else:
                raise ValueError("KG file path does not exist...")

        if split == 'disease_eval':
            split_data_path = os.path.join(self.data_folder, self.split + '_' + str(disease_eval_idx))
        elif split == 'downstream_pred':
            split_data_path = os.path.join(self.data_folder, self.split + '_downstream_pred')
            disease_eval_idx = [11394., 6353., 12696., 14183., 12895., 9128., 12623., 15129.,
                                12897., 12860., 7611., 13113., 4029., 14906., 13438., 13177.,
                                13335., 12896., 12879., 12909., 4815., 12766., 12653.]
        elif no_kg:
            split_data_path = os.path.join(self.data_folder, self.split + '_no_kg_' + str(seed))
        elif test_size != 0.05:
            split_data_path = os.path.join(self.data_folder, self.split + '_' + str(seed)) + '_frac' + str(test_size)
        elif one_hop:
            split_data_path = os.path.join(self.data_folder, self.split + '_' + str(seed)) + '_one_hop_ratio' + str(
                mask_ratio)
        else:
            split_data_path = os.path.join(self.data_folder, self.split + '_' + str(seed))

        if no_kg:
            sub_kg = ['off-label use', 'indication', 'contraindication']
            df = df[df.relation.isin(sub_kg)].reset_index(drop=True)

        if not os.path.exists(os.path.join(split_data_path, 'train.csv')):
            if not os.path.exists(split_data_path):
                os.mkdir(split_data_path)
            print('Creating splits... it takes several minutes...')
            df_train, df_valid, df_test = process.create_split(df, split, disease_eval_idx, split_data_path, seed)
        else:
            print('Splits detected... Loading splits....')
            df_train = pd.read_csv(os.path.join(split_data_path, 'train.csv'))
            df_valid = pd.read_csv(os.path.join(split_data_path, 'valid.csv'))
            df_test = pd.read_csv(os.path.join(split_data_path, 'test.csv'))

        if split not in ['random', 'complex_disease', 'complex_disease_cv', 'disease_eval', 'full_graph',
                         'downstream_pred', 'few_edeges_to_indications', 'few_edeges_to_kg']:
  
            df_test = process.process_disease_area_split(self.data_folder, df, df_test, split)

        print('Creating PyG graph....')
        #create pyg graph
        g = process.create_pyg_graph(df_test,df)


        self.G = g
        self.df, self.df_train, self.df_valid, self.df_test = df, df_train, df_valid, df_test
        self.disease_eval_idx = disease_eval_idx
        self.no_kg = no_kg
        self.seed = seed
        print('Done!')
        return g,df, df_train, df_valid, df_test,disease_eval_idx,no_kg


    def retrieve_id_mapping(self):
        df = self.df
        df['x_id'] = df.x_id.apply(lambda x: process.convert2str(x))
        df['y_id'] = df.y_id.apply(lambda x: process.convert2str(x))

        idx2id_drug = dict(df[df.x_type == 'drug'][['x_idx', 'x_id']].drop_duplicates().values)
        idx2id_drug.update(dict(df[df.y_type == 'drug'][['y_idx', 'y_id']].drop_duplicates().values))

        idx2id_disease = dict(df[df.x_type == 'disease'][['x_idx', 'x_id']].drop_duplicates().values)
        idx2id_disease.update(dict(df[df.y_type == 'disease'][['y_idx', 'y_id']].drop_duplicates().values))

        df_ = pd.read_csv(os.path.join(self.data_folder, 'kg.csv'))
        df_['x_id'] = df_.x_id.apply(lambda x: process.convert2str(x))
        df_['y_id'] = df_.y_id.apply(lambda x: process.convert2str(x))

        id2name_disease = dict(df_[df_.x_type == 'disease'][['x_id', 'x_name']].drop_duplicates().values)
        id2name_disease.update(dict(df_[df_.y_type == 'disease'][['y_id', 'y_name']].drop_duplicates().values))

        id2name_drug = dict(df_[df_.x_type == 'drug'][['x_id', 'x_name']].drop_duplicates().values)
        id2name_drug.update(dict(df_[df_.y_type == 'drug'][['y_id', 'y_name']].drop_duplicates().values))

        return {'id2name_drug': id2name_drug,
                'id2name_disease': id2name_disease,
                'idx2id_disease': idx2id_disease,
                'idx2id_drug': idx2id_drug
                }


In [77]:
PreData = PreData(data_folder_path='/data/zhaojingtong/PrimeKG/data_all/')
g,df, df_train, df_valid, df_test,disease_eval_idx,no_kg = PreData.prepare_split(split='random', seed=42, no_kg=False)
g = process.initialize_node_embedding(g, 128)
data = g

Found saved processed KG... Loading...


  df = pd.read_csv(kg_path)


Splits detected... Loading splits....


  df_train = pd.read_csv(os.path.join(split_data_path, 'train.csv'))
  df_valid = pd.read_csv(os.path.join(split_data_path, 'valid.csv'))
  df_test = pd.read_csv(os.path.join(split_data_path, 'test.csv'))


Creating PyG graph....
Done!


In [89]:
import torch.nn.functional as F
import torch.nn as nn

def compute_edge_score(src_h, dst_h, rel_emb):
    # 计算边上的得分，类似于 apply_edges 的功能
    score = torch.sum(src_h * rel_emb * dst_h, dim=1)
    return score
w_rels = nn.Parameter(torch.Tensor(len(data.edge_types), 128))
W = w_rels
nn.init.xavier_uniform_(w_rels, gain=nn.init.calculate_gain('relu'))
rel2idx = dict(zip(data.edge_types, list(range(len(data.edge_types)))))
lp_loss = 0
from models import HeteroConvLayers
hetero_conv = HeteroConvLayers( data, 128,1,0.2)
for edge_type in data.edge_types:
    edge_type_loss = 0
    scores_list = []
    label_list = []
    neg_sampling_config = NegativeSampling(
        mode='triplet',  # 使用 binary 模式
        dst_weight=torch.pow(degree(data[edge_type].edge_index[1]), 0.75).float()
    )
    loader = LinkNeighborLoader(
        data,
        num_neighbors=[30],  # 为每个关系采样邻居个数
        batch_size=1024,
        edge_label_index=(edge_type, data[edge_type].edge_index),
        neg_sampling=neg_sampling_config,
        neg_sampling_ratio=1.0,  # 正负样本比例
        edge_label=torch.ones(data[edge_type].edge_index.size(1))
    )
    for batch in loader:
        h_1 = hetero_conv(data, LP=False)
        h1  = {key: F.relu(h)  for key, h in h_1.items()}
        
        src_h = data[edge_type[0]].x[batch[edge_type]['edge_label_index'][0]]  # 源节点的嵌入
        dst_h = data[edge_type[2]].x[batch[edge_type]['edge_label_index'][1]]  # 目标节点的嵌入
        print(h1)
        # 获取当前关系类型的嵌入
        rel_idx = rel2idx[edge_type]
        rel_emb = W[rel_idx]  # 关系嵌入
        
        # 计算得分
        score = compute_edge_score(src_h, dst_h, rel_emb)
        scores_list.append(score)
        label_list.append(batch[edge_type]['edge_label'])
    all_scores = torch.cat(scores_list)
    all_label = torch.cat(label_list)
    loss = F.binary_cross_entropy(torch.sigmoid(all_scores), all_label)
    edge_type_loss += loss
    edge_type_loss = edge_type_loss/len(loader)
    # lp_loss+=edge_type_loss
    
    print(edge_type_loss)

{'gene/protein': tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0143, 0.0000, 0.0217,  ..., 0.0000, 0.0177, 0.0057],
         [0.0000, 0.0050, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0157, 0.0000, 0.0000,  ..., 0.0043, 0.0456, 0.0069],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0113, 0.0000,  ..., 0.0347, 0.0000, 0.0165]]],
       grad_fn=<ReluBackward0>), 'disease': tensor([[[0.0116, 0.0000, 0.0011,  ..., 0.0069, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0052, 0.0108, 0.0115],
         [0.0112, 0.0000, 0.0029,  ..., 0.0010, 0.0000, 0.0187],
         [0.0009, 0.0000, 0.0070,  ..., 0.0000, 0.0000, 0.0000]]],
       grad_fn=<ReluBackward0>), 'drug': tensor([[[0.0142, 0.0000, 0.0167,  ..., 0.0000, 0.0000, 0.0077],
         [0.0000, 0.0

KeyboardInterrupt: 

In [96]:
data['drug'].x.shape

torch.Size([7957, 128])

In [100]:
h1['drug'][0].shape

torch.Size([7957, 128])

In [79]:
batch

HeteroData(
  anatomy={
    num_nodes=0,
    x=[0, 128],
    n_id=[0],
  },
  biological_process={
    num_nodes=0,
    x=[0, 128],
    n_id=[0],
  },
  cellular_component={
    num_nodes=0,
    x=[0, 128],
    n_id=[0],
  },
  disease={
    num_nodes=613,
    x=[613, 128],
    n_id=[613],
  },
  drug={
    num_nodes=4559,
    x=[4559, 128],
    n_id=[4559],
  },
  effect/phenotype={
    num_nodes=403,
    x=[403, 128],
    n_id=[403],
  },
  exposure={
    num_nodes=0,
    x=[0, 128],
    n_id=[0],
  },
  gene/protein={
    num_nodes=347,
    x=[347, 128],
    n_id=[347],
  },
  molecular_function={
    num_nodes=0,
    x=[0, 128],
    n_id=[0],
  },
  pathway={
    num_nodes=0,
    x=[0, 128],
    n_id=[0],
  },
  (gene/protein, protein_protein, gene/protein)={
    edge_index=[2, 0],
    e_id=[0],
  },
  (drug, drug_protein, gene/protein)={
    edge_index=[2, 0],
    e_id=[0],
  },
  (drug, contraindication, disease)={
    edge_index=[2, 0],
    e_id=[0],
  },
  (drug, indication, di

In [53]:
batch[edge_type]

{'edge_index': tensor([[ 499,  500,  178,  ...,   90, 2101, 1447],
        [   1,    1,    1,  ...,  496,  497,  497]]), 'e_id': tensor([27828, 18582, 21752,  ..., 30595, 29188, 10695]), 'input_id': tensor([14464, 14465, 14466, 14467, 14468, 14469, 14470, 14471, 14472, 14473,
        14474, 14475, 14476, 14477, 14478, 14479, 14480, 14481, 14482, 14483,
        14484, 14485, 14486, 14487, 14488, 14489, 14490, 14491, 14492, 14493,
        14494, 14495, 14496, 14497, 14498, 14499, 14500, 14501, 14502, 14503,
        14504, 14505, 14506, 14507, 14508, 14509, 14510, 14511, 14512, 14513,
        14514, 14515, 14516, 14517, 14518, 14519, 14520, 14521, 14522, 14523,
        14524, 14525, 14526, 14527, 14528, 14529, 14530, 14531, 14532, 14533,
        14534, 14535, 14536, 14537, 14538, 14539, 14540, 14541, 14542, 14543,
        14544, 14545, 14546, 14547, 14548, 14549, 14550, 14551, 14552, 14553,
        14554, 14555, 14556, 14557, 14558, 14559, 14560, 14561, 14562, 14563,
        14564, 14565,

In [65]:
batch[edge_type]['edge_label'].shape

torch.Size([256])

In [72]:
len(data.edge_types)

48