In [None]:
import torch
import numpy as np
from sklearn import metrics
import torch.nn as nn
from scipy import sparse
from torch.utils.data import DataLoader,Dataset
import torch.nn.functional as F
import torch.nn.init as init
from scipy.linalg import expm
import math
torch.cuda.empty_cache()

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience, verbose=False, delta=0, save_path='checkpoint.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.save_path = save_path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score - self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.save_path)
        self.val_loss_min = val_loss


In [None]:
class GCN(nn.Module):
    def __init__(self) :
        super().__init__()
    def forward(self,adj,x):
        return adj@x
    
class HyperGraph(nn.Module):
    def __init__(self,in_dim,hidden_dim,out_dim):
        super().__init__()

        self.weight1 = nn.Linear(in_dim,hidden_dim)
        self.dropout = nn.Dropout(0.3)
        self.relu = nn.LeakyReLU(0.2)
        self.batch_norm1 = nn.BatchNorm1d(out_dim)

        self.gate = nn.Linear(hidden_dim,1)
        self.sigmoid = nn.Sigmoid()
        self.tanh_f = nn.Tanh()

        nn.init.xavier_normal_(self.weight1.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.gate.weight,gain=nn.init.calculate_gain('relu'))

    def forward(self,x,hyper_adj,hyper_degree,node_degree):
        x = self.weight1(x)
        x = self.dropout(x)
        hyper_edge_emb = hyper_adj.T@node_degree@x #[超边数，特征维度]，这里是1527，1527
        #利用门控对超边特征进行调整大小
        hyper_edge_weight = self.tanh_f(self.sigmoid(self.gate(hyper_edge_emb)))
        hyper_edge_weight = torch.diag(torch.squeeze(hyper_edge_weight))
        hyper_edge_emb = hyper_edge_weight@hyper_edge_emb
        #超边特征传播到节点
        hyper_emb = node_degree@hyper_adj@hyper_edge_emb #1527*600
        hyper_emb = self.batch_norm1(hyper_emb)
        return hyper_emb

class HyperGraphModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn = GCN()
        self.hyper_weight = nn.Parameter(torch.ones(1527,1527))
        self.hgcn1 = HyperGraph(1527,600,600)
        self.hgcn2 = HyperGraph(600,300,300)
        self.dropout = nn.Dropout(0.3)
        self.batch_norm = nn.BatchNorm1d(1527)
        
        self.relu = nn.LeakyReLU()
        self.dropout = nn.Dropout(0.3)
    def forward(self,adj,x,degree_matrix):
        #动态超图
        hyper_adj = x@self.hyper_weight
        hyper_degree = torch.inverse(torch.diag(torch.sum(hyper_adj,dim=0)))
        node_degree = torch.sqrt(torch.inverse(torch.diag(torch.squeeze(degree_matrix))))
        #超图卷积
        hyper_emb = self.hgcn1(x,hyper_adj,hyper_degree,node_degree) #1527*600
        hyper_emb = self.relu(hyper_emb)
        hyper_emb = self.hgcn2(hyper_emb,hyper_adj,hyper_degree,node_degree) #1527*600
        # hyper_emb = self.relu(hyper_emb)
        return hyper_emb

In [None]:
class TopoRep(nn.Module):
    def __init__(self):
        super().__init__()
        self.rw_fuse = nn.Conv2d(in_channels=2,out_channels=1,kernel_size=1)
        self.beita = nn.Parameter(torch.diag(torch.ones(1527)))

        self.q_linear = nn.Linear(1527,300)
        self.k_linear = nn.Linear(1527,300)
        self.v_linear = nn.Linear(1527,300)
        self.scale = 300 ** -0.5
        self.norm1 = nn.LayerNorm(300)
        self.dropout = nn.Dropout(0.3)

        self.align = nn.Linear(1527,300)
    def forward(self,adj_norm):
        rw_one = adj_norm[None,:,:]
        rw_two = adj_norm@adj_norm[None,:,:]
        rw_matrix = torch.cat([rw_one,rw_two],dim=0) #通道融合
        multi_step_topo = self.rw_fuse(rw_matrix).reshape(-1,1527) #1527,1527

        residual = self.beita@adj_norm
        multi_step_topo = multi_step_topo+residual #1527*1527
        
        #attention
        Q = self.q_linear(multi_step_topo)
        K = self.k_linear(multi_step_topo)
        V = self.v_linear(multi_step_topo)
        #注意力分数矩阵
        attn = (Q @ K.T)*self.scale
        attn = F.softmax(attn,dim=-1) #1527*1527
        attn = self.dropout(attn)
        topo_feature = torch.matmul(attn,V) #1528*300

        multi_step_topo = self.align(multi_step_topo) #1527,1527->1527,300
        topo_feature = self.norm1(topo_feature+multi_step_topo) #1527*300
        return topo_feature

In [None]:
class MetaPath(nn.Module):
    def __init__(self,device):
        super().__init__()
        self.cc_emb = nn.Linear(1,50)
        self.ccc_emb = nn.Linear(1,50)
        self.cdc_emb = nn.Linear(1,50)
        self.dd_emb = nn.Linear(1,50)
        self.dcd_emb = nn.Linear(1,50)
        self.ddd_emb = nn.Linear(1,50)

        self.cc_fuse = nn.Conv2d(in_channels=834,out_channels=1,kernel_size=1)
        self.ccc_fuse = nn.Conv2d(in_channels=834,out_channels=1,kernel_size=1)
        self.cdc_fuse = nn.Conv2d(in_channels=834,out_channels=1,kernel_size=1)
        self.dd_fuse = nn.Conv2d(in_channels=138,out_channels=1,kernel_size=1)
        self.dcd_fuse = nn.Conv2d(in_channels=138,out_channels=1,kernel_size=1)
        self.ddd_fuse = nn.Conv2d(in_channels=138,out_channels=1,kernel_size=1)


        self.cc_meta = nn.Linear(650,300)
        self.ccc_meta = nn.Linear(650,300)
        self.cdc_meta = nn.Linear(650,300)
        self.dd_meta = nn.Linear(650,300)
        self.ddd_meta = nn.Linear(650,300)
        self.dcd_meta = nn.Linear(650,300)


        self.relu = nn.LeakyReLU()
        self.relu1 = nn.ReLU()
        self.norm1 = nn.BatchNorm1d(300)
        self.norm2 = nn.BatchNorm1d(300)
        self.device = device

        self.cc_gcn1 = nn.Linear(1577,900)
        self.ccc_gcn1 = nn.Linear(1577,900)
        self.cdc_gcn1 = nn.Linear(1577,900)
        self.dd_gcn1 = nn.Linear(1577,900)
        self.ddd_gcn1 = nn.Linear(1577,900)
        self.dcd_gcn1 = nn.Linear(1577,900)
        
        self.cc_gcn2 = nn.Linear(900,600)
        self.ccc_gcn2 = nn.Linear(900,600)
        self.cdc_gcn2 = nn.Linear(900,600)
        self.dd_gcn2 = nn.Linear(900,600)
        self.ddd_gcn2 = nn.Linear(900,600)
        self.dcd_gcn2 = nn.Linear(900,600)

        self.circ_fuse = nn.Linear(600*3,300)
        self.dise_fuse = nn.Linear(600*3,300)
        self.dropout = nn.Dropout(0.3)
        nn.init.xavier_normal_(self.cc_emb.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.ccc_emb.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.cdc_emb.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.dd_emb.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.dcd_emb.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.ddd_emb.weight,gain=nn.init.calculate_gain('relu'))

        nn.init.xavier_normal_(self.cc_fuse.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.ccc_fuse.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.cdc_fuse.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.dd_fuse.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.dcd_fuse.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.ddd_fuse.weight,gain=nn.init.calculate_gain('relu'))

        nn.init.xavier_normal_(self.cc_gcn1.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.ccc_gcn1.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.cdc_gcn1.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.dd_gcn1.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.ddd_gcn1.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.dcd_gcn1.weight,gain=nn.init.calculate_gain('relu'))
    def forward(self,meta_paths,node_fea):
        #meta_paths
        c_c = meta_paths['cc']
        c_c_c = meta_paths['ccc']
        c_d_c = meta_paths['cdc']
        d_d = meta_paths['dd']
        d_c_d = meta_paths['dcd']
        d_d_d = meta_paths['ddd']

        cc_d = torch.sum(c_c,dim=1)
        ccc_d = torch.sum(c_c_c,dim=1)
        I = torch.eye(c_d_c.shape[0]).to(self.device)
        c_d_c = c_d_c + I
        cdc_d = torch.sum(c_d_c,dim=1)
        cc_norm = torch.sqrt(torch.inverse(torch.diag(cc_d)))@c_c@torch.sqrt(torch.inverse(torch.diag(cc_d)))
        ccc_norm = torch.sqrt(torch.inverse(torch.diag(ccc_d)))@c_c_c@torch.sqrt(torch.inverse(torch.diag(ccc_d)))   
        cdc_norm = torch.sqrt(torch.inverse(torch.diag(cdc_d)))@c_d_c@torch.sqrt(torch.inverse(torch.diag(cdc_d)))

        dd_d = torch.sum(d_d,dim=1)
        dd_norm = torch.sqrt(torch.inverse(torch.diag(dd_d)))@d_d@torch.sqrt(torch.inverse(torch.diag(dd_d)))
        ddd_d = torch.sum(d_d_d,dim=1)
        ddd_norm = torch.sqrt(torch.inverse(torch.diag(ddd_d)))@d_d_d@torch.sqrt(torch.inverse(torch.diag(ddd_d)))   
        I = torch.eye(d_c_d.shape[0]).to(self.device)
        d_c_d = d_c_d + I
        dcd_d = torch.sum(d_c_d,dim=1)
        dcd_norm = torch.sqrt(torch.inverse(torch.diag(dcd_d)))@d_c_d@torch.sqrt(torch.inverse(torch.diag(dcd_d)))
        #元路径综合语义
        cc_sema = c_c[:,:,None]
        ccc_sema = c_c_c[:,:,None]
        cdc_sema = c_d_c[:,:,None]
        
        cc_fea = self.cc_emb(cc_sema)
        ccc_fea = self.ccc_emb(ccc_sema)
        cdc_fea = self.cdc_emb(cdc_sema)

        cc_fea = cc_fea.reshape(834,834,50)
        ccc_fea = ccc_fea.reshape(834,834,50)
        cdc_fea = cdc_fea.reshape(834,834,50)

        #aggregation
        cc_fea = self.cc_fuse(cc_fea)  #834,50
        cc_fea = cc_fea.view(-1,50)
        ccc_fea = self.ccc_fuse(ccc_fea) 
        ccc_fea = ccc_fea.view(-1,50)
        cdc_fea = self.cdc_fuse(cdc_fea)
        cdc_fea = cdc_fea.view(-1,50)


        #disease
        dd_sema = d_d[:,:,None]
        dcd_sema = d_c_d[:,:,None]
        ddd_sema = d_d_d[:,:,None]

        dd_fea = self.dd_emb(dd_sema)
        dcd_fea = self.dcd_emb(dcd_sema)
        ddd_fea = self.ddd_emb(ddd_sema)
        dd_fea = dd_fea.reshape(138,138,50)
        ddd_fea = ddd_fea.reshape(138,138,50)
        dcd_fea = dcd_fea.reshape(138,138,50)
        #aggregation
        dd_fea = self.dd_fuse(dd_fea)  #834,50
        dd_fea = dd_fea.view(-1,50)
        dcd_fea = self.dcd_fuse(dcd_fea) 
        dcd_fea = dcd_fea.view(-1,50)
        ddd_fea = self.ddd_fuse(ddd_fea)
        ddd_fea = ddd_fea.view(-1,50)
        #GCN
        circ_fea = node_fea[0:834]
        dise_fea = node_fea[834:834+138]
        cc_fea = torch.cat([circ_fea,cc_fea],dim=1)
        ccc_fea = torch.cat([circ_fea,ccc_fea],dim=1)
        cdc_fea = torch.cat([circ_fea,cdc_fea],dim=1)
        dd_fea = torch.cat([dise_fea,dd_fea],dim=1)
        ddd_fea = torch.cat([dise_fea,ddd_fea],dim=1)
        dcd_fea = torch.cat([dise_fea,dcd_fea],dim=1)
        #circRNA

        cc_node = cc_norm@cc_fea         
        cc_node = self.cc_gcn1(cc_node)
        cc_node = self.relu(cc_node)
        cc_node = self.dropout(cc_node)
        cc_node = self.cc_gcn2(cc_node)

        ccc_node = ccc_norm@ccc_fea
        ccc_node = self.ccc_gcn1(ccc_node)
        ccc_node = self.relu(ccc_node)
        ccc_node = self.dropout(ccc_node)
        ccc_node = self.ccc_gcn2(ccc_node)

        cdc_node = cdc_norm@cdc_fea
        cdc_node = self.cdc_gcn1(cdc_node)
        cdc_node = self.relu(cdc_node)
        cdc_node = self.dropout(cdc_node)
        cdc_node = self.cdc_gcn2(cdc_node)
        #disease

        dd_node = dd_norm@dd_fea  #834,1527
        ddd_node = ddd_norm@ddd_fea
        dcd_node = dcd_norm@dcd_fea

        dd_node = self.dd_gcn1(dd_node)
        dd_node = self.relu(dd_node)
        dd_node = self.dropout(dd_node)
        dd_node = self.dd_gcn2(dd_node)

        dcd_node = self.dcd_gcn1(dcd_node)
        dcd_node = self.relu(dcd_node)
        dcd_node = self.dropout(dcd_node)
        dcd_node = self.dcd_gcn2(dcd_node)

        ddd_node = self.ddd_gcn1(ddd_node)
        ddd_node = self.relu(ddd_node)
        ddd_node = self.dropout(ddd_node)
        ddd_node = self.ddd_gcn2(ddd_node)

        circ_sema = torch.cat([cc_node,ccc_node,cdc_node],dim=1)
        dise_sema = torch.cat([dd_node,ddd_node,dcd_node],dim=1)

        circ_sema = self.circ_fuse(circ_sema) #834*300
        # circ_sema = self.relu1(circ_sema)
        circ_sema = self.norm2(circ_sema)
        dise_sema = self.dise_fuse(dise_sema) #138*300
        # dise_sema = self.relu1(dise_sema)
        dise_sema = self.norm2(dise_sema)
        return circ_sema,dise_sema


In [None]:
class RepModel(nn.Module):
    def __init__(self,device):
        super().__init__()
        self.longDistance= TopoRep()

        self.meta_path = MetaPath(device)
        self.hyper_model = HyperGraphModel()
        
        self.par = nn.Parameter(torch.diag(torch.ones(972)))
        self.align = nn.Linear(1527,300)
        self.device = device
        self.tan_h = nn.Tanh()
        self.relu = nn.LeakyReLU()
    def forward(self,adj_norm,node_fea,rowD,meta_paths):
        #超图得到的特征
        hyper_emb = self.hyper_model(adj_norm,node_fea,rowD.reshape(1527,-1)) #1527*300
        hyper_emb = hyper_emb[:972,:]

        #拓扑特征
        topo_emb = self.longDistance(adj_norm) #1527*300
        topo_emb = topo_emb[:972,:]
        #元路径得到的特征
        circ_emb,d_emb = self.meta_path(meta_paths,node_fea)  
        meta_emb = torch.cat([circ_emb,d_emb],dim=0) #972*300
        meta_emb = self.tan_h(meta_emb)
        meta_emb = self.relu(meta_emb)

        node_rep = node_fea[:972,:]
        return topo_emb,meta_emb,hyper_emb,node_rep

In [None]:
class GateAdjust(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj_s = nn.Linear(900,300)
        self.relu = nn.LeakyReLU()
        
        self.topo_w = nn.Parameter(torch.ones(972,300))
        self.hyper_w = nn.Parameter(torch.ones(972,300))
        self.meta_w = nn.Parameter(torch.ones(972,300))

        self.tan_h = nn.Tanh()
        nn.init.xavier_normal_(self.topo_w,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.hyper_w,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.meta_w,gain=nn.init.calculate_gain('relu'))

    def forward(self,semantic,topo,hyper_semantic):
        semantic = self.relu(self.tan_h(self.meta_w*semantic))
        topo = self.relu(self.tan_h(self.topo_w*topo))
        hyper_semantic = self.relu(self.tan_h(self.hyper_w*hyper_semantic))
        
        fused = torch.cat([topo,hyper_semantic,semantic],dim=1)
        fused = self.proj_s(fused)
        fused = self.relu(fused)
        return fused  

In [None]:
class my_model(nn.Module):
    def __init__(self,device):
        super().__init__()

        self.mp_model = RepModel(device)
        self.gate_adjust = GateAdjust()

        self.c1 = nn.Conv2d(in_channels=1,out_channels=32,kernel_size=(2,13),stride=1,padding=0) #8*32*1*3392
        self.p1 = nn.MaxPool2d(kernel_size=(1,5)) #363
        
        self.c2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=(1,4),stride=1,padding=0) #8*64*1*840
        self.p2 = nn.MaxPool2d(kernel_size=(1,10)) #36
        
        self.l1 = nn.Linear(64*36,1200)
        self.d1 = nn.Dropout(0.5)
        self.l2 = nn.Linear(1200,700)
        self.l3=nn.Linear(700,200)
        self.l4=nn.Linear(200,2)
        self.LR = nn.LeakyReLU()
        
        self.regularizer = nn.MSELoss()

        nn.init.xavier_normal_(self.c1.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.c2.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.l1.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.l2.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.l3.weight,gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_normal_(self.l4.weight)
    def forward(self,adj_norm,node_fea,x,y,rowD,meta_paths):
        topo_emb,semantic_emb,hyper_emb,node_rep = self.mp_model(adj_norm,node_fea,rowD,meta_paths) #1527
        fuse_fea = self.gate_adjust(semantic_emb,topo_emb,hyper_emb) #300

        represent = torch.cat([node_rep,fuse_fea],dim=1)

        x2 = y+834
        circ_fea = represent[x][:,None,None,:]
        dise_fea = represent[x2][:,None,None,:]

        fea = torch.cat([circ_fea,dise_fea],dim=2)
        fea = fea.to(torch.float32)
        x = self.c1(fea) #16*32*1*3404
        x = self.LR(x) 
        x = self.p1(x)

        x = self.c2(x) 
        x = self.LR(x)
        x = self.p2(x) 


        x = x.reshape(x.shape[0],-1) 
        x = self.l1(x)
        x = self.LR(x)   
        x = self.d1(x)
        
        x = self.l2(x) 
        x = self.LR(x)
        x = self.d1(x)
        
        x = self.l3(x) 
        x = self.LR(x)
        x = self.d1(x)

        x = self.l4(x) 
        
        return x

In [None]:
def train(model,adj_norm,node_fea,train_dataset,test_dataset,fold_num,epoch,device,degree,meta_paths):
    cros_loss = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
    early_stopping= EarlyStopping(patience=6, verbose= True, save_path=rf'output\pt\dict{fold_num}.pth')
    for e in range(epoch):
        model.train()
        correct = 0
        LOSS = 0
        for x,y,label in train_dataset:
            model.train()
            x = x.to(torch.long).to(device)
            y = y.to(torch.long).to(device)
            label = label.to(torch.long).to(device)
            output = model(adj_norm,node_fea,x,y,degree,meta_paths)
            loss = cros_loss(output,label)
            LOSS += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            max_index = torch.argmax(output,dim=1)
            eq_e = max_index==label
            eq_num = eq_e.sum()
            correct+=eq_num
        early_stopping(LOSS, model)
        correct_percent = correct/len(train_dataset.dataset)
        print(f'第{e}个eopch的正确率为{correct_percent}')
        if early_stopping.early_stop:
            print(f'early_stopping!')
            break
        if e + 1 == epoch:
            torch.save(model.state_dict(), rf'output\dict{fold_num}.pth')
    torch.cuda.empty_cache()
    # torch.save(model.state_dict(),RF'F:\origin_model_5k\cnn_model\epoch20\cnn_model{cros}.pth')
    test_correct = 0
    output_all = torch.tensor([]).to(device)
    label_all = torch.tensor([]).to(device)
    model.load_state_dict(torch.load(rf'output\dict{fold_num}.pth'))
    model.eval()
    with torch.no_grad():
        for x,y,label in test_dataset:
            x = x.to(torch.long).to(device)
            y = y.to(torch.long).to(device)
            label = label.to(torch.long).to(device)
            t_output = model(adj_norm,node_fea,x,y,degree,meta_paths)#8*2
            output_all = torch.cat([output_all,t_output],dim=0) #cat每一个batch的output
            label_all = torch.cat([label_all,label],dim=0)

                # 输出准确率用
            max_index = torch.argmax(t_output,dim=1)
            eq_e = max_index==label
                
            eq_num = eq_e.sum()
            test_correct+=eq_num
        correct_percent = test_correct/len(test_dataset.dataset)
        torch.save(output_all,RF'output\output{fold_num}')
        torch.save(label_all,RF'output\label{fold_num}')
        # print(f'测试集的正确率为:{correct_percent}')

In [None]:
class MyDataset(Dataset):
    def __init__(self,t_dataset,c_d) :
        super().__init__()
        self.t_dataset = t_dataset
        self.c_d = c_d
    def __getitem__(self, index) :
        x,y = self.t_dataset[:,index]
        label = self.c_d[x][y]
        return x,y,label
    def __len__(self):
        return self.t_dataset.shape[1]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
c_d = np.load(rf'data\circRNA_disease.npy')
c_d = torch.tensor(c_d).to(torch.long)
train_data = torch.load(rf'data\train_dataset.pth')
test_data = torch.load(rf'data\test_data.pth')
adj_matrix_list = torch.load(rf'data\cover_feature_matrix.pth')
cc_list = torch.load(rf'data\meta_path\cc.pth')
ccc_list = torch.load(rf'data\meta_path\ccc.pth')
cdc_list = torch.load(rf'data\meta_path\cdc.pth')
dd_list = torch.load(rf'data\meta_path\dd.pth')
dcd_list = torch.load(rf'data\meta_path\dcd.pth')
ddd_list = torch.load(rf'data\meta_path\ddd.pth')
torch.cuda.empty_cache()
for i in range(5):
    cc = cc_list[i].to(torch.float32).to(device)
    ccc = ccc_list[i].to(torch.float32).to(device)
    cdc = cdc_list[i].to(torch.float32).to(device)
    dd = dd_list[i].to(torch.float32).to(device)
    ddd = ddd_list[i].to(torch.float32).to(device)
    dcd = dcd_list[i].to(torch.float32).to(device)
    node_fea = adj_matrix_list[i].to(torch.float32).to(device)
    meta_paths = {}
    meta_paths['cc'] = cc
    meta_paths['ccc'] = ccc
    meta_paths['cdc'] = cdc
    meta_paths['dd'] = dd
    meta_paths['ddd'] = ddd
    meta_paths['dcd'] = dcd
    adj_matrix = adj_matrix_list[i]
    adj_matrix = adj_matrix.to(torch.float32).to(device)
    rowD = torch.sum(adj_matrix,dim=1)
    colD = torch.sum(adj_matrix,dim=0)
    avgD = (rowD+colD)/2.0 #1527
    adj_norm = torch.sqrt(torch.inverse(torch.diag(avgD)))@adj_matrix@torch.sqrt(torch.inverse(torch.diag(avgD)))
    adj_norm = adj_norm.to(torch.float32)

    print(f'cross:{i}')
    
    model = my_model(device).to(device)
    train_dataset = DataLoader(dataset=MyDataset(t_dataset=train_data[i].to(device),c_d=c_d),batch_size=16,shuffle=True)
    test_dataset = DataLoader(dataset=MyDataset(t_dataset=test_data[i].to(device),c_d=c_d),batch_size=16,shuffle=False)
    epoch = 70
    train(model,adj_norm,adj_matrix,train_dataset,test_dataset,i,epoch,device,rowD,meta_paths)
