In [1]:
from torch_geometric.data import DataLoader
import torch
from torch_geometric.data import Data
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split

from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as tg_nn
import h5py
import torch_geometric.transforms as T
from torch_geometric.data import Batch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
graph_type = "cc200"

class fMRIDataLoader:
    def __init__(self, file_path, graph_type, test_size=0.15, val_size=0.15, batch_size=32):
        self.file_path = file_path
        self.graph_type = graph_type
        self.test_size = test_size
        self.val_size = val_size
        self.batch_size = batch_size
        self.data_splits = self._load_and_split_data()
    
    def _load_and_split_data(self):
        
        graph_dataset = []
        labels = []
        with h5py.File(self.file_path, "r") as f:
            patients_group = f["/patients"]
            for subject_id in patients_group.keys():
                subject_group = patients_group[subject_id]
                if self.graph_type in subject_group and "y" in subject_group.attrs:
                    triu_vector = subject_group[self.graph_type][:]
                    matrix = reconstruct_fc(triu_vector)
                    # 生成40000维输入特征（展平整个矩阵）
                    node_features = matrix.flatten()  # 形状 (200x200=40000,)
                    edge_index = self._get_brain_connectivity_edges(matrix)
                    try:
                        label_value = subject_group.attrs["y"]
                        if isinstance(label_value, (int, np.number)):
                            graph_data = Data(
                                x=torch.FloatTensor(node_features).view(1, -1),  # 直接使用展平后的矩阵
                                edge_index=edge_index,
                                y=torch.tensor([label_value], dtype=torch.long)
                            )
                            graph_dataset.append(graph_data)
                            labels.append(label_value)
                    except KeyError:
                        pass
                        
                else:
                    print(f"Warning: Subject {subject_id} missing {self.graph_type} or label.")
        
        # Data splitting logic remains the same
        train_val_data, test_data, train_val_labels, test_labels = train_test_split(
            graph_dataset, labels, test_size=self.test_size, random_state=42
        )
        train_data, val_data, train_labels, val_labels = train_test_split(
            train_val_data, train_val_labels, test_size=self.val_size/(1-self.test_size), random_state=42
        )
        
        return {
            "train": (train_data, train_labels),
            "valid": (val_data, val_labels),
            "test": (test_data, test_labels),
        }
    
    def _get_brain_connectivity_edges(self, matrix, threshold=0.3):
        """生成边索引（优化版本）"""
        # 创建全连接（考虑对称性）
        rows, cols = np.triu_indices_from(matrix, k=1)
        mask = matrix[rows, cols] > threshold
        edge_index = np.array([rows[mask], cols[mask]])
        
        # 添加反向边
        edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1)
        
        return torch.tensor(edge_index, dtype=torch.long)

    def _create_dataloader(self, data, labels):
        return DataLoader(
            data, 
            batch_size=self.batch_size,
            shuffle=True
        )

    '''def _create_dataloader(self, data, labels):
        return DataLoader(
            data, 
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=lambda batch: (
                torch.stack([d.x.view(1, 200, 2) for d in batch]),  # 处理特征
                torch.cat([d.y for d in batch])  # 正确拼接1维标签
            )
        )'''
    
    def get_dataloaders(self):
        return {
            "train": self._create_dataloader(*self.data_splits["train"]),
            "valid": self._create_dataloader(*self.data_splits["valid"]),
            "test": self._create_dataloader(*self.data_splits["test"]),
        }

    def get_num_classes(self):
        return len(set(self.data_splits["train"][1]))

class fMRI3DGNN(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        self.graph_builder = nn.Sequential(
            nn.Linear(40000, 200*200),  # 将40000维输入转换为200x200矩阵
            nn.Sigmoid()
        )
        # 修改后的卷积层定义
        self.convs = nn.ModuleList([
            # 第一层GAT：输入特征维度需与增强后的节点特征匹配
            tg_nn.GATv2Conv(
                in_channels=16,  # 修改为特征增强后的维度
                out_channels=128,
                heads=8,
                dropout=config['dropout'],
                add_self_loops=False
            ),
            # 第二层GAT
            tg_nn.GATv2Conv(
                in_channels=128*8,  # 多头注意力的输出维度
                out_channels=256,
                heads=4,
                dropout=config['dropout']
            ),
            # 第三层GAT
            tg_nn.GATv2Conv(
                in_channels=256*4,
                out_channels=512,
                heads=1,
                dropout=config['dropout']
            )
        ])
        
        # 新增特征增强层
        self.feature_enhancer = nn.Sequential(
            nn.Linear(2, 8),  # 扩展节点特征维度
            nn.GELU(),
            nn.Linear(8, 16),
            nn.LayerNorm(16)
        )

        # 更新分类器输入维度
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(256, config['num_classes'])
        )
    def build_graph(self, fc_matrix):
        """修正后的图构建方法"""
        batch_size = fc_matrix.size(0)
        
        # 通过graph_builder生成邻接矩阵
        adj = self.graph_builder(fc_matrix).view(batch_size, 200, 200).float()
        adj = (adj + adj.transpose(1,2)) / 2  # 确保对称性

        # 生成增强节点特征（维度验证）
        node_features = []
        edge_indices = []
        for b in range(batch_size):
            # 基础统计特征
            
            means = adj[b].mean(dim=1, keepdim=True)  # (200,1)
            stds = adj[b].std(dim=1, keepdim=True)    # (200,1)
            base_feat = torch.cat([means, stds], dim=1)  # (200,2)
            
            # 特征增强（输出维度16）
            enhanced_feat = self.feature_enhancer(base_feat)  # (200,16)
            
            # 动态边生成（带阈值限制）
            assert adj[b].dtype == torch.float32, f"邻接矩阵数据类型错误: {adj[b].dtype}"
            threshold = torch.quantile(adj[b].flatten(), 0.75)
            mask = (adj[b] > threshold).float()
            row, col = mask.nonzero(as_tuple=False).t()
            edge_index = torch.stack([row, col], dim=0)
            
            node_features.append(enhanced_feat)
            edge_indices.append(edge_index)

        return Batch.from_data_list([
            Data(x=feat, edge_index=edge) 
            for feat, edge in zip(node_features, edge_indices)
        ])
    

    def _generate_adaptive_edges(self, node_feat):
        """动态生成边连接"""
        # 空间约束
        spatial_dist = torch.cdist(self.spatial_emb.weight, self.spatial_emb.weight)
        
        # 特征相似性
        feat_sim = torch.mm(node_feat, node_feat.t())
        
        # 综合边权重
        combined = (feat_sim * (1 / (spatial_dist + 1e-6)))
        
        # 生成邻接矩阵
        adj = (combined > self.threshold).float()
        
        # 确保最小连接数
        topk = torch.topk(combined, self.k_neighbors, dim=1)
        adj[topk.indices] = 1.0
        
        return adj
    def forward(self, raw_fc):
        # 输入应为 [batch_size, 40000]
        if raw_fc.dim() == 1:
            raw_fc = raw_fc.unsqueeze(0)  # 添加批次维度 [1, 40000]
        assert raw_fc.dim() == 2, f"输入应为二维张量 [batch, 40000]，当前维度：{raw_fc.dim()}"
        batch_size = raw_fc.size(0)
        
        # 构建动态图
        batch_graph = self.build_graph(raw_fc)
        
        # 图卷积处理
        x = batch_graph.x  # [batch*200, 16] (经过特征增强)
        edge_index = batch_graph.edge_index
        
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.gelu(x)
            x = F.dropout(x, training=self.training)
        
        # 全局池化
        x = tg_nn.global_mean_pool(x, batch_graph.batch)  # [batch, hidden_dim]
        return self.classifier(x)
    
    def classify(self, x):
        return self.classifier(x)
    
    def _adjust_model_parameters(self, new_dim):
        """动态调整模型参数"""
        # 调整空间嵌入维度
        old_emb = self.spatial_emb
        self.spatial_emb = nn.Embedding(new_dim, 3)
        with torch.no_grad():
            min_dim = min(old_emb.weight.size(0), new_dim)
            self.spatial_emb.weight[:min_dim] = old_emb.weight[:min_dim]
        
        # 调整图卷积层输入维度
        if self.convs[0].in_channels != new_dim:
            first_conv = self.convs[0]
            new_conv = tg_nn.GATv2Conv(
                new_dim, 
                first_conv.out_channels,
                heads=first_conv.heads,
                dropout=first_conv.dropout
            )
            self.convs[0] = new_conv

def load_pretrained_gnn():
    GNN_CONFIG = {
        "gnn_layers": 3,
        "hidden_channels": 128,
        "num_classes": 2,
        "dropout": 0.4
    }
    model = fMRI3DGNN(GNN_CONFIG)
    checkpoint = torch.load("fmri_gnn.pth", map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model.to(device)

def reconstruct_fc(vector):
    """将上三角向量重建为对称矩阵"""
    # 创建空矩阵
    matrix = np.zeros((200, 200))
    # 提取上三角索引（不包括对角线）
    triu_indices = np.triu_indices(200, k=1)
    # 填充上三角
    matrix[triu_indices] = vector
    # 对称复制到下三角
    matrix = matrix + matrix.T - np.diag(matrix.diagonal())    
    return matrix

def get_top_rois(model, dataloader, device, top_k=100, roi_names=None):
    model.eval()
    gradients = []
    
    progress = tqdm(dataloader, desc="计算ROI重要性", unit="batch")
    
    try:
        for batch in progress:
            # 解包批次数据为 (input_features, labels)
            x_batch = batch.x.to(device)
            print("输入维度检查:", x_batch.shape) # [batch_size, 40000]
            x_batch = x_batch.requires_grad_(True)  # [batch_size, 40000]
            
            y_batch = batch.y  # [batch_size]
            
            
            # 前向传播
            outputs = model(x_batch)  # 确保输入是 [batch_size, 40000]
            
            # 梯度计算
            grads = torch.autograd.grad(
                outputs.sum(),
                x_batch,
                retain_graph=False,
                create_graph=False,
                allow_unused=True
            )
            
            # 记录梯度
            gradients.append(grads[0].abs().mean(dim=0).cpu())
 
        # 合并梯度
        all_grads = torch.stack(gradients).mean(dim=0).numpy()
        
        # 转换为连接矩阵
        conn_matrix = all_grads.reshape(200, 200)
        roi_importance = conn_matrix.sum(0) + conn_matrix.sum(1)
        
        # 结果处理
        sorted_indices = np.argsort(roi_importance)[::-1]
        roi_names = roi_names or [f"ROI_{i+1:03d}" for i in range(200)]
        
        '''print(f"\nTop {top_k} fMRI ROIs:")
        for i in range(top_k):
            idx = sorted_indices[i]
            print(f"{i+1}. {roi_names[idx]} ({roi_importance[idx]:.4f})")
            
        return roi_importance, roi_names'''
        print(f"\nTop {top_k} fMRI ROIs (按索引显示):")
        for i in range(top_k):
            idx = sorted_indices[i]
            print(f"{i+1}. ROI_{idx+1:03d} 重要性分数: {roi_importance[idx]:.4f}")
            
        return roi_importance
    
    except Exception as e:
        print(f"分析失败: {str(e)}")
        return None, None

# 使用示例 --------------------------------------------------
if __name__ == "__main__":
    # 加载预训练模型
    gnn_model = load_pretrained_gnn().to(device)
    file_path = "/home/yangzongxian/xlz/ASD_GCN/main/data2/abide.hdf5"
    graph_type = "cc200"
    fmri_loader = fMRIDataLoader(file_path = file_path, graph_type="cc200", batch_size=32)
    cc200_names = [
        "Precentral_L", "Precentral_R",  # 实际应使用完整的200个名称
        # ... 补充完整名称列表
    ]
    
    # 获取测试集数据加载器
    fmri_train_loader = fmri_loader.get_dataloaders()["train"]
    fmri_val_loader = fmri_loader.get_dataloaders()["valid"]
    fmri_test_loader = fmri_loader.get_dataloaders()["test"]
    
    # 运行分析
    importance_scores, roi_names = get_top_rois(
        model=gnn_model,
        dataloader=fmri_train_loader,
        device=device,
        top_k=100,
        roi_names=cc200_names
    )
    
    
    '''if importance_scores is not None:
        import matplotlib.pyplot as plt
        
        plt.figure(figsize=(12, 6))
        plt.bar(range(len(importance_scores)), importance_scores)
        plt.title("ROI Importance Distribution")
        plt.xlabel("ROI Index")
        plt.ylabel("Importance Score")
        plt.show()'''

  checkpoint = torch.load("fmri_gnn.pth", map_location=device)
计算ROI重要性:   0%|          | 0/23 [00:00<?, ?batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:   4%|▍         | 1/23 [00:01<00:42,  1.93s/batch]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:   9%|▊         | 2/23 [00:02<00:19,  1.08batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  13%|█▎        | 3/23 [00:03<00:18,  1.09batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  22%|██▏       | 5/23 [00:04<00:15,  1.14batch/s]

输入维度检查: torch.Size([32, 40000])
输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  26%|██▌       | 6/23 [00:06<00:18,  1.07s/batch]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  30%|███       | 7/23 [00:07<00:15,  1.04batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  35%|███▍      | 8/23 [00:07<00:10,  1.37batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  39%|███▉      | 9/23 [00:08<00:12,  1.09batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  48%|████▊     | 11/23 [00:11<00:13,  1.13s/batch]

输入维度检查: torch.Size([32, 40000])
输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  52%|█████▏    | 12/23 [00:12<00:11,  1.01s/batch]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  57%|█████▋    | 13/23 [00:12<00:08,  1.19batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  61%|██████    | 14/23 [00:13<00:08,  1.06batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  65%|██████▌   | 15/23 [00:14<00:08,  1.01s/batch]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  70%|██████▉   | 16/23 [00:16<00:07,  1.08s/batch]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  74%|███████▍  | 17/23 [00:17<00:06,  1.12s/batch]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  78%|███████▊  | 18/23 [00:17<00:04,  1.09batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  87%|████████▋ | 20/23 [00:18<00:01,  1.53batch/s]

输入维度检查: torch.Size([32, 40000])
输入维度检查: torch.Size([32, 40000])


计算ROI重要性:  91%|█████████▏| 21/23 [00:19<00:01,  1.26batch/s]

输入维度检查: torch.Size([32, 40000])


计算ROI重要性: 100%|██████████| 23/23 [00:20<00:00,  1.10batch/s]

输入维度检查: torch.Size([19, 40000])






Top 100 fMRI ROIs (按索引显示):
1. ROI_094 重要性分数: 0.2219
2. ROI_036 重要性分数: 0.2193
3. ROI_102 重要性分数: 0.2187
4. ROI_071 重要性分数: 0.2182
5. ROI_038 重要性分数: 0.2170
6. ROI_057 重要性分数: 0.2141
7. ROI_031 重要性分数: 0.2134
8. ROI_068 重要性分数: 0.2126
9. ROI_013 重要性分数: 0.2125
10. ROI_025 重要性分数: 0.2120
11. ROI_073 重要性分数: 0.2118
12. ROI_080 重要性分数: 0.2115
13. ROI_010 重要性分数: 0.2114
14. ROI_017 重要性分数: 0.2107
15. ROI_064 重要性分数: 0.2105
16. ROI_059 重要性分数: 0.2097
17. ROI_104 重要性分数: 0.2081
18. ROI_043 重要性分数: 0.2072
19. ROI_099 重要性分数: 0.2065
20. ROI_048 重要性分数: 0.2063
21. ROI_001 重要性分数: 0.2059
22. ROI_079 重要性分数: 0.2050
23. ROI_030 重要性分数: 0.2044
24. ROI_083 重要性分数: 0.2043
25. ROI_158 重要性分数: 0.2042
26. ROI_028 重要性分数: 0.2041
27. ROI_161 重要性分数: 0.2039
28. ROI_026 重要性分数: 0.2037
29. ROI_009 重要性分数: 0.2036
30. ROI_103 重要性分数: 0.2032
31. ROI_077 重要性分数: 0.2031
32. ROI_067 重要性分数: 0.2030
33. ROI_070 重要性分数: 0.2029
34. ROI_046 重要性分数: 0.2027
35. ROI_004 重要性分数: 0.2025
36. ROI_177 重要性分数: 0.2020
37. ROI_141 重要性分数: 0.2013
38. ROI_183 重要性分数: 

ValueError: too many values to unpack (expected 2)

In [2]:
def get_top_rois_signed(model, dataloader, device, top_k=100, roi_names=None):
    model.eval()
    gradients = []
    for batch in dataloader:
        x_batch = batch.x.to(device).requires_grad_(True)
        y_batch = batch.y.to(device)
        outputs = model(x_batch)
        # 假设二分类，计算ASD类（假设为1）的梯度
        grads = torch.autograd.grad(
            outputs[:, 1].sum(),  # ASD类的概率和
            x_batch,
            retain_graph=False,
            create_graph=False
        )[0]
        gradients.append(grads.mean(dim=0).cpu())  # 批次平均，符号保留
    all_grads = torch.stack(gradients).mean(dim=0).numpy()
    conn_matrix = all_grads.reshape(200, 200)
    roi_importance = conn_matrix.sum(0) + conn_matrix.sum(1)  # 每个ROI的连接梯度和
    sorted_indices = np.argsort(np.abs(roi_importance))[::-1]
    roi_names = roi_names or [f"ROI_{i+1:03d}" for i in range(200)]
    print(f"\nTop {top_k} fMRI ROIs (signed importance):")
    for i in range(top_k):
        idx = sorted_indices[i]
        score = roi_importance[idx]
        direction = "高连接性" if score > 0 else "低连接性"
        print(f"{i+1}. {roi_names[idx]} ({score:.4f}, {direction} in ASD)")
    return roi_importance, roi_names

In [3]:
gnn_model = load_pretrained_gnn().to(device)
gnn_model = load_pretrained_gnn().to(device)
file_path = "/home/yangzongxian/xlz/ASD_GCN/main/data2/abide.hdf5"
graph_type = "cc200"
fmri_loader = fMRIDataLoader(file_path = file_path, graph_type="cc200", batch_size=32)
fmri_train_loader = fmri_loader.get_dataloaders()["train"]
fmri_val_loader = fmri_loader.get_dataloaders()["valid"]
fmri_test_loader = fmri_loader.get_dataloaders()["test"]

get_top_rois_signed(gnn_model, fmri_train_loader, device, top_k = 100, roi_names=None)

  checkpoint = torch.load("fmri_gnn.pth", map_location=device)



Top 100 fMRI ROIs (signed importance):
1. ROI_078 (-0.5556, 低连接性 in ASD)
2. ROI_077 (-0.5048, 低连接性 in ASD)
3. ROI_070 (-0.4873, 低连接性 in ASD)
4. ROI_003 (0.4195, 高连接性 in ASD)
5. ROI_038 (0.4012, 高连接性 in ASD)
6. ROI_105 (-0.3547, 低连接性 in ASD)
7. ROI_074 (0.3320, 高连接性 in ASD)
8. ROI_056 (0.3180, 高连接性 in ASD)
9. ROI_018 (0.3104, 高连接性 in ASD)
10. ROI_112 (-0.3073, 低连接性 in ASD)
11. ROI_151 (-0.3042, 低连接性 in ASD)
12. ROI_164 (-0.2995, 低连接性 in ASD)
13. ROI_076 (-0.2982, 低连接性 in ASD)
14. ROI_021 (-0.2939, 低连接性 in ASD)
15. ROI_023 (0.2882, 高连接性 in ASD)
16. ROI_122 (0.2860, 高连接性 in ASD)
17. ROI_050 (0.2818, 高连接性 in ASD)
18. ROI_120 (0.2785, 高连接性 in ASD)
19. ROI_194 (-0.2748, 低连接性 in ASD)
20. ROI_093 (0.2680, 高连接性 in ASD)
21. ROI_035 (0.2679, 高连接性 in ASD)
22. ROI_099 (0.2661, 高连接性 in ASD)
23. ROI_107 (-0.2601, 低连接性 in ASD)
24. ROI_072 (-0.2585, 低连接性 in ASD)
25. ROI_073 (0.2555, 高连接性 in ASD)
26. ROI_171 (0.2537, 高连接性 in ASD)
27. ROI_106 (-0.2531, 低连接性 in ASD)
28. ROI_055 (-0.2479, 低连接性 in ASD)
29.

(array([-1.11101106e-01,  8.69523734e-02,  4.19450223e-01,  1.97669297e-01,
        -2.72064507e-02, -1.10024229e-01,  1.62590012e-01, -3.79957855e-02,
        -6.04615994e-02, -6.33042585e-03, -1.54018074e-01,  1.16986372e-01,
        -1.08430833e-02,  9.13844705e-02,  1.39984697e-01,  7.01157898e-02,
         5.66040464e-02,  3.10375750e-01,  2.38823235e-01, -9.14340168e-02,
        -2.93935835e-01,  2.33978152e-01,  2.88170934e-01,  8.11973512e-02,
        -2.14589223e-01, -2.32651979e-01,  2.07617521e-01, -2.72190943e-02,
         4.89774086e-02, -1.90916628e-01,  6.58501536e-02,  2.07267888e-02,
        -1.80101488e-04,  2.03146115e-02,  2.67891377e-01, -1.76919878e-01,
         5.34295961e-02,  4.01191175e-01, -9.18172300e-02, -2.16218352e-01,
         2.39921845e-02, -1.27599031e-01, -1.15855806e-01, -6.44879490e-02,
         2.99030561e-02, -4.85641323e-02,  1.90229088e-01, -1.67647377e-01,
        -1.56337559e-01,  2.81822443e-01,  2.41671354e-01,  8.97357762e-02,
         1.2

In [None]:
import time
from Bio import Entrez
from Bio import Medline

Entrez.email = "xvlizhao@gmail.com"

brain_regions = [
    "Temporal", "Cerebelum", "Calcarine", "Precuneus", "Frontal",
    "Cingulum", "Parietal", "Thalamus", "Occipital"
]

# 清理微生物名称的空格并去重
microbes = list(set([m.strip() for m in [
    "Roseburia intestinalis","Roseburia hominis","Parabacteroides chongii","Parabacteroides faecis",
    "Parabacteroides timonensis","Ruminococcus torques", "Mediterraneibacter catenae",
    "Ruminococcus torques" ,"Butyricicoccus pullicaecorum","Butyricicoccus porcorum ",
    "Agathobaculum desmolans","Paraprevotella xylaniphila", "Paraprevotella xylaniphila",
    "Paraprevotella clara","Oribacterium sinus","Oribacterium parvum",
    "Oribacterium asaccharolyticum","Enterocloster homin","Lacrimispora indolis",
    "Kineothrix alysoides","Fusicatenibacter saccharivorans","Clostridium porci",
    "Lacrimispora amygdalina","Intestinimonas butyriciproducens","Intestinimonas timonensis",
    "Clostridium phoceensis"
]]))

def search_asd_studies(region, microbe):
    """专注搜索ASD领域的三重组合：脑区+微生物+ASD关键词"""
    asd_keywords = '(autism OR ASD OR "autism spectrum disorder")'
    query = f'({region}[Title/Abstract] AND {microbe}[Title/Abstract]) AND {asd_keywords}'  # 标题/摘要限定
    
    try:
        handle = Entrez.esearch(db="pubmed", term=query, retmax=2, sort="relevance")  # 限制2篇最相关结果
        record = Entrez.read(handle)
        handle.close()
        return record["IdList"]
    except Exception as e:
        print(f"Error searching {region} & {microbe}: {e}")
        return []

def fetch_paper_details(id_list):
    """获取论文详细信息"""
    if not id_list:
        return []
    try:
        handle = Entrez.efetch(db="pubmed", id=id_list, rettype="medline", retmode="text")
        records = list(Medline.parse(handle))
        handle.close()
        return records
    except Exception as e:
        print(f"Error fetching papers: {e}")
        return []
# 以脑区为主轴遍历
for region in brain_regions:
    region_save_path = f"ASD_BrainRegion.txt"  
    
    for microbe in microbes:
        time.sleep(1.5)  
        paper_ids = search_asd_studies(region, microbe)
        
        if not paper_ids:
            continue  
        
        papers = fetch_paper_details(paper_ids)
        with open(region_save_path, 'a') as f:  
            f.write(f"\n### {region} & {microbe} ###\n")
            
            for paper in papers:
                title = paper.get("TI", "No title")
                pmid = paper.get("PMID", "")
                url = f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" if pmid else "No URL"
                abstract = paper.get("AB", "No abstract available")[:500] + "..."  
                
              
                f.write(f"Region:{region}\nMicrobe:{microbe}\nTitle: {title}\nURL: {url}\nAbstract: {abstract}\n\n")
                print(f"Found in {region}: {title[:50]}...")
    
    print(f"Completed {region}. Saved to {region_save_path}")

In [1]:
import time
from Bio import Entrez
from Bio import Medline

Entrez.email = "xvlizhao@gmail.com"

# 参数配置
SAVE_FILE = "/home/yangzongxian/xlz/ASD_GCN/main/ASD_Brain_Gut_Studies_AllRegions.txt"  

BRAIN_REGIONS = [
    "Temporal", "Cerebellum", "Calcarine", "Precuneus", 
    "Frontal", "Cingulum", "Parietal", "Thalamus", "Occipital"
]
MICROBES = [
    "Roseburia intestinalis","Roseburia hominis","Parabacteroides chongii","Parabacteroides faecis",
    "Parabacteroides timonensis","Ruminococcus torques", "Mediterraneibacter catenae",
    "Ruminococcus torques" ,"Butyricicoccus pullicaecorum","Butyricicoccus porcorum ",
    "Agathobaculum desmolans","Paraprevotella xylaniphila", "Paraprevotella xylaniphila",
    "Paraprevotella clara","Oribacterium sinus","Oribacterium parvum",
    "Oribacterium asaccharolyticum","Enterocloster homin","Lacrimispora indolis",
    "Kineothrix alysoides","Fusicatenibacter saccharivorans","Clostridium porci",
    "Lacrimispora amygdalina","Intestinimonas butyriciproducens","Intestinimonas timonensis",
    "Clostridium phoceensis"
]



def search_asd_studies(region, microbe):
    """精准检索ASD相关研究：脑区+微生物+ASD关键词"""
    query = (
        f'("{region}"[Title/Abstract] OR "brain"[Title/Abstract]) '
        f'AND "{microbe}"[Title/Abstract] '
        'AND (autism OR ASD OR "autism spectrum disorder") '
        'AND ("human"[MeSH Terms] OR "clinical trial"[Publication Type])'  # 限制临床研究
    )
    try:
        handle = Entrez.esearch(db="pubmed", term=query, retmax=3, sort="relevance")
        results = Entrez.read(handle)
        handle.close()
        return results.get("IdList", [])
    except Exception as e:
        print(f"检索失败 {region}-{microbe}: {str(e)}")
        return []

def save_results(data):
    """统一保存结果到文件"""
    with open(SAVE_FILE, "a", encoding="utf-8") as f:
        f.write(data)

# 初始化输出文件
with open(SAVE_FILE, "w") as f:
    f.write("ASD脑-肠轴研究文献汇总\n\n")

# 遍历所有脑区与微生物组合
for idx, region in enumerate(BRAIN_REGIONS, 1):
    print(f"\n正在处理脑区 ({idx}/{len(BRAIN_REGIONS)}): {region}")
    
    for microbe in MICROBES:
        paper_ids = search_asd_studies(region, microbe)
        time.sleep(1.2)  # 遵守NCBI API速率限制
        
        if not paper_ids:
            continue
        
        # 获取文献详细信息
        papers = fetch_paper_details(paper_ids)  # 复用之前的fetch函数
        
        # 构建保存内容
        output = [
            f"\n### {region} & {microbe} ###",
            f"找到 {len(papers)} 篇相关文献:"
        ]
        
        for p in papers:
            title = p.get("TI", "无标题")
            pmid = p.get("PMID", "")
            authors = "; ".join(p.get("AU", []))[:50] + "..."  # 显示前50字符作者
            journal = p.get("TA", "未知期刊")
            year = p.get("DP", "未知年份").split()[0]  # 提取年份
            
            output.append(
                f"\n标题: {title}\n"
                f"PMID: {pmid}\n"
                f"链接: https://pubmed.ncbi.nlm.nih.gov/{pmid}/\n"
                f"作者: {authors}\n"
                f"期刊: {journal} ({year})"
            )
        
        # 保存结果并打印进度
        save_results("\n".join(output))
        print(f"│ 发现 {microbe} 的 {len(papers)} 篇文献")

print(f"\n所有结果已保存至: {SAVE_FILE}")


正在处理脑区 (1/9): Temporal

正在处理脑区 (2/9): Cerebellum

正在处理脑区 (3/9): Calcarine

正在处理脑区 (4/9): Precuneus

正在处理脑区 (5/9): Frontal

正在处理脑区 (6/9): Cingulum

正在处理脑区 (7/9): Parietal

正在处理脑区 (8/9): Thalamus

正在处理脑区 (9/9): Occipital

所有结果已保存至: /home/yangzongxian/xlz/ASD_GCN/main/ASD_Brain_Gut_Studies_AllRegions.txt
