In [1]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch
from torch.optim import RMSprop
from GATv2_model import DualCNNandGATv2
from functions import train_model, evaluate_model,load_subgraph_loader, CombinedLoss
from dataset import balance_and_split_by_random_state, load_datasets_and_node_features, pre_data_deal
import matplotlib.pyplot as plt
import gc  # 导入垃圾收集模块
import random

seed = 42  # 可以是任意数字，使用相同的数字确保每次运行的初始化相同
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)  # 对所有GPU设置种子


# 初始化参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_dir = '/home/tjzhang03/zxj/deal_data/data_output'
cell_type = "NHEK"
input_dir = f'{base_dir}/{cell_type}'
output_dir = f'{input_dir}/random_state_split_dataset_CNN_GAT'
non_gene_ids_file = '/home/tjzhang03/zxj/deal_data/data_input/BENGI/non_gene_ids.txt'
geometric_batch_size = 1
node_height = 7 #行
node_width =21 #列
cnn_out_channels = 64  #节点通过CNN后的特征维度
gat1_out_channels = 32  #第一层GAT输出维度
gat2_out_channels = 16  #第二层GAT输出节点维度
Edge_hidden_dim = 16
num_heads=6
lr=0.0005
weight_decay=1e-4
step_size=30
gamma=0.1
num_epochs = 100
threshold = 0.5
weight_bce = 0.8

# 循环不同的 random_state 值
random_states = range(30)  # 示例：遍历5个不同的random_state
# random_states = [8] # 示例：遍历5个不同的random_state

aucs_per_state = []
# standardized_df, enhancer_features, promoter_features = pre_data_deal(input_dir, cell_type, non_gene_ids_file)
for random_state in random_states:  
    # print(random_state)
    #过滤、平衡、划分训练集验证测试、转换数据格式
    # train_edges_df, val_edges_df, test_edges_df, train_nodes_dict, val_nodes_dict, test_nodes_dict = balance_and_split_by_random_state(standardized_df, enhancer_features, promoter_features, random_state, output_dir)
    train_edges_df, val_edges_df, test_edges_df, train_nodes_dict, val_nodes_dict, test_nodes_dict = load_datasets_and_node_features(output_dir, random_state)
    train_loader = load_subgraph_loader(train_nodes_dict, train_edges_df, geometric_batch_size)
    val_loader = load_subgraph_loader(val_nodes_dict, val_edges_df, geometric_batch_size, is_training=False) #val_loader是字典

    model = DualCNNandGATv2(node_height, node_width, cnn_out_channels, gat1_out_channels, gat2_out_channels, Edge_hidden_dim, 2, num_heads).to(device)
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    optimizer = RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    # criterion = torch.nn.BCELoss()
    # criterion = torch.nn.BCEWithLogitsLoss()
    criterion = CombinedLoss(weight_bce=weight_bce)
    
    # 保存每个 epoch 的
    best_val_auc = float('inf')
    last_train_loss = float('inf')
    last_val_loss = float('inf')

    train_losses = []
    val_losses = []
    val_aucs = []

    for epoch in range(num_epochs):
        # 训练和验证模型的代码
        train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, threshold, device)
        train_losses.append(train_loss)
        val_loss, val_auc, _, _ = evaluate_model(model, val_loader, criterion, threshold, device)
        val_losses.append(val_loss)
        val_aucs.append(val_auc)
        scheduler.step()  # 更新学习率

    # 选取此 random_state 下的最佳 AUC
    best_val_auc = max(val_aucs)
    aucs_per_state.append(best_val_auc)
    # 计算当前 random_state 下所有epoch的AUC值的方差
    current_state_auc_variance = np.var(val_aucs)
    print(f"Current random state {random_state}, MAX AUC:{best_val_auc}, AUC variance: {current_state_auc_variance}")
    

    # # 循环结束，不再需要模型和数据时
    del model  # 删除模型实例
    del train_loader, val_loader  # 删除数据变量
    gc.collect()  # 显式调用垃圾收集器
    if torch.cuda.is_available():
        torch.cuda.empty_cache()  # 清理CUDA缓存

  from .autonotebook import tqdm as notebook_tqdm


Current random state 0, MAX AUC:0.7296220351454294, AUC variance: 0.0009893836902026734
Current random state 1, MAX AUC:0.7042194674012856, AUC variance: 0.0009576234111367874
Current random state 2, MAX AUC:0.7111567283163266, AUC variance: 0.0003645186874664097
Current random state 3, MAX AUC:0.7103699100171563, AUC variance: 0.00033938857394167623
Current random state 4, MAX AUC:0.7248421266086251, AUC variance: 0.0004799223183996303
Current random state 5, MAX AUC:0.6765762639202196, AUC variance: 0.0004024677816143521
Current random state 6, MAX AUC:0.7148274116228615, AUC variance: 0.0007452039532925622
Current random state 7, MAX AUC:0.6960855958514996, AUC variance: 0.0004974169405861248
Current random state 8, MAX AUC:0.6613781363806464, AUC variance: 0.00035006204441296126
Current random state 9, MAX AUC:0.7277099808726972, AUC variance: 0.0005955528741595392
Current random state 10, MAX AUC:0.7304944670121865, AUC variance: 0.0007190241272699818
Current random state 11, MAX 