In [5]:
from adversarial import load_pretrained_mlp, extract_microbe_features
from MLP import MicrobeDataLoader, SparseMLP
import numpy as np
import pandas as pd
import torch
from biom import load_table
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

torch.cuda.set_device(0) 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class MicrobeDataLoader:
    def __init__(self, csv_path, biom_path, batch_size=32):
        
        self.batch_size = batch_size
        self.features_tensor, self.labels_tensor = self._load_from_files(csv_path, biom_path)
        
        # 添加数据验证
        assert len(self.features_tensor) > 0, "特征数据为空！"
        assert len(self.labels_tensor) > 0, "标签数据为空！"
        assert len(self.features_tensor) == len(self.labels_tensor), "特征与标签数量不匹配！"
        
        self._split_data()
        self.feature_names = self._get_feature_names(biom_path)
    
    
    def _get_feature_names(self, biom_path):
        """从BIOM文件中提取特征名称（含分类学信息）"""
        try:
            table = load_table(biom_path)
            
            
            feature_names = []
            for i, obs_id in enumerate(table.ids(axis='observation')):
                # 获取元数据中的分类信息
                metadata = table.metadata(obs_id, axis='observation')
                
                if metadata and 'taxonomy' in metadata:
                    # 示例格式：['k__Bacteria', 'p__Firmicutes', ...]
                    taxonomy = metadata['taxonomy']
                    
                    # 提取最具体分类等级（种级）
                    species = next((t for t in reversed(taxonomy) if t != ''), 'Unclassified')
                    feature_names.append(f"{species} (OTU-{obs_id})")
                else:
                    feature_names.append(f"Unclassified (OTU-{obs_id})")
                    
            return feature_names
        except Exception as e:
            print(f"特征名称获取失败: {str(e)}")
            return [f"Feature_{i}" for i in range(table.shape[0])]  # 生成默认名称
        
    def get_feature_names(self):
        
        return self.feature_names

 
    
    def _load_from_files(self, csv_path, biom_path):
        
        metadata = pd.read_csv(csv_path)
        #print(f"元数据加载成功，样本数: {len(metadata)}")
        
        table = load_table(biom_path)
        #print(f"BIOM表加载成功，原始样本数: {table.shape[1]}")

        sample_ids = set(table.ids(axis="sample"))
        metadata_ids = set(metadata["ID"])
        matched_ids = sample_ids & metadata_ids

        #print(f"匹配样本数: {len(matched_ids)}")
        unmatched = sample_ids - metadata_ids
        '''if unmatched:
            print(f"未匹配样本ID: {unmatched}")'''

        metadata = metadata[metadata["ID"].isin(matched_ids)].set_index("ID")
        filtered_table = table.filter(matched_ids, axis="sample", inplace=False)

        # 转换为DataFrame并验证
        df = filtered_table.to_dataframe(dense=True).T  # 明确指定dense格式
       
        feature_data = self._add_feature_engineering(filtered_table)
        feature_data = feature_data.T

        feature_data = df.values.astype(np.float32)
        return torch.tensor(feature_data, dtype=torch.float32), torch.tensor(metadata['DX_GROUP'].values, dtype=torch.long)
    
    def _add_feature_engineering(self, table):
        data = table.matrix_data.T.toarray()
        return np.log1p(data)
    
    def _split_data(self):
        try:
            train_features, test_features, train_labels, test_labels = train_test_split(
                self.features_tensor, self.labels_tensor, test_size=0.2, random_state=42
            )
            train_features, val_features, train_labels, val_labels = train_test_split(
                train_features, train_labels, test_size=0.1, random_state=42
            )

            self.train_dataset = TensorDataset(train_features, train_labels)
            self.val_dataset = TensorDataset(val_features, val_labels)
            self.test_dataset = TensorDataset(test_features, test_labels)
        except Exception as e:
            print(f"数据划分失败: {str(e)}")
            raise

    def get_loaders(self):
        return (
            DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4),
            DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4),
            DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4)
        )


def get_top_microbes(model, feature_names, top_k=100):
    """
    获取微生物组中最重要的前top_k个特征
    原理：分析输入层权重绝对值之和
    """
    # 获取第一个全连接层的权重
    first_layer = model.feature_extractor[0]
    weights = first_layer.weight.data.cpu().numpy()
    
    # 计算每个特征的绝对权重之和
    importance_scores = np.abs(weights).sum(axis=0)
    
    # 获取排序后的索引
    sorted_indices = np.argsort(importance_scores)[::-1]
    
    # 输出结果
    print("\n微生物组重要特征 Top", top_k)
    for i in range(top_k):
        idx = sorted_indices[i]
        print(f"{i+1}. {feature_names[idx]} ({importance_scores[idx]:.4f})")


csv_path = "/home/yangzongxian/xlz/ASD_GCN/main/data2/microbe_data.csv"
biom_path = "/home/yangzongxian/xlz/ASD_GCN/main/data2/feature-table.biom"
mlp_model = load_pretrained_mlp()
microbe_loader = MicrobeDataLoader(csv_path = csv_path, biom_path=biom_path, batch_size=32)
microbe_train_loader = microbe_loader.get_loaders()[0]
microbe_val_loader = microbe_loader.get_loaders()[1]
microbe_test_loader = microbe_loader.get_loaders()[2]
microbe_train_features, microbe_train_labels = extract_microbe_features(microbe_train_loader)
microbe_val_features, microbe_val_labels = extract_microbe_features(microbe_val_loader)
microbe_test_features, microbe_test_labels = extract_microbe_features(microbe_test_loader)
# 假设microbe_features是特征名称列表（从biom文件获取）
microbe_loader = MicrobeDataLoader(csv_path, biom_path)
feature_names = microbe_loader.get_feature_names()
get_top_microbes(mlp_model, feature_names, top_k=100)


未匹配样本ID: {'SRR9666981', 'SRR9666855', 'SRR9666852', 'SRR9666976', 'SRR9666949', 'SRR9666805', 'SRR9666825', 'SRR9666973'}
数据加载验证通过


  model.load_state_dict(torch.load("/home/yangzongxian/xlz/ASD_GCN/main/down/sparse_mlp.pth"))
  checkpoint = torch.load("fmri_gnn.pth", map_location=device)


GNN feature dim: 512
未匹配样本ID: {'SRR9666981', 'SRR9666855', 'SRR9666852', 'SRR9666976', 'SRR9666949', 'SRR9666805', 'SRR9666825', 'SRR9666973'}
数据加载验证通过
微生物特征维度验证: 2503
MLP输入层维度: 2503
fMRI特征维度验证: 40000
GNN期望输入维度: 40000
GNN feature dim: 512


  best_model.load_state_dict(torch.load("/home/yangzongxian/xlz/ASD_GCN/main/pre_train/adversarial.pth"))


测试准确率（最佳模型）: 65.45%
未匹配样本ID: {'SRR9666981', 'SRR9666855', 'SRR9666852', 'SRR9666976', 'SRR9666949', 'SRR9666805', 'SRR9666825', 'SRR9666973'}
数据加载验证通过


  model.load_state_dict(torch.load("/home/yangzongxian/xlz/ASD_GCN/main/down/sparse_mlp.pth"))



微生物组重要特征 Top 100
1. Unclassified (OTU-53ce252658bde2080c0ea32fd9c4554d) (21.3102)
2. Unclassified (OTU-feee161174b5fc83306033c80168aa6c) (21.2928)
3. Unclassified (OTU-0c87c8c664745907a533973b395c37bb) (21.2827)
4. Unclassified (OTU-c0604b5eb4d65e2224470fac9eb9ac41) (21.2796)
5. Unclassified (OTU-9ef6efd0e49ef5e1efe3613b0e427312) (21.2638)
6. Unclassified (OTU-4663e5f7d08dbc76e525d3b28030712c) (21.2014)
7. Unclassified (OTU-140fd4635b672501cf3ec9e6594b6d8e) (21.1887)
8. Unclassified (OTU-29c4bf94133271c19eee54d4561cfc2a) (21.1640)
9. Unclassified (OTU-5fdacbfd74abb06f94f87f7c539c6246) (21.1459)
10. Unclassified (OTU-f629c1c70c26ae5e8c4f061251e0e020) (21.1371)
11. Unclassified (OTU-63521cc8c33ebf066b2d8c3fc5c210f3) (21.1357)
12. Unclassified (OTU-60f0a102e1b3118286adc774264b5b7d) (21.1335)
13. Unclassified (OTU-429799da351a899f9d472ef9e28c1028) (21.1320)
14. Unclassified (OTU-7fa2be302641261c9d6bbaba04331945) (21.1302)
15. Unclassified (OTU-dc14fe8fe01091477545525b8265b0f7) (21.1271)
1

In [6]:
def get_top_microbes_signed(model, feature_names, top_k=100):
    first_layer = model.feature_extractor[0]
    weights = first_layer.weight.data.cpu().numpy()
    importance_scores = weights.sum(axis=0)  # 符号和
    sorted_indices = np.argsort(np.abs(importance_scores))[::-1]
    print("\nTop", top_k, "microbes (signed importance):")
    for i in range(top_k):
        idx = sorted_indices[i]
        score = importance_scores[idx]
        direction = "过表达" if score > 0 else "欠表达"
        print(f"{i+1}. {feature_names[idx]} ({score:.4f}, {direction} in ASD)")

In [8]:
get_top_microbes_signed(mlp_model, feature_names, top_k = 100)


Top 100 microbes (signed importance):
1. Unclassified (OTU-27702ffe2bcdd63a9e9bb27d55cc77b0) (-2.0175, 欠表达 in ASD)
2. Unclassified (OTU-2a0fd49b9d4e82bca347ba4d87539250) (1.8943, 过表达 in ASD)
3. Unclassified (OTU-d6b84b115708da78933a7b5650f82e35) (1.8503, 过表达 in ASD)
4. Unclassified (OTU-c0604b5eb4d65e2224470fac9eb9ac41) (1.8338, 过表达 in ASD)
5. Unclassified (OTU-648f305e9fe2e4c2b34fb139f01ecfe9) (1.6676, 过表达 in ASD)
6. Unclassified (OTU-5160485a724e51fcd3e4ebe658012aa8) (1.6552, 过表达 in ASD)
7. Unclassified (OTU-6db42349da85d5d75ac387674377482c) (-1.6221, 欠表达 in ASD)
8. Unclassified (OTU-b1f2adb7ce649a26acfe88c4b1f9b37e) (1.5391, 过表达 in ASD)
9. Unclassified (OTU-74ed22fa0968e4ae74b21689790bd405) (-1.5170, 欠表达 in ASD)
10. Unclassified (OTU-372868c395b5f16f7c05413fa1475251) (-1.5032, 欠表达 in ASD)
11. Unclassified (OTU-c700d4fbc650299a2151329fb490f975) (-1.4875, 欠表达 in ASD)
12. Unclassified (OTU-b6ebcc56b4bccfd9f259c274bd121479) (-1.4620, 欠表达 in ASD)
13. Unclassified (OTU-25dddcc9adcba42cb0

In [None]:
from scipy.stats import mannwhitneyu
asd_indices = np.where(microbe_train_labels == 1)[0]
control_indices = np.where(microbe_train_labels == 0)[0]
p_values = []
for i in range(microbe_train_features.shape[1]):
    asd_values = microbe_train_features[asd_indices, i]
    control_values = microbe_train_features[control_indices, i]
    stat, p = mannwhitneyu(asd_values, control_values, alternative='two-sided')
    p_values.append(p)