In [20]:
import torch
from torch_cluster import radius_graph


def remove_subset_edges(main_edge_index, subset_edge_index):
# 将边对转换为集合进行操作
    main_edges = set(map(tuple, main_edge_index.t().tolist()))
    subset_edges = set(map(tuple, subset_edge_index.t().tolist()))

    # 从 main_edges 中移除 subset_edges
    filtered_edges = main_edges - subset_edges

    # 转换回 tensor 格式
    filtered_edge_index = torch.tensor(list(filtered_edges), dtype=torch.long, device=main_edge_index.device).t()
    return filtered_edge_index
def multi_radius_graph(x, batch, radii):
    """
    根据多个半径阈值从节点坐标和批次信息中生成多级别的边集合。
    只调用一次 radius_graph，然后根据距离进行分级筛选。

    Args:
        x (Tensor): 节点坐标, [N, D]
        batch (LongTensor): 节点的批次信息, [N]
        radii (list or tuple): 递增顺序的半径列表 (例如 [1.35, 1.7, 2.5])

    Returns:
        edges_list (list of Tensors): 对应 radii 阈值范围划分的边集列表
    """

    # 确保 radii 已经排序
    radii = sorted(radii)
    r_max = radii[-1]

    # 使用最大半径一次性构建边
    edge_index_full = radius_graph(x, r=r_max, batch=batch, flow='source_to_target')

    # 计算每条边的距离
    row, col = edge_index_full
    # 假设 x 为 (N, D)
    # 计算欧式距离的平方（避免调用 sqrt 提高性能）
    diff = x[row] - x[col]
    dist_sq = (diff * diff).sum(dim=-1)  # dist^2

    # 将半径也转为平方方便比较
    radii_sq = [r*r for r in radii]

    edges_list = []
    # 上一个半径区间的上界
    prev_r_sq = 0.0  
    for r_sq in radii_sq:
        # 选取 dist_sq 在 (prev_r_sq, r_sq] 区间内的边
        # 如果希望第一个区间包括从0到r1的所有边，那么 prev_r_sq可设0
        mask = (dist_sq <= r_sq) & (dist_sq > prev_r_sq)
        edges_list.append(edge_index_full[:, mask])
        prev_r_sq = r_sq

    return edges_list

def normalize_edges(edges):
    edges = edges.clone()
    # 确保每条边(i,j)中 i<j
    sorted_edges = torch.sort(edges, dim=0)[0]
    # (2,E) -> (E,2)
    sorted_edges = sorted_edges.t()
    arr = sorted_edges.cpu().numpy()
    # 使用numpy lexsort进行双列排序
    idx = np.lexsort((arr[:,1], arr[:,0]))
    arr = arr[idx]
    return torch.from_numpy(arr).to(edges.device)
if __name__ == "__main__":
    import numpy as np
    torch.manual_seed(42)

    # 模拟数据
    N = 50
    x = torch.randn(N, 3)
    batch = torch.randint(0, 3, (N,))  # 3个batch
    radii = [1.35, 1.7, 2.5]

    # 方法一：multi_radius_graph
    edge_index_levels_multi = multi_radius_graph(x, batch, radii)
    edge_index_2p7_multi, edge_index_3p4_multi, edge_index_4p9_multi = edge_index_levels_multi

    # 方法二：多次 radius_graph + remove_subset_edges
    edge_index_2p7 = radius_graph(x, r=1.35, batch=batch, flow='source_to_target')
    edge_index_3p4 = radius_graph(x, r=1.7, batch=batch, flow='source_to_target')
    edge_index_4p9 = radius_graph(x, r=2.5, batch=batch, flow='source_to_target')
    edge_index_4p9 = remove_subset_edges(edge_index_4p9, edge_index_3p4)
    edge_index_3p4 = remove_subset_edges(edge_index_3p4, edge_index_2p7)

    # 标准化
    norm_2p7_multi = normalize_edges(edge_index_2p7_multi)
    norm_2p7 = normalize_edges(edge_index_2p7)
    norm_3p4_multi = normalize_edges(edge_index_3p4_multi)
    norm_3p4 = normalize_edges(edge_index_3p4)
    norm_4p9_multi = normalize_edges(edge_index_4p9_multi)
    norm_4p9 = normalize_edges(edge_index_4p9)

    # 检查两种方法输出是否在集合意义上一致
    eq_2p7 = torch.equal(norm_2p7_multi, norm_2p7)
    eq_3p4 = torch.equal(norm_3p4_multi, norm_3p4)
    eq_4p9 = torch.equal(norm_4p9_multi, norm_4p9)

    if eq_2p7 and eq_3p4 and eq_4p9:
        print("测试通过：两种方法的结果在集合意义上等价。")
    else:
        print("测试失败：存在边集合不一致的情况。")
        if not eq_2p7:
            print("半径1.35对应的边集不一致")
            print("multi:\n", norm_2p7_multi)
            print("独立构图:\n", norm_2p7)
        if not eq_3p4:
            print("1.35~1.7对应的边集不一致")
            print("multi:\n", norm_3p4_multi)
            print("独立构图+去重:\n", norm_3p4)
        if not eq_4p9:
            print("1.7~2.5对应的边集不一致")
            print("multi:\n", norm_4p9_multi)
            print("独立构图+去重:\n", norm_4p9)

测试通过：两种方法的结果在集合意义上等价。


In [19]:
import numpy as np
import torch
def connect_within_batch(x, virtual_mask, batch, mask_ligand):
    """
    建立同一批次中分子与分子、虚拟原子与虚拟原子、分子与虚拟原子之间的连接。
    """

    device = x.device
    non_protein_mask = virtual_mask | mask_ligand
    non_protein_indices = torch.where(non_protein_mask)[0]
    unique_batches = batch[non_protein_indices].unique()
    edge_index = []
    for b in unique_batches:
        nodes_in_batch = non_protein_indices[batch[non_protein_indices] == b]
        if nodes_in_batch.size(0) > 1:
            edges = torch.combinations(nodes_in_batch, r=2)
            edge_index.append(edges)
            edge_index.append(edges.flip(dims=[1]))
    edge_index = torch.cat(edge_index, dim=0).t().to(device)
    return edge_index
def connect_within_batch_optimized(x, virtual_mask, batch, mask_ligand):
    """
    优化后的函数：建立同一批次中分子与分子、虚拟原子与虚拟原子、分子与虚拟原子之间的连接。
    
    所有计算均在 GPU (CUDA) 上进行。
    
    Args:
        x: 输入特征张量 (在 GPU 上)
        virtual_mask: 虚拟原子的掩码 (在 GPU 上)
        batch: 批次索引，用于区分不同批次的原子 (N_protein,) (在 GPU 上)
        mask_ligand: 配体掩码，用于区分配体原子 (在 GPU 上)
    
    Returns:
        edge_index: 边索引张量，形状为 (2, E)，在 GPU 上
    """
    device = x.device
    # 确保所有输入张量都在 GPU 上
    # 假设在函数外部已经保证 x, virtual_mask, batch, mask_ligand 均在 GPU 上
    # 如果不确定，可使用以下语句（根据需要取消注释）：
    # virtual_mask = virtual_mask.to(device)
    # mask_ligand = mask_ligand.to(device)
    # batch = batch.to(device)
    
    non_protein_mask = virtual_mask | mask_ligand
    non_protein_indices = torch.where(non_protein_mask)[0]  # 位于 GPU 上
    
    # 如果非蛋白原子数量小于等于1，则无需建立边
    if non_protein_indices.size(0) <= 1:
        return torch.empty((2, 0), dtype=torch.long, device=device)
    
    batch_non_protein = batch[non_protein_indices]  # 在 GPU 上
    N = batch_non_protein.size(0)
    
    # 在 GPU 上生成所有上三角（不含对角线）的索引对
    idx_i, idx_j = torch.triu_indices(N, N, offset=1, device=device)  # 直接在 GPU 上生成
    # 筛选出同一批次内的原子对
    same_batch_mask = (batch_non_protein[idx_i] == batch_non_protein[idx_j])
    
    # 获取符合条件的原子索引对
    selected_idx_i = idx_i[same_batch_mask]
    selected_idx_j = idx_j[same_batch_mask]
    
    edges = torch.stack([
        non_protein_indices[selected_idx_i], 
        non_protein_indices[selected_idx_j]
    ], dim=0)
    
    # 生成双向边
    edges = torch.cat([edges, edges.flip(dims=[0])], dim=1)
    return edges

def normalize_edges(edges):
    # edges 形状为 (2, E) 或 (E, 2)
    # 确保 edges 是 (E,2)
    if edges.size(0) == 2:
        edges = edges.t()

    # 对每条边内排序，使 (i, j) 满足 i <= j
    sorted_edges = torch.sort(edges, dim=1)[0]

    # 转换到 CPU，使用 numpy 进行多列排序
    arr = sorted_edges.cpu().numpy()
    # 使用 np.lexsort：np.lexsort((keys...)) 按最后一个 key 先排序
    # 我们想按第一列再第二列排序，所以 keys 的顺序需要是 (第二列, 第一列)
    idx = np.lexsort((arr[:,1], arr[:,0]))

    arr = arr[idx]
    sorted_edges = torch.from_numpy(arr).to(edges.device)

    return sorted_edges

if __name__ == "__main__":
    torch.manual_seed(42)
    device = 'cpu'

    # 模拟数据
    N = 10
    x = torch.randn(N, 3, device=device)              # 节点特征
    batch = torch.tensor([0,0,0,0,1,1,1,2,2,2], device=device)  # batch索引
    virtual_mask = torch.zeros(N, dtype=torch.bool, device=device)
    mask_ligand = torch.zeros(N, dtype=torch.bool, device=device)

    # 随机指定一些原子为virtual和ligand以测试功能
    virtual_mask[2] = True
    virtual_mask[5] = True
    mask_ligand[6] = True
    mask_ligand[8] = True

    # 调用两个函数
    edge_orig = connect_within_batch(x, virtual_mask, batch, mask_ligand)
    edge_opt = connect_within_batch_optimized(x, virtual_mask, batch, mask_ligand)

    # 标准化
    norm_orig = normalize_edges(edge_orig)
    norm_opt = normalize_edges(edge_opt)

    # 检查无序条件下是否相同
    if torch.equal(norm_orig, norm_opt):
        print("测试通过：两种实现结果在集合意义上相同。")
    else:
        print("测试失败：两种实现结果存在差异。")
        print("原始结果：\n", norm_orig)
        print("优化结果：\n", norm_opt)


测试通过：两种实现结果在集合意义上相同。


In [17]:
import torch
import torch_scatter

class ProteinAggregator:
    def aggregate_to_virtual_center_by_batch(self, h_protein, group_indices, virtual_mask, batch):
        # 原始版本函数
        updated_h_protein = h_protein.clone()
        unique_batches = torch.unique(batch)
        for b in unique_batches:
            batch_mask = (batch == b)
            batch_group_indices = group_indices[batch_mask]
            batch_virtual_mask = virtual_mask[batch_mask]
            batch_h_protein = h_protein[batch_mask]

            valid_clusters = torch.unique(batch_group_indices)
            valid_clusters = valid_clusters[valid_clusters != -1]

            for cluster in valid_clusters:
                cluster_indices = (batch_group_indices == cluster).nonzero(as_tuple=True)[0]
                virtual_center_idx = cluster_indices[batch_virtual_mask[cluster_indices]][0]

                real_atom_indices = cluster_indices[~batch_virtual_mask[cluster_indices]]
                cluster_features_mean = batch_h_protein[real_atom_indices].mean(dim=0)
                updated_h_protein[virtual_center_idx] += cluster_features_mean
        return updated_h_protein

    def aggregate_to_virtual_center_by_batch_optimized(self, h_protein, group_indices, virtual_mask, batch):
        # 优化版本函数（需要确保 group_indices >= 0 且从0开始计数）
        valid_mask = group_indices != -1
        if valid_mask.sum() == 0:
            return h_protein.clone()

        combined = batch * (group_indices.max()+1) + group_indices
        valid_combined = combined[valid_mask]
        real_mask = valid_mask & ~virtual_mask

        unique_groups, inverse_indices = torch.unique(valid_combined, sorted=True, return_inverse=True)
        num_unique_groups = unique_groups.size(0)

        real_indices = torch.nonzero(real_mask, as_tuple=False).squeeze(1)
        real_combined = combined[real_indices]
        real_inverse = torch.searchsorted(unique_groups, real_combined)
        real_features = h_protein[real_indices]

        sum_features = torch_scatter.scatter_add(real_features, real_inverse, dim=0, dim_size=num_unique_groups)
        counts = torch_scatter.scatter_add(torch.ones_like(real_inverse, dtype=torch.float32), real_inverse, dim=0, dim_size=num_unique_groups)
        group_mean = sum_features / counts.unsqueeze(1).clamp(min=1.0)

        virtual_mask_valid = valid_mask & virtual_mask
        if virtual_mask_valid.sum() == 0:
            return h_protein.clone()

        virtual_indices = torch.nonzero(virtual_mask_valid, as_tuple=False).squeeze(1)
        virtual_combined = combined[virtual_indices]
        virtual_inverse = torch.searchsorted(unique_groups, virtual_combined)

        updated_h_protein = h_protein.clone()
        updated_h_protein[virtual_indices] += group_mean[virtual_inverse]

        return updated_h_protein


# 测试用例
if __name__ == "__main__":
    # 设置随机种子确保可重复性
    torch.manual_seed(42)

    aggregator = ProteinAggregator()

    N = 20  
    feature_dim = 5
    num_batches = 3
    h_protein = torch.randn(N, feature_dim)

    batch = torch.tensor([0,0,0,0,0,0,1,1,1,1,0,1,2,2,2,2,2,2,2,2])


    group_indices = torch.full((N,), -1, dtype=torch.long)


    group_indices[0] = 0; group_indices[1] = 0; group_indices[2] = 0  # batch0 cluster0
    group_indices[3] = 1; group_indices[4] = 1; group_indices[5] = 1  # batch0 cluster1
    group_indices[6] = 0; group_indices[7] = 0; group_indices[8] = 0; group_indices[9] = 0  # batch1 cluster0
    # 10,11,12为batch1的虚拟原子，不属于任何簇（-1）
    group_indices[13] = 0; group_indices[14] = 0; group_indices[15] = 0  # batch2 cluster0
    group_indices[16] = 1; group_indices[17] = 1; group_indices[18] = 1; group_indices[19] = 1 # batch2 cluster1


    virtual_mask = torch.zeros(N, dtype=torch.bool)
    virtual_mask[2] = True
    virtual_mask[5] = True
    virtual_mask[9] = True
    virtual_mask[15] = True
    virtual_mask[19] = True
    h_protein[virtual_mask] = 0.0
    # 再定义mask_ligand，随便定义几个配体原子（不影响聚合结果）
    mask_ligand = torch.zeros(N, dtype=torch.bool)
    mask_ligand[10] = True  # 虚拟、配体
    mask_ligand[11] = True
    mask_ligand[12] = True

    # 调用原始版本与优化版本
    res_original = aggregator.aggregate_to_virtual_center_by_batch(h_protein, group_indices, virtual_mask, batch)
    res_optimized = aggregator.aggregate_to_virtual_center_by_batch_optimized(h_protein, group_indices, virtual_mask, batch)

    # 检查结果是否一致
    # 使用allclose以允许浮点数上的微小差别
    if torch.allclose(res_original, res_optimized, atol=1e-3):
        print("测试通过：两种实现结果基本一致。")
    else:
        print("测试失败：两种实现结果存在差异。")
        print("原始结果：", res_original)
        print("优化结果：", res_optimized)


测试失败：两种实现结果存在差异。
原始结果： tensor([[ 1.9269e+00,  1.4873e+00,  9.0072e-01, -2.1055e+00,  6.7842e-01],
        [-1.2345e+00, -4.3067e-02, -1.6047e+00, -7.5214e-01,  1.6487e+00],
        [ 3.4619e-01,  7.2211e-01, -3.5197e-01, -1.4288e+00,  1.1636e+00],
        [ 6.3276e-01,  4.1151e+00, -8.8044e-01, -4.5107e-01, -3.8814e-01],
        [-7.5813e-01,  1.0783e+00,  8.0080e-01,  1.6806e+00,  1.2791e+00],
        [ 2.1571e-03,  1.3603e+00,  3.2060e-01,  5.9161e-01,  8.5936e-01],
        [-2.5158e-01,  8.5986e-01, -1.3847e+00, -8.7124e-01, -2.2337e-01],
        [ 1.9780e+00,  4.4240e-01, -9.3316e-01,  3.0874e-01, -3.7001e-01],
        [-1.5576e+00,  9.9564e-01, -8.7979e-01, -6.0114e-01, -1.2742e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 7.8024e-02,  5.2581e-01, -4.8799e-01,  1.1914e+00, -8.1401e-01],
        [-7.3599e-01, -1.4032e+00,  3.6004e-02, -6.3477e-02,  6.7561e-01],
        [-9.7807e-02,  1.8446e+00, -1.1845e+00,  1.3835e+00,  1.4451e+00],
  