In [None]:
import os
import datetime
import uuid
from tqdm import tqdm

import torch
import torch.nn    as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

from data_declaration import Task
from loader_helper    import LoaderHelper


from vapformer.model_components import thenet,UnetrPP
from evaluation import evaluate_model
import matplotlib.pyplot as plt
import numpy as np
import random
from torchmetrics.classification import BinaryAUROC
from vapformer.dynunet_block import get_conv_layer, UnetResBlock
from monai.networks.layers.utils import get_norm_layer

from sklearn import neighbors
import scipy.sparse as sp
from models import tabular_net
from torch_geometric.data import Data,Batch

In [None]:
DEVICE = torch.device("cuda:1")

ld_helper = LoaderHelper(task=Task.NC_v_AD)
ld_helper = LoaderHelper(task=Task.sMCI_v_pMCI)
train_dl = ld_helper.get_train_dl(0, batch_size = 16)
test_dl = ld_helper.get_test_dl(0, batch_size = 64, shuffle=False)

torch.manual_seed(2024)

In [None]:
from sklearn.metrics import accuracy_score, roc_auc_score
def sen_spe(a, b):
    TP = 0
    FN = 0
    TN = 0
    FP = 0
    for i in range(len(a)):  # a:label
        if a[i] == 1 and b[i] == 1:
            TP = TP + 1
        elif a[i] == 1 and b[i] == 0:
            FN = FN + 1
        elif a[i] == 0 and b[i] == 1:
            FP = FP + 1
        elif a[i] == 0 and b[i] == 0:
            TN = TN + 1
        else:
            pass

    TPR = TP / (TP + FN + 1e-6)  # True positive rate, Sensitivity
    TNR = TN / (TN + FP + 1e-6)  # True Negative Rate, Specificity
    return TPR, TNR
def test_GNN(model, mask, tensor_x, adj):
    model.eval()

    logits = model(tensor_x.to(DEVICE).float(), adj.to(DEVICE).float())
    mask_logits = logits[mask]
    
    predicted = mask_logits.max(dim=1)[1].long()
    label = labels[mask]

    predicted = predicted.cpu().detach().numpy()
    label = label.cpu().detach().numpy()

    accuracy = accuracy_score(label, predicted)
    sensitivity, specificity = sen_spe(label, predicted)
    

    auc = roc_auc_score(label, mask_logits.cpu().detach())

    return accuracy, sensitivity, specificity, auc
def evaluate_gnn(total_label, total_pre):
    correct = 0
    total = 0
    TP = 0.000001
    TN = 0.000001
    FP = 0.000001
    FN = 0.000001
    fpr, tpr, thresholds = roc_curve(total_label, total_pre[:, 1])
    roc_auc = auc(fpr, tpr)

    # 在ROC曲线上选择最佳阈值
    best_idx = np.argmax(tpr - fpr)

    best_thresh_roc = thresholds[best_idx]
    print(best_thresh_roc)

    for i in range(len(total_pre)):
                real_class = total_label[i]
                predicted_class = 1 if total_pre[i, 1] > best_thresh_roc else 0
                
                if predicted_class == real_class:
                    correct += 1
                    if real_class == 0:
                        TN += 1
                    elif real_class == 1:
                        TP += 1
                else:
                    if real_class == 0:
                        FP += 1
                    elif real_class == 1:
                        FN += 1
                    
                total += 1
    
    sensitivity = round((TP / (TP + FN)), 5)
    specificity = round((TN / (TN + FP)), 5)
    accuracy = round((sensitivity + specificity) / 2, 5)
    
    return accuracy, sensitivity, specificity, roc_auc

In [None]:
def create_graph_from_embedding(embedding, name, n = 30):
    latent_dim, batch_size = embedding.shape
    if name == 'knn':
        A = neighbors.kneighbors_graph(embedding, n_neighbors = n).toarray()
        A = (A + np.transpose(A)) / 2
        return A
from utils import *

In [None]:
mri_resnet18 = torch.load('./weights/mri_resnet50_mci_new.pth').to(DEVICE)
pet_resnet18 = torch.load('./weights/pet_resnet50_mci.pth').to(DEVICE)
tab_net = torch.load('./weights/tb_net_MCI.pth').to(DEVICE)
mri_resnet18.fc = torch.nn.Identity()
pet_resnet18.fc = torch.nn.Identity()
tab_net.linear3 = torch.nn.Identity()
#resnet18.sig = torch.nn.Identity()

In [None]:


import torch
from torch.utils.data import ConcatDataset, DataLoader

# 假设 train_dl 和 test_dl 是你的训练和测试数据加载器

# 创建一个空的 ConcatDataset
combined_dataset = ConcatDataset([train_dl.dataset, test_dl.dataset])

# 创建合并后的 DataLoader
combined_dl = DataLoader(combined_dataset, batch_size=64, shuffle=False)

# 初始化一个空的 train_mask
train_mask = []

# 遍历 combined_dataset 数据集
train_len = len(train_dl.dataset)
test_len = len(test_dl.dataset)

for i in range(train_len + test_len):
    if i < train_len:
        train_mask.append(True)  # 训练集的样本
    else:
        train_mask.append(False)  # 测试集的样本

# 将 train_mask 转换为 Torch Tensor
train_mask = torch.tensor(train_mask)

# 打印 mask 长度以验证
print(f"Length of train_mask: {len(train_mask)}")
print(f"Number of True values in train_mask: {train_mask.sum().item()}")

# 确保 train_mask 长度与 combined_dataset 长度一致
assert len(train_mask) == len(combined_dataset)

# 现在 train_mask 是一个 Torch Tensor，每个元素表示相应样本是否来自训练集


In [None]:
all_outputs = []
label = []
# 将模型设置为评估模式

mri_resnet18.eval()
pet_resnet18.eval()
tab_net.eval()
for _, sample_batched in enumerate(tqdm(combined_dl)):
    batch_mri = sample_batched['mri'].to(DEVICE).float()
    batch_pet = sample_batched['pet'].to(DEVICE).float()
    batch_clinical = sample_batched['clin_t'].to(DEVICE).float()
    batch_clinical = torch.nan_to_num(batch_clinical, nan=0.0)
    with torch.no_grad():
        mri_outputs = mri_resnet18(batch_mri)
        pet_outputs = pet_resnet18(batch_pet)
        tab_outputs = tab_net(batch_clinical)
        outputs = torch.cat([mri_outputs, pet_outputs, tab_outputs], dim=1)
        #outputs = torch.cat([mri_outputs], dim=1)
    all_outputs.append(outputs)
    label.append(sample_batched['label'])
    labels = torch.cat(label, dim=0)
    stacked_outputs = torch.cat(all_outputs, dim=0)
    

In [None]:
from hypergraph_utils import *

In [None]:
dis_mat = Eu_dis(stacked_outputs.cpu())
H = construct_H_with_KNN_from_distance(dis_mat, 50, is_probH=False, m_prob=1)
G = generate_G_from_H(H)

# GraphMAE

In [None]:
def sce_loss(x, y, alpha=3):
    x = F.normalize(x, p=2, dim=-1)
    y = F.normalize(y, p=2, dim=-1)
    loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
    loss = loss.mean()
    return loss

# loss function: sig
def sig_loss(x, y):
    x = F.normalize(x, p=2, dim=-1)
    y = F.normalize(y, p=2, dim=-1)
    loss = (x * y).sum(1)
    loss = torch.sigmoid(-loss)
    loss = loss.mean()
    return loss

def mask_edge(graph, mask_prob):
    E = graph.num_edges()
    mask_rates = torch.FloatTensor(np.ones(E) * mask_prob)
    masks = torch.bernoulli(1 - mask_rates)
    mask_idx = masks.nonzero().squeeze(1)
    return mask_idx


# graph transformation: drop edge
def drop_edge(graph, drop_rate, return_edges=False):
    if drop_rate <= 0:
        return graph
    edge_mask = mask_edge(graph, drop_rate)
    src = graph.edges()[0]
    dst = graph.edges()[1]

    nsrc = src[edge_mask]
    ndst = dst[edge_mask]

    ng = Data(edge_index=torch.concat((nsrc, ndst), 0))
    dsrc = src[~edge_mask]
    ddst = dst[~edge_mask]

    if return_edges:
        return ng, (dsrc, ddst)
    return ng

def initialize_gnn_decoder(gnn_type, input_dim, hid_dim, num_layer,device):
    if gnn_type == 'GAT':
            gnn = GAT(input_dim = input_dim, hid_dim = hid_dim, num_layer = num_layer)
    elif gnn_type == 'GCN':
            gnn = GCN(input_dim = input_dim, hid_dim = hid_dim, num_layer = num_layer)
    elif gnn_type == 'GraphSAGE':
            gnn = GraphSAGE(input_dim = input_dim, hid_dim = hid_dim, num_layer = num_layer)
    elif gnn_type == 'GIN':
            gnn = GIN(input_dim = input_dim, hid_dim = hid_dim, num_layer = num_layer)
    elif gnn_type == 'GCov':
            gnn = GCov(input_dim = input_dim, hid_dim = hid_dim, num_layer = num_layer)
    elif gnn_type == 'GraphTransformer':
            gnn = GraphTransformer(input_dim = input_dim, hid_dim = hid_dim, num_layer = num_layer)
    elif gnn_type == 'HGNN':
            gnn = HGNN(input_dim = input_dim, hid_dim = hid_dim, num_layer = num_layer)
    else:
            raise ValueError(f"Unsupported GNN type: {gnn_type}")
    gnn.to(device)
    return gnn

In [None]:
from torch import nn
class GraphMAELoss(nn.Module):
    def __init__(self, encoder, decoder, hidden_dim, enc_in_dim, dec_in_dim, mask_rate=0.75, drop_edge_rate=0.0, replace_rate=0.1, loss_fn='sce', alpha_l=2):
        super(GraphMAELoss, self).__init__()
        self._mask_rate = mask_rate
        self._drop_edge_rate = drop_edge_rate
        self._replace_rate = replace_rate
        self._mask_token_rate = 1 - self._replace_rate
        self.hidden_dim = hidden_dim

        # build encoder
        self.encoder = encoder

        # build decoder
        self.decoder = decoder
        
        self.enc_mask_token = nn.Parameter(torch.zeros(1, enc_in_dim))
        self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False)
        # setup loss function
        self.criterion = self.setup_loss_fn(loss_fn, alpha_l)

    def forward(self, data):
              
        loss, x_hidden = self.mask_attr_prediction(data)
        loss_item = {"loss": loss.item()}

        return loss, loss_item,x_hidden

    def setup_loss_fn(self, loss_fn, alpha_l):
        if loss_fn == "mse":
            criterion = nn.MSELoss()
        elif loss_fn == "sce":
            criterion = partial(sce_loss, alpha=alpha_l)
        else:
            raise NotImplementedError
        return criterion

    def encoding_mask_noise(self, g, x, mask_rate=0.3):
        num_nodes = g.num_nodes
        
        perm = torch.randperm(num_nodes, device=x.device)
        # random masking
        num_mask_nodes = int(mask_rate * num_nodes)
        mask_nodes = perm[: num_mask_nodes]
        keep_nodes = perm[num_mask_nodes:]
        if self._replace_rate > 0:
            num_noise_nodes = int(self._replace_rate * num_mask_nodes)
            perm_mask = torch.randperm(num_mask_nodes, device=x.device)
            token_nodes = mask_nodes[perm_mask[: int(self._mask_token_rate * num_mask_nodes)]]
            noise_nodes = mask_nodes[perm_mask[-int(self._replace_rate * num_mask_nodes):]]
            noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[:num_noise_nodes]
            out_x = x.clone()
            out_x[token_nodes] = 0.0
            out_x[noise_nodes] = x[noise_to_be_chosen]
        else:
            out_x = x.clone()
            token_nodes = mask_nodes
            out_x[mask_nodes] = 0.0
        out_x[token_nodes] += self.enc_mask_token
        use_g = g.clone()
        return use_g, out_x, (mask_nodes, keep_nodes)
    
    def mask_attr_prediction(self, data, pretrain_method='graphmae'):
        
        g = data
        x = data.x

        pre_use_g, use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise(g, x, self._mask_rate)
        
        if self._drop_edge_rate > 0:
            use_g, masked_edges = drop_edge(pre_use_g, self._drop_edge_rate, return_edges=True)
        else:
            use_g = pre_use_g
        
        # if there are noise nodes before reconstruction, then execture this line
        all_hidden = self.encoder(x=use_x, edge_index=use_g.edge_index)

        # if there are none noise nodes before reconstruction, please execture this line
        # all_hidden = self.encoder(data.x, data.edge_index)

        # ---- attribute reconstruction ----

        node_reps = self.encoder_to_decoder(all_hidden)
        node_reps[mask_nodes] = 0

        recon_graph = Data(x=node_reps, edge_index=pre_use_g.edge_index).to(data.x.device)
        recon_node_reps = self.decoder(recon_graph.x, recon_graph.edge_index)

        x_init = x[mask_nodes]
        x_rec = recon_node_reps[mask_nodes]
        loss = self.criterion(x_rec, x_init)
        return loss, all_hidden

    def embed(self, g, x):
        rep = self.encoder(x=x, edge_index=g.edge_index)
        return rep

    @property
    def enc_params(self):
        return self.encoder.parameters()

    @property
    def dec_params(self):
        return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()])


In [None]:
data = Data(x=stacked_outputs.to(DEVICE), edge_index=torch.Tensor(G).to(DEVICE), y=labels.squeeze().long(), train_mask=train_mask, val_mask = ~train_mask, test_mask = ~train_mask)

In [None]:
edge_indices = np.transpose(np.transpose(np.nonzero(H)))
data = Data(x=stacked_outputs.to(DEVICE), edge_index=torch.Tensor(edge_indices).long().to(DEVICE), y=labels.squeeze().long(), train_mask=train_mask, val_mask = ~train_mask, test_mask = ~train_mask)

In [None]:
import torch
import torch.optim as optim
from torch.autograd import Variable
from torch_geometric.loader import DataLoader
from torch.utils.data import TensorDataset
import sys
sys.path.append('/home/cyliu/code/ProG/')
from prompt_graph.data import load4link_prediction_multi_graph, load4link_prediction_single_graph
from torch.optim import Adam
import time
from prompt_graph.model import GAT, GCN, GCov, GIN, GraphSAGE, GraphTransformer, HGNN
from prompt_graph.data import load4node, load4graph, NodePretrain
import os


data = data.to(DEVICE)
input_dim = data.x.shape[1]
output_dim = 2
in_node_feat_dim = input_dim


In [None]:
import numpy as np
import random

def hypergraph_random_walk_matrix(H, start_nodes, walk_length):
    """
    Perform random walks on a hypergraph represented by an adjacency matrix.

    Parameters:
    H (np.array): Hypergraph adjacency matrix of shape (num_nodes, num_nodes).
    start_nodes (list of int): List of starting nodes for the random walks.
    walk_length (int): The length of each walk.

    Returns:
    walk_list (list of lists): List of walks, where each walk is a list of nodes visited.
    """
    num_nodes = H.shape[0]
    walk_list = []

    for start_node in start_nodes:
        walk = [start_node]
        current_node = start_node

        for _ in range(walk_length):
            # Find neighbors of the current node (nodes connected by a hyperedge)
            neighbors = np.nonzero(H[current_node])[0]
            
            if len(neighbors) == 0:
                break  # If no neighbors, terminate the walk
            
            # Choose a random neighbor as the next node
            next_node = random.choice(neighbors)
            
            walk.append(next_node)
            current_node = next_node

        walk_list.append(torch.Tensor(walk).long().to(DEVICE))
    
    return walk_list

# Example Usage:
# Define a hypergraph as a list of hyperedges
def edge_index_to_hypergraph_adjacency(edge_index, num_nodes):
    """
    Convert edge index to hypergraph adjacency matrix.

    Parameters:
    edge_index (np.array): Array of shape (2, num_edges) representing the edges in the graph.
    num_nodes (int): Number of nodes in the hypergraph.

    Returns:
    H (np.array): Hypergraph adjacency matrix of shape (num_nodes, num_nodes).
    """
    # Initialize the adjacency matrix with zeros
    H = np.zeros((num_nodes, num_nodes), dtype=int)

    # Iterate over each edge and update the adjacency matrix
    for i in range(edge_index.shape[1]):
        node1, node2 = edge_index[:, i]
        H[node1, node2] = 1
        H[node2, node1] = 1  # Since the graph is undirected

    return H

# Perform random walks for a list of starting nodes [0, 2, 4] with a walk length of 5
split_ratio = 0.1
walk_length = 30
all_random_node_list = torch.randperm(data.num_nodes)
selected_node_num_for_random_walk = int(split_ratio * data.num_nodes)
# 这一行
random_node_list = all_random_node_list[:selected_node_num_for_random_walk].to(DEVICE)
walk_list = hypergraph_random_walk_matrix(H, start_nodes=random_node_list, walk_length=walk_length)
graph_list = [] 
skip_num = 0        
for walk in walk_list:   
    subgraph_nodes = torch.unique(walk).to(DEVICE)
    if(len(subgraph_nodes)<5):
        skip_num+=1
        continue
    data = data.to(DEVICE)
    #print(subgraph_nodes)
    subgraph_data = data.subgraph(subgraph_nodes)
    subgraph_data.edge_index = torch.Tensor(edge_index_to_hypergraph_adjacency(subgraph_data.edge_index, len(subgraph_data.x))).to(DEVICE)
    graph_list.append(subgraph_data)



In [None]:
data

In [None]:
from functools import partial
from prompt_graph.model import GAT, GCN, GCov, GIN, GraphSAGE, GraphTransformer, HGNN
#graph_dataloader = DataLoader([data], batch_size=1, shuffle=True)
graph_dataloader = DataLoader(graph_list, batch_size=1, shuffle=True)
graph_n_feat_dim = input_dim
hid_dim = 128
gnn = GCN(input_dim = input_dim, hid_dim = hid_dim, num_layer = 2)

decoder = initialize_gnn_decoder("HGNN",hid_dim,input_dim,2,DEVICE)
mask_rate = 0.75
drop_edge_rate=0.0
replace_rate=0.
loss_fn='sce'
alpha_l=2
hid_dim = 128
learning_rate = 0.001
weight_decay = 0.00005
epochs = 1000

mae_loss = GraphMAELoss(gnn, decoder, hid_dim, graph_n_feat_dim, hid_dim, mask_rate, drop_edge_rate, replace_rate, loss_fn, alpha_l).to(DEVICE)

optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, list(gnn.parameters()) + list(decoder.parameters())),
            lr=learning_rate,
            weight_decay=weight_decay
            )

# Pretrain AD

In [None]:

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn.inits import reset, uniform
from torch.optim import Adam
import torch
from torch import nn
import time
from prompt_graph.utils import generate_corrupted_graph
from prompt_graph.data import load4node, load4graph, NodePretrain
import os
import torch.nn.functional as F
from itertools import chain
from functools import partial
import numpy as np
from prompt_graph.model import GAT, GCN, GCov, GIN, GraphSAGE, GraphTransformer
from torchmetrics import MeanMetric
import numpy as np

loss_metric = MeanMetric()
train_loss_min = np.inf
patience = 50
cnt_wait = 0
for epoch in range(epochs):
    st_time = time.time()
    
    loss_metric.reset()
    
    for step, batch in enumerate(graph_dataloader):
        optimizer.zero_grad()
        batch = batch.to(DEVICE)
        loss, loss_item, x_hidden = mae_loss.forward(batch)              
        loss.backward()
        optimizer.step() 
        loss_metric.update(loss.item(), batch.size(0))
    print(f"GraphMAE [Pretrain] Epoch {epoch}/{epochs} | Train Loss {loss_metric.compute():.5f} | "
          f"Cost Time {time.time() - st_time:.3}s")
    
    if train_loss_min > loss_metric.compute():
        train_loss_min = loss_metric.compute()
        cnt_wait = 0
    else:
        cnt_wait += 1
        if cnt_wait == patience:
            print('-' * 100)
            print('Early stopping at '+str(epoch) +' eopch!')
            break
    print(cnt_wait)
folder_path = f"./Experiment/pre_trained_model/ADNI"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)
torch.save(gnn.state_dict(),
            "./Experiment/pre_trained_model/{}/{}.{}.{}.pth".format("ADNI", 'GraphMAE', "HGCN1", str(hid_dim) + 'hidden_dim'))

print("+++model saved ! {}/{}.{}.{}.pth".format("ADNI", 'GraphMAE', "HGCN", str(hid_dim) + 'hidden_dim'))       

In [None]:
from functools import partial
from prompt_graph.model import GAT, GCN, GCov, GIN, GraphSAGE, GraphTransformer, HGNN
#graph_dataloader = DataLoader([data], batch_size=1, shuffle=True)
graph_dataloader = DataLoader(graph_list, batch_size=1, shuffle=True)
graph_n_feat_dim = input_dim
hid_dim = 128
# gnn = HGNN(input_dim = input_dim, hid_dim = hid_dim, out_dim=2, num_layer = 2)

# decoder = HGNN(input_dim = 2, hid_dim = hid_dim, out_dim=input_dim, num_layer = 2)
gnn = HGNN(input_dim = input_dim, hid_dim = hid_dim, out_dim=hid_dim, num_layer = 2)
decoder = HGNN(input_dim = hid_dim, hid_dim = hid_dim, out_dim=input_dim, num_layer = 2)
#initialize_gnn_decoder("HGNN",2,input_dim,input_dim,DEVICE)
mask_rate = 0.75
drop_edge_rate=0.0
replace_rate=0.1
loss_fn='sce'
alpha_l=2
hid_dim = 128
learning_rate = 0.001
weight_decay = 0.00005
epochs = 1000
mae_loss = GraphMAELoss(gnn, decoder, hid_dim, graph_n_feat_dim, 128, mask_rate, drop_edge_rate, replace_rate, loss_fn, alpha_l).to(DEVICE)

optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, list(gnn.parameters()) + list(decoder.parameters())),
            lr=learning_rate,
            weight_decay=weight_decay
            )

In [None]:

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn.inits import reset, uniform
from torch.optim import Adam
import torch
from torch import nn
import time
from prompt_graph.utils import generate_corrupted_graph
from prompt_graph.data import load4node, load4graph, NodePretrain
import os
import torch.nn.functional as F
from itertools import chain
from functools import partial
import numpy as np
from prompt_graph.model import GAT, GCN, GCov, GIN, GraphSAGE, GraphTransformer
from torchmetrics import MeanMetric
import numpy as np

loss_metric = MeanMetric()
train_loss_min = np.inf
patience = 50
cnt_wait = 0
for epoch in range(epochs):
    st_time = time.time()
    
    loss_metric.reset()
    
    for step, batch in enumerate(graph_dataloader):
        optimizer.zero_grad()
        batch = batch.to(DEVICE)
        loss, loss_item, x_hidden = mae_loss.forward(batch)              
        loss.backward()
        optimizer.step() 
        loss_metric.update(loss.item(), batch.size(0))
    print(f"GraphMAE [Pretrain] Epoch {epoch}/{epochs} | Train Loss {loss_metric.compute():.5f} | "
          f"Cost Time {time.time() - st_time:.3}s")
    
    if train_loss_min > loss_metric.compute():
        train_loss_min = loss_metric.compute()
        cnt_wait = 0
    else:
        cnt_wait += 1
        if cnt_wait == patience:
            print('-' * 100)
            print('Early stopping at '+str(epoch) +' eopch!')
            break
    print(cnt_wait)
folder_path = f"./Experiment/pre_trained_model/ADNI_MCI"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)
torch.save(gnn.state_dict(),
            "./Experiment/pre_trained_model/{}/{}.{}.{}.pth".format("ADNI_MCI", 'GraphMAE', "128HGNN_mri", str(hid_dim) + 'hidden_dim'))

print("+++model saved ! {}/{}.{}.{}.pth".format("ADNI_MCI", 'GraphMAE', "HGNN", str(hid_dim) + 'hidden_dim'))       

In [None]:
batch

# Gprompt Plus

In [None]:
from prompt_graph.tasker import NodeTask, GraphTask
from prompt_graph.utils import seed_everything
from torchsummary import summary
from prompt_graph.utils import print_model_parameters
from prompt_graph.utils import  get_args
from prompt_graph.data import load4node,load4graph, split_induced_graphs
import pickle
import random
import numpy as np
import os
import pandas as pd

def load_induced_graph(dataset_name, data, device):

    folder_path = './Experiment/induced_graph/' + dataset_name
    if not os.path.exists(folder_path):
            os.makedirs(folder_path)

    file_path = folder_path + '/induced_graph_min100_max300.pkl'
    if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                print('loading induced graph...')
                graphs_list = pickle.load(f)
                print('Done!!!')
    else:
        print('Begin split_induced_graphs.')
        split_induced_graphs(data, folder_path, device, smallest_size=100, largest_size=300)
        with open(file_path, 'rb') as f:
            graphs_list = pickle.load(f)
    graphs_list = [graph.to(device) for graph in graphs_list]
    return graphs_list


seed_everything(42)

In [None]:
DEVICE = torch.device("cuda:3")

In [None]:
edge_indices = np.transpose(np.transpose(np.nonzero(H)))
data = Data(x=stacked_outputs.to(DEVICE), edge_index=torch.Tensor(edge_indices).long().to(DEVICE), y=labels.squeeze().long(), train_mask=train_mask, val_mask = ~train_mask, test_mask = ~train_mask)

In [None]:
def edge_index_to_hypergraph_adjacency(edge_index, num_nodes):
    """
    Convert edge index to hypergraph adjacency matrix.

    Parameters:
    edge_index (np.array): Array of shape (2, num_edges) representing the edges in the graph.
    num_nodes (int): Number of nodes in the hypergraph.

    Returns:
    H (np.array): Hypergraph adjacency matrix of shape (num_nodes, num_nodes).
    """
    # Initialize the adjacency matrix with zeros
    H = np.zeros((num_nodes, num_nodes), dtype=int)

    # Iterate over each edge and update the adjacency matrix
    for i in range(edge_index.shape[1]):
        node1, node2 = edge_index[:, i]
        H[node1, node2] = 1
        H[node2, node1] = 1  # Since the graph is undirected

    return H
dataset_name = "ADNI_MCI"
prompt_type = 'GPF-plus'
if prompt_type in ['Gprompt', 'All-in-one', 'GPF', 'GPF-plus']:
    graphs_list = load_induced_graph(dataset_name, data, DEVICE) 
else:
    graphs_list = None 
for graph in graphs_list:   
   
    graph.edge_index = torch.Tensor(edge_index_to_hypergraph_adjacency(graph.edge_index, len(graph.x))).to(DEVICE)

In [None]:

# 获取标签
labels = data.y

# 使用 train_mask 生成训练集索引
train_mask = data.train_mask
train_indices = torch.nonzero(train_mask).view(-1)

# 使用 ~train_mask 生成测试集索引
test_mask = ~train_mask
test_indices = torch.nonzero(test_mask).view(-1)


folder = "./Experiment/sample_data/Node/ADNI_MCI/11_shot/1"
if not os.path.exists(folder):
    os.makedirs(folder)
# 保存训练集索引和标签
torch.save(train_indices, os.path.join(folder, 'train_idx.pt'))
train_labels = labels[train_indices]
torch.save(train_labels, os.path.join(folder, 'train_labels.pt'))

# 保存测试集索引和标签
torch.save(test_indices, os.path.join(folder, 'test_idx.pt'))
test_labels = labels[test_indices]
torch.save(test_labels, os.path.join(folder, 'test_labels.pt'))

In [None]:
pre_train_model_path = "./Experiment/pre_trained_model/ADNI/GraphMAE.HGCN.128hidden_dim.pth"
num_layer = 2
gnn_type = "GCN"
hid_dim = 128
prompt_type = "GPF-plus"
epochs = 1000
shot_num = 10
lr = 0.00001
decay = 2e-6
batch_size = 128
tasker = NodeTask(pre_train_model_path = pre_train_model_path, 
                    dataset_name = dataset_name, num_layer = num_layer,
                    gnn_type = gnn_type, hid_dim = hid_dim, prompt_type = prompt_type,
                    epochs = epochs, shot_num = 11, device='2', lr = lr, wd = decay,
                    batch_size = batch_size, data = data, input_dim = input_dim, output_dim = output_dim, graphs_list = graphs_list)

pre_train_type = tasker.pre_train_type

In [None]:
_, test_acc, std_test_acc, f1, std_f1, roc, std_roc, _, _= tasker.run()
  
print("Final Accuracy {:.4f}±{:.4f}(std)".format(test_acc, std_test_acc)) 
print("Final F1 {:.4f}±{:.4f}(std)".format(f1,std_f1)) 
print("Final AUROC {:.4f}±{:.4f}(std)".format(roc, std_roc)) 

In [None]:
pre_train_model_path = "./Experiment/pre_trained_model/ADNI_MCI/GraphMAE.32HGNN.128hidden_dim.pth"
num_layer = 2
gnn_type = "HGNN"
hid_dim = 32

prompt_type = "GPF-plus"
epochs = 1000
shot_num = 10
lr = 0.00001
decay = 2e-6
batch_size = 1
tasker = NodeTask(pre_train_model_path = pre_train_model_path, 
                    dataset_name = dataset_name, num_layer = num_layer,
                    gnn_type = gnn_type, hid_dim = hid_dim, prompt_type = prompt_type,
                    epochs = epochs, shot_num = 11, device='3', lr = lr, wd = decay,
                    batch_size = batch_size, data = data, input_dim = input_dim, output_dim = 2, graphs_list = graphs_list)

pre_train_type = tasker.pre_train_type

In [None]:

_, test_acc, std_test_acc, f1, std_f1, roc, std_roc, _, _= tasker.run()
print("Final Accuracy {:.4f}±{:.4f}(std)".format(test_acc, std_test_acc)) 
print("Final F1 {:.4f}±{:.4f}(std)".format(f1,std_f1)) 
print("Final AUROC {:.4f}±{:.4f}(std)".format(roc, std_roc)) 

In [None]:
tasker.gnn = HGNN(input_dim = input_dim, hid_dim = hid_dim, out_dim=2, num_layer = 2)

# HGNN

In [None]:
mri_feature = stacked_outputs[:,:512]
pet_feature = stacked_outputs[:,512:1024]
clinical_feature = stacked_outputs[:,1024:]

from hypergraph_utils import *
mri_dis_mat = Eu_dis(mri_feature.cpu())
pet_dis_mat = Eu_dis(pet_feature.cpu())
clinical_dis_mat = Eu_dis(clinical_feature.cpu())

k_distance = 30

mri_H = construct_H_with_KNN_from_distance(mri_dis_mat, k_distance, is_probH=False, m_prob=1)
pet_H = construct_H_with_KNN_from_distance(pet_dis_mat, k_distance, is_probH=False, m_prob=1)
clinical_H = construct_H_with_KNN_from_distance(clinical_dis_mat, k_distance, is_probH=False, m_prob=1)

In [None]:
#H_list = [mri_H, pet_H, clinical_H]
H_list = [mri_H]
H = None
for h in H_list:
    if h is not None:
        # for the first H appended to fused hypergraph incidence matrix
        if H is None:
            H = h
        else:
            if type(h) != list:
                H = np.hstack((H, h))
            else:
                tmp = []
                for a, b in zip(H, h):
                    tmp.append(np.hstack((a, b)))
                H = tmp

G = generate_G_from_H(H)
G = torch.Tensor(G).to(DEVICE)

# HPrompt

In [None]:
class HGPrompt(torch.nn.Module):
    def __init__(self, token_num: int, token_dim: int):
        super(HGPrompt, self).__init__()
        self.tokens = torch.nn.Parameter(torch.Tensor(token_num, token_dim))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.tokens, nonlinearity='leaky_relu', mode='fan_in', a=0.01)


In [None]:
token_num = 2
token_dim = 1024
prompt = HGPrompt(token_num, token_dim).to(DEVICE)

In [None]:
# 新的维度
new_rows = H.shape[0] + token_num
new_cols = H.shape[1] + token_num * 3

# 创建一个形状为 (new_rows, new_cols) 的全1张量
H_expanded = np.ones((new_rows, new_cols))

# 将原始的H张量复制到扩充后的张量中
H_expanded[:H.shape[0], :H.shape[1]] = H

In [None]:
prompt_dis_mat = Eu_dis(prompt.tokens.detach().cpu())

k_distance = token_num // 2

prompt_H = construct_H_with_KNN_from_distance(prompt_dis_mat, k_distance, is_probH=False, m_prob=1)
H_expanded[H.shape[0]:, H.shape[1]:H.shape[1]+token_num] = prompt_H
H_expanded[H.shape[0]:, H.shape[1]+token_num:H.shape[1]+token_num*2] = prompt_H
H_expanded[H.shape[0]:, H.shape[1]+token_num*2:H.shape[1]+token_num*3] = prompt_H

In [None]:
G = generate_G_from_H(H_expanded)
G = torch.Tensor(G).to(DEVICE)

In [None]:
from sklearn.metrics import roc_curve, auc
from collections import OrderedDict
from models import HGNN
input = clinical_feature
input = torch.cat([mri_feature, pet_feature], dim=1)
input = torch.cat([mri_feature, clinical_feature], dim=1)
input = torch.cat([pet_feature, clinical_feature], dim=1)
prompt = prompt.to(DEVICE)
prompted_input = torch.cat([input, prompt.tokens])
hgnn = HGNN(in_ch=prompted_input.shape[1],
                    n_class=2,
                    n_hid=256,
                    dropout=0.0)
# state_dict = torch.load("./Experiment/pre_trained_model/ADNI_MCI/GraphMAE.HGNN.128hidden_dim.pth")
# state_dict = OrderedDict({
#     'hgc1.weight': state_dict['conv_layers.0.weight'],
#     'hgc1.bias': state_dict['conv_layers.0.bias'],
#     'hgc2.weight': state_dict['conv_layers.1.weight'],
#     'hgc2.bias': state_dict['conv_layers.1.bias'],
#     # Add more mappings as needed
# })

# hgnn.load_state_dict(state_dict)
hgnn = torch.load("./weights/concat_pet_clinical_resnet50_hgnn_mci.pth")
hgnn = hgnn.to(DEVICE)
prompt = HGPrompt(token_num, token_dim).to(DEVICE)
G = torch.Tensor(G).to(DEVICE)
prompted_input = prompted_input.to(DEVICE)
# here
optimizer = optim.AdamW(prompt.parameters(), lr=3e-3, weight_decay=0)


loss_function = nn.BCELoss()
loss_function = nn.CrossEntropyLoss()
loss_fig = []
eva_fig = []

epochs = 1000
best_auc = 0

best_epoch = 0


for epoch in range(epochs):
    # here
    hgnn.eval()
    prompt.train()

    prompt_dis_mat = Eu_dis(prompt.tokens.detach().cpu())


    
    prompt_H = construct_H_with_KNN_from_distance(prompt_dis_mat, k_distance, is_probH=False, m_prob=1)
    H_expanded[H.shape[0]:, H.shape[1]:H.shape[1]+token_num] = prompt_H
    H_expanded[H.shape[0]:, H.shape[1]+token_num:H.shape[1]+token_num*2] = prompt_H
    H_expanded[H.shape[0]:, H.shape[1]+token_num*2:H.shape[1]+token_num*3] = prompt_H
    G = generate_G_from_H(H_expanded)
    G = torch.Tensor(G).to(DEVICE)
    #prompt.tokens = prompt.tokens.to(DEVICE)
    prompted_input = torch.cat([input, prompt.tokens])

    
    pred = hgnn(prompted_input, G)[:-token_num]
    loss = loss_function(pred[train_mask].to(DEVICE), labels[train_mask].squeeze(1).to(DEVICE).long())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    tqdm.write("Epoch: {}/{}, train loss: {}".format(epoch, epochs, round(loss.item(), 5)))
    # filein.write("Epoch: {}/{}, train loss: {}\n".format(i, epochs, round(loss, 5)))
    loss_fig.append(round(loss.item(), 5))
    # here
    hgnn.eval()
    prompt.eval()
    # here
    logits = hgnn(prompted_input, G)[:-token_num]
    mask_logits = logits[~train_mask]
        
    predicted = F.softmax(mask_logits, dim=1)
    label = labels[~train_mask]
    predicted = predicted.cpu().detach().numpy()
    label = label.cpu().detach().numpy()
    
    accuracy, sensitivity, specificity, roc_auc = evaluate_gnn(label, predicted)
    eva_fig.append(accuracy)
    tqdm.write("Epoch: {}/{}, evaluation loss: {}".format(epoch, epochs,(accuracy, sensitivity, specificity, roc_auc)))
    if roc_auc > best_auc:
        print(f"save pth in epoch {epoch}")
        best_auc = roc_auc
        best_epoch = epoch
        # here
        torch.save(hgnn, "./weights/prompt_concat_pet_clinical_resnet50_hgnn_mci.pth")

    print(best_epoch)
    print(best_auc)


In [None]:
torch.cat([mri_feature, pet_feature], dim=0).shape

In [None]:
mri_feature.shape

## HGNN

In [None]:
from sklearn.metrics import roc_curve, auc
from collections import OrderedDict
from models import HGNN
input = clinical_feature
input = torch.cat([mri_feature, pet_feature], dim=1)
input = torch.cat([mri_feature, clinical_feature], dim=1)
input = torch.cat([pet_feature, clinical_feature], dim=1)
input = mri_feature
hgnn = HGNN(in_ch=input.shape[1],
                    n_class=2,
                    n_hid=256,
                    dropout=0.0)
# state_dict = torch.load("./Experiment/pre_trained_model/ADNI_MCI/GraphMAE.HGNN.128hidden_dim.pth")
# state_dict = OrderedDict({
#     'hgc1.weight': state_dict['conv_layers.0.weight'],
#     'hgc1.bias': state_dict['conv_layers.0.bias'],
#     'hgc2.weight': state_dict['conv_layers.1.weight'],
#     'hgc2.bias': state_dict['conv_layers.1.bias'],
#     # Add more mappings as needed
# })

# hgnn.load_state_dict(state_dict)
hgnn = hgnn.to(DEVICE)
input = input.to(DEVICE)
G = torch.Tensor(G).to(DEVICE)

# here
optimizer = optim.AdamW(hgnn.parameters(), lr=3e-4, weight_decay=5e-4)


loss_function = nn.BCELoss()
loss_function = nn.CrossEntropyLoss()
loss_fig = []
eva_fig = []

epochs = 8000
best_auc = 0

best_epoch = 0

for epoch in range(epochs):
    # here
    hgnn.train()

    pred = hgnn(input, G)
    loss = loss_function(pred[train_mask].to(DEVICE), labels[train_mask].to(DEVICE).squeeze(1).long())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    tqdm.write("Epoch: {}/{}, train loss: {}".format(epoch, epochs, round(loss.item(), 5)))
    # filein.write("Epoch: {}/{}, train loss: {}\n".format(i, epochs, round(loss, 5)))
    loss_fig.append(round(loss.item(), 5))
    # here
    hgnn.eval()

    # here
    logits = hgnn(input, G)
    mask_logits = logits[~train_mask]
        
    predicted = F.softmax(mask_logits, dim=1)
    label = labels[~train_mask]
    predicted = predicted.cpu().detach().numpy()
    label = label.cpu().detach().numpy()
    
    accuracy, sensitivity, specificity, roc_auc = evaluate_gnn(label, predicted)
    eva_fig.append(accuracy)
    tqdm.write("Epoch: {}/{}, evaluation loss: {}".format(epoch, epochs,(accuracy, sensitivity, specificity, roc_auc)))
    if roc_auc > best_auc:
        print(f"save pth in epoch {epoch}")
        best_auc = roc_auc
        best_epoch = epoch
        # here
        torch.save(hgnn, "./weights/concat_mri_resnet50_hgnn_mci.pth")

    print(best_epoch)
    print(best_auc)

