In [2]:
import torch
import biom
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix, f1_score, matthews_corrcoef, mean_absolute_error, r2_score, accuracy_score


In [3]:
from membed.Attention_embedding import load_data_imdb, DataLoader, TransformerEncoder, evaluate_auc_gpu, set_seed

membed 0.1.0 initialized.


In [48]:
metadata_df = pd.read_csv("/home/dongbiao/all_study/data/metadata_t.csv",
                               sep="\t",
                               index_col=0,
                               low_memory=False, dtype={0:str})
existing_studies = metadata_df.study.unique()

In [4]:
set_seed(11)

In [5]:
num_steps=600
p_drop=0.4
batch_size=4000
d_model=100
n_layers=1
n_heads=1
group = "group"

In [8]:

import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from sklearn import metrics
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, matthews_corrcoef
from biom import load_table
 
# 配置路径
base_dir = "../../data/Disease_classification_loo/disease_data"
 
# 初始化结果容器
results = {
    'disease': [], 'test_study': [], 'auc': [], 'f1': [],
    'mcc': [], 'acc': [], 'confusion_matrix': []
}
 
def evaluate_auc_gpu(Y, prob):
    """综合性能评估函数"""
    Y = Y.numpy().astype('int')
    prob = prob.cpu().detach().numpy()
    
    # 计算ROC曲线
    fpr, tpr, _ = metrics.roc_curve(Y, prob)
    roc_auc = metrics.auc(fpr, tpr)
    
    # 计算PR曲线
    precision, recall, _ = metrics.precision_recall_curve(Y, prob)
    pr_auc = metrics.auc(recall, precision)
    
    # 计算分类指标
    y_pred = (prob > 0.5).astype(int)
    cm = confusion_matrix(Y, y_pred)
    f1 = f1_score(Y, y_pred, average='macro')
    acc = accuracy_score(Y, y_pred)
    mcc = matthews_corrcoef(Y, y_pred)
    
    return roc_auc, pr_auc, f1, mcc, acc, cm, fpr, tpr
 

 
def save_results(disease, study, roc_auc, pr_auc, f1, mcc, acc, cm, fpr, tpr):
    """保存分析结果"""
    # 保存指标
    results['disease'].append(disease)
    results['test_study'].append(study)
    results['auc'].append(roc_auc)
    results['f1'].append(f1)
    results['mcc'].append(mcc)
    results['acc'].append(acc)
    results['confusion_matrix'].append(str(cm))  # 将矩阵转为字符串保存
    
    # 保存ROC曲线
    roc_df = pd.DataFrame({'FPR': fpr, 'TPR': tpr})
    output_dir = os.path.join(base_dir, disease, study, "results")
    os.makedirs(output_dir, exist_ok=True)
    roc_df.to_csv(os.path.join(output_dir, "roc_curve.csv"), index=False)
 
def process_disease(disease_path):
    """处理单个疾病目录"""
    disease_name = os.path.basename(disease_path)
    print(f"\n{'='*30}\n处理疾病: {disease_name}\n{'='*30}")
    
    # 获取所有研究目录
    study_dirs = [
        d for d in os.listdir(disease_path) 
        if os.path.isdir(os.path.join(disease_path, d))
    ]
    
    # 跳过单个研究的疾病
    if len(study_dirs) <= 1:
        print(f"!! 跳过 {disease_name}，仅有 {len(study_dirs)} 个研究")
        return
    
    # 检查元数据文件
    metadata_path = os.path.join(disease_path, "metadata.tsv")
    if not os.path.exists(metadata_path):
        print(f"!! 严重错误：{disease_name} 缺失元数据文件")
        return
    
    # 遍历每个研究
    for study in study_dirs:
        study_path = os.path.join(disease_path, study)
        print(f"\n-- 处理研究: {study}")
        
        # 验证必需文件
        required_files = {
            'train': 'train_loo.biom',
            'test': 'test_loo.biom',
            'model': os.path.join('results', 'attention_loo.pt')
        }
        
        missing_files = [
            f for f in required_files.values() 
            if not os.path.exists(os.path.join(study_path, f))
        ]
        
        if missing_files:
            print(f"!! 文件缺失: {', '.join(missing_files)}")
            continue
        
        # 加载数据
        train_data = load_data_imdb(f'{study_path}/train_loo.biom',
                                    metadata_path,
                                    'group',
                                    'sample',
                                    num_steps)
        test_data = load_data_imdb(f'{study_path}/test_loo.biom',
                                   metadata_path,
                                    'group',
                                    'sample',
                                    num_steps)
        
        train_iter = DataLoader(train_data,
                                batch_size=batch_size,
                                shuffle=False)
        
        test_iter = DataLoader(test_data,
                                batch_size=batch_size,
                                shuffle=False)
        
        # 加载模型
        fid_dict = train_data()
        net = TransformerEncoder(otu_size=len(fid_dict),
                                    seq_len=num_steps+1,
                                    d_model=d_model,
                                    n_layers=n_layers,
                                    n_heads=n_heads,
                                    p_drop=p_drop,
                                    pad_id=fid_dict['<pad>'])
        
        net.load_state_dict( 
                torch.load(
                    f'{study_path}/results/attention_loo.pt',
                    map_location=torch.device('cpu')))
        
        print(net.embedding)  # 确认num_embeddings参数
        
        with torch.no_grad():
            net.eval()
            for i, (features, abundance, group_lables, mask) in enumerate(test_iter):
                pred, _ = net(features, abundance, mask)
                # 评估性能
                roc_auc, pr_auc, f1, mcc, acc, cm, fpr, tpr = evaluate_auc_gpu(group_lables, pred)
                
                # 保存结果
                save_results(disease_name, study, roc_auc, pr_auc, f1, mcc, acc, cm, fpr, tpr)
                print(f"√ 成功保存 {study} 结果")
        

for disease_folder in os.listdir(base_dir):
    disease_path = os.path.join(base_dir, disease_folder)
    if os.path.isdir(disease_path):
        process_disease(disease_path)

# 汇总结果
summary_df = pd.DataFrame(results)
summary_path = os.path.join( "summary_ROC_results.csv")
summary_df.to_csv(summary_path, index=False)



处理疾病: CAD

-- 处理研究: PRJDB6472


FileNotFoundError: [Errno 2] No such file or directory: 'group'

In [6]:
model_res = pd.DataFrame({"disease":results['disease'],"study":results['test_study'],"f1":results['f1'], "auc":results['auc'], "cms":results['confusion_matrix']})
model_res.to_csv("attention_model_loo.csv", index=None)
# print(model_res)
model_res

Unnamed: 0,disease,study,f1,auc,cms
0,CAD,PRJDB6472,0.350000,0.687500,[[28 0]\n [24 0]]
1,CAD,PRJDB7456,0.322581,0.783838,[[30 0]\n [33 0]]
2,HTN,PRJNA762195,0.259542,0.556956,[[34 0]\n [63 0]]
3,HTN,PRJNA722359,0.400000,0.543478,[[46 0]\n [23 0]]
4,OB,PRJEB11419,0.340566,0.679973,[[361 0]\n [338 0]]
...,...,...,...,...,...
57,CRC,PRJNA430990,0.718861,0.809213,[[117 2]\n [ 17 10]]
58,CRC,PRJEB36789,0.773333,0.898501,[[99 3]\n [24 27]]
59,ASD,PRJNA932561,0.333333,0.826250,[[40 0]\n [40 0]]
60,ASD,PRJNA578223,0.729049,0.792101,[[34 14]\n [12 36]]


In [7]:
import re
import csv
from pathlib import Path
import pandas as pd

def extract_metrics(text_path):
    # 正则匹配模式
    disease_pattern = re.compile(r'^Processing disease:\s+(\w+)')  # 疾病行
    study_pattern = re.compile(r'^(\w+)\s+K_')                     # studyId行
    metric_pattern = re.compile(r'(AUC ROC|Average Precision|F1 value|mcc value|aupr value|acc value):\s+([0-9.]+)')  # 指标行
    
    # 混淆矩阵相关变量
    matrix_lines_remaining = 0
    current_matrix = []
    current_confusion_matrix = None

    records = []
    current_study = None
    current_disease = None

    with open(text_path, 'r') as f:
        for line in f:
            line = line.rstrip()  # 去除行尾换行符

            # 匹配疾病行
            disease_match = disease_pattern.match(line)
            if disease_match:
                current_disease = disease_match.group(1)
                matrix_lines_remaining = 0  # 重置状态
                current_confusion_matrix = None
                continue

            # 匹配study行
            study_match = study_pattern.match(line)
            if study_match:
                current_study = study_match.group(1)
                matrix_lines_remaining = 0  # 重置状态
                current_confusion_matrix = None
                continue

            # 匹配混淆矩阵标题行
            if "conflusion Matrix" in line:  # 容忍拼写错误
                matrix_lines_remaining = 2
                current_matrix = []
                continue

            # 提取混淆矩阵数据
            if matrix_lines_remaining > 0:
                row_match = re.match(r'\s*\[\s*(\d+)\s+(\d+)\s*\]', line)
                if row_match:
                    current_matrix.append([
                        int(row_match.group(1)),
                        int(row_match.group(2))
                    ])
                    matrix_lines_remaining -= 1
                    if matrix_lines_remaining == 0:
                        current_confusion_matrix = current_matrix
                else:
                    matrix_lines_remaining = 0  # 格式不符则放弃

            # 匹配指标行
            if current_study:
                metrics = dict(metric_pattern.findall(line))
                if metrics:
                    record = {
                        "studyId": current_study,
                        "disease": current_disease,
                        **{k: float(v) for k, v in metrics.items()}
                    }
                    # 添加混淆矩阵
                    # print(current_confusion_matrix)s
                    if current_confusion_matrix:
                        record.update({
                            "tn": current_confusion_matrix[0][0],
                            "fp": current_confusion_matrix[0][1],
                            "fn": current_confusion_matrix[1][0],
                            "tp": current_confusion_matrix[1][1],
                            "confusion_matrix": str(current_confusion_matrix)
                        })
                        current_confusion_matrix = None  # 重置
                    records.append(record)

    # 创建DataFrame
    df = pd.DataFrame(records)
    
    # 美化显示格式
    pd.options.display.float_format = '{:.4f}'.format
    display(df)
    return df

# 使用示例
input_file = Path("/home/dongbiao/all_study/script/script_leave_one_out_new/log/4_21_RF.out")
RF_model_res = extract_metrics(input_file)
RF_model_res.to_csv("RF_model_loo.csv", index=None)

Unnamed: 0,studyId,disease,AUC ROC,Average Precision,F1 value,aupr value,acc value,mcc value
0,PRJCA007396,AS,0.6042,0.8923,0.4434,0.8929,0.7966,
1,SRP332779,AS,0.7087,0.8404,0.4126,0.8351,0.7024,
2,PRJNA932561,ASD,0.6381,0.6457,0.5429,0.6410,0.5875,0.2242
3,PRJNA578223,ASD,0.5202,0.5608,0.3287,0.5543,0.4896,
4,PRJNA1103839,ASD,0.5140,0.5672,0.3222,0.5526,0.4754,
...,...,...,...,...,...,...,...,...
57,PRJNA978958,T2DM,0.6079,0.8036,0.5538,0.7970,0.6897,0.2300
58,PRJNA719138,T2DM,0.3090,0.5134,0.3729,0.4890,0.5946,
59,PRJNA646010,T2DM,0.3260,0.5018,0.3594,0.4782,0.5610,
60,PRJNA819279,T2DM,0.4655,0.4810,0.3051,0.4479,0.4390,


In [8]:
# 提取 studyId 和 study 列的唯一值
unique_studyId = set(RF_model_res['studyId'].unique())
unique_study = set(model_res['study'].unique())
 
# 找出不同的研究
different_studies = unique_studyId.symmetric_difference(unique_study)
 
print("不同的研究有:", different_studies)

不同的研究有: set()


In [9]:
import pandas as pd
 
# 读取两个CSV文件
df_attention = pd.read_csv("attention_model_loo.csv")
df_rf = pd.read_csv("RF_model_loo.csv")
 
# 统一研究ID列名
df_attention = df_attention.rename(columns={"studyId": "study","disease":'disease_name_ab'})
df_rf = df_rf.rename(columns={"studyId": "study","disease":'disease_name_ab'})
# 按study列进行内连接合并（只保留双方共有的study）
merged_df = pd.merge(
    df_attention, 
    df_rf,
    on=["study","disease_name_ab"],
    how="inner"
)
 
# 保存结果
merged_df.to_csv("all_model_loo.csv", index=False)
merged_df

Unnamed: 0,disease_name_ab,study,f1,auc,cms,AUC ROC,Average Precision,F1 value,aupr value,acc value,mcc value
0,CAD,PRJDB6472,0.3500,0.6875,[[28 0]\n [24 0]],0.5759,0.5553,0.5772,0.5536,0.5962,0.2673
1,CAD,PRJDB7456,0.3226,0.7838,[[30 0]\n [33 0]],0.7944,0.7987,0.7232,0.7979,0.7302,0.4652
2,HTN,PRJNA762195,0.2595,0.5570,[[34 0]\n [63 0]],0.4750,0.6339,0.4156,0.6249,0.6392,
3,HTN,PRJNA722359,0.4000,0.5435,[[46 0]\n [23 0]],0.4806,0.3567,0.3803,0.3377,0.4058,0.0913
4,OB,PRJEB11419,0.3406,0.6800,[[361 0]\n [338 0]],0.5295,0.5254,0.3253,0.5243,0.4821,
...,...,...,...,...,...,...,...,...,...,...,...
57,CRC,PRJNA430990,0.7189,0.8092,[[117 2]\n [ 17 10]],0.6857,0.4133,0.6141,0.4120,0.7329,0.2401
58,CRC,PRJEB36789,0.7733,0.8985,[[99 3]\n [24 27]],0.9168,0.8524,0.8334,0.8521,0.8431,0.6825
59,ASD,PRJNA932561,0.3333,0.8263,[[40 0]\n [40 0]],0.6381,0.6457,0.5429,0.6410,0.5875,0.2242
60,ASD,PRJNA578223,0.7290,0.7921,[[34 14]\n [12 36]],0.5202,0.5608,0.3287,0.5543,0.4896,


In [11]:
import pandas as pd

# 读取两个文件
# ------------------------------------------------------------
# 读取模型结果文件
model_df = pd.read_csv("all_model_loo.csv")

# 读取metadata文件（处理制表符分隔和多余列）
metadata_df = pd.read_csv(
    "/home/dongbiao/all_study/data/metadata_t.csv",
    sep='\t',
    usecols=['study', 'diagnosis', 'disease_name', 'disease_name_ab'],  # 只读取需要的列
    dtype={'study': str}  # 确保study列为字符串
).drop_duplicates('study')  # 去重，每个study保留第一条记录

# 数据合并
# ------------------------------------------------------------
# 左连接合并（保留所有模型结果）
merged_df = pd.merge(
    model_df,
    metadata_df,
    left_on='study',
    right_on='study',
    how='left'
)

# 重排列顺序（可选）
column_order = [
    'study', 'diagnosis', 'disease_name','disease_name_ab',  # 新增列在前
    'train_studies', 'fold', 'auc', 'cms',  # 原列在后
    'AUC ROC', 'Average Precision', 'F1 value', 
    'mcc value', 'aupr value', 'acc value'
]
merged_df = merged_df[column_order]

# 保存结果（覆盖原文件）
# ------------------------------------------------------------
merged_df.to_csv(
    "all_model_loo.csv", 
    index=False,
    float_format='%.4f'  # 统一小数精度
)

print("合并完成，结果已写回原文件")

合并完成，结果已写回原文件


In [12]:
model_df = pd.read_csv("all_model_loo.csv")
display(model_df)

Unnamed: 0,study,diagnosis,disease_name,disease_name_ab,train_studies,fold,auc,cms,AUC ROC,Average Precision,F1 value,mcc value,aupr value,acc value
0,PRJCA007396,C,Ankylosing Spondylitis,AS,['SRP332779'],loo,0.5606,[[ 0 11]\n [ 1 47]],0.572,0.8598,0.4434,,0.8564,0.7966
1,SRP332779,AS,Ankylosing Spondylitis,AS,['PRJCA007396'],loo,0.8715,[[20 4]\n [ 8 51]],0.7433,0.877,0.5118,0.1298,0.8747,0.7108
2,Agp_Austim,Austim,Austim,ASD,"['PRJNA1103839', 'PRJNA578223', 'PRJNA932561']",loo,0.443,[[ 1 73]\n [ 1 73]],0.4181,0.4727,0.3303,,0.4666,0.4932
3,PRJNA1103839,Healthy,Autism Spectrum Disorder,ASD,"['Agp_Austim', 'PRJNA578223', 'PRJNA932561']",loo,0.6492,[[15 13]\n [ 6 22]],0.3119,0.4379,0.3253,,0.4205,0.4821
4,PRJNA578223,ASD,Autism Spectrum Disorder,ASD,"['Agp_Austim', 'PRJNA1103839', 'PRJNA932561']",loo,0.7784,[[44 3]\n [13 28]],0.2169,0.339,0.3575,0.0499,0.3227,0.4773
5,PRJNA932561,ASD,Autism Spectrum Disorder,ASD,"['Agp_Austim', 'PRJNA1103839', 'PRJNA578223']",loo,0.84,[[29 11]\n [ 6 34]],0.4725,0.4822,0.3543,0.0,0.47,0.5
6,Agp_Bipolar,Bipolar,Bipolar,BD,"['PRJEB23500', 'PRJNA1068750']",loo,0.7961,[[13 4]\n [ 4 11]],0.8118,0.7859,0.7758,0.5659,0.7872,0.7812
7,PRJEB23500,H,Bipolar disorder,BD,"['Agp_Bipolar', 'PRJNA1068750']",loo,0.6732,[[19 28]\n [11 61]],0.5792,0.6693,0.3737,,0.6675,0.5966
8,PRJNA1068750,BD,Bipolar disorder,BD,"['Agp_Bipolar', 'PRJEB23500']",loo,0.751,[[23 23]\n [ 6 36]],0.7365,0.74,0.7044,0.4154,0.7357,0.7045
9,PRJDB6472,C,coronary artery disease,CAD,['PRJDB7456'],loo,0.6057,[[17 11]\n [ 7 17]],0.5863,0.5537,0.3796,0.0636,0.5504,0.4808
