In [1]:
import torch

# 假设我们有一个简单的例子
# 2帧，3个原子，嵌入维度为2
node_ebd = torch.tensor([
    # 第1帧
    [[1.0, 0.0],   # 原子0
     [2.0, 0.0],   # 原子1  
     [5.0, 0.0]],  # 原子2
    # 第2帧
    [[1.1, 0.1],   # 原子0
     [2.1, 0.1],   # 原子1
     [5.1, 0.1]]   # 原子2
])  # shape: [2, 3, 2]

# 邻居列表：原子0的邻居是原子1，原子1的邻居是原子0和2
nlist = torch.tensor([
    # 第1帧
    [[1, -1],      # 原子0的邻居：原子1
     [0, 2],       # 原子1的邻居：原子0, 2
     [1, -1]],     # 原子2的邻居：原子1
    # 第2帧  
    [[1, -1],      # 原子0的邻居：原子1
     [0, 2],       # 原子1的邻居：原子0, 2
     [1, -1]]      # 原子2的邻居：原子1
])  # shape: [2, 3, 2]

# 计算成对距离
node_ebd_i = node_ebd.unsqueeze(2)  # [2, 3, 1, 2]
node_ebd_j = node_ebd.unsqueeze(1)  # [2, 1, 3, 2]
distances = torch.norm(node_ebd_i - node_ebd_j, dim=-1)
print("成对距离矩阵:")
print(distances[0])  # 第1帧的距离矩阵
# tensor([[0.0000, 1.0000, 4.0000],  # 原子0到所有原子的距离
#         [1.0000, 0.0000, 3.0000],  # 原子1到所有原子的距离  
#         [4.0000, 3.0000, 0.0000]]) # 原子2到所有原子的距离

# 构建邻居掩码
neighbor_mask = torch.zeros_like(distances, dtype=torch.bool)
for frame_idx in range(2):
    for atom_idx in range(3):
        neighbors = nlist[frame_idx, atom_idx]
        valid_neighbors = neighbors[(neighbors >= 0) & (neighbors < 3)]
        if len(valid_neighbors) > 0:
            neighbor_mask[frame_idx, atom_idx, valid_neighbors] = True

print("邻居掩码:")
print(neighbor_mask[0])
# tensor([[False,  True, False],  # 原子0的邻居：原子1
#         [ True, False,  True],  # 原子1的邻居：原子0, 2
#         [False,  True, False]]) # 原子2的邻居：原子1

# 排除对角线
eye_mask = torch.eye(3, dtype=torch.bool).unsqueeze(0).expand(2, -1, -1)
valid_mask = ~eye_mask

# 计算MAD_neighbor（邻居间距离）
neighbor_mask_valid = neighbor_mask & valid_mask
neighbor_distances = distances[neighbor_mask_valid]
print("邻居间距离:", neighbor_distances)
# tensor([1.0000, 1.0000, 3.0000, 3.0000])  # 0-1, 1-0, 1-2, 2-1的距离
mad_neighbor = neighbor_distances.mean()
print("MAD_neighbor:", mad_neighbor)  # 2.0

# 计算MAD_remote（非邻居间距离）
remote_mask = ~neighbor_mask & valid_mask
remote_distances = distances[remote_mask]
print("非邻居间距离:", remote_distances)
# tensor([4.0000, 4.0000])  # 0-2, 2-0的距离
mad_remote = remote_distances.mean()
print("MAD_remote:", mad_remote)  # 4.0

# 计算MADGap
mad_gap = mad_remote - mad_neighbor
print("MADGap:", mad_gap)  # 2.0

成对距离矩阵:
tensor([[0., 1., 4.],
        [1., 0., 3.],
        [4., 3., 0.]])
邻居掩码:
tensor([[False,  True, False],
        [ True, False,  True],
        [False,  True, False]])
邻居间距离: tensor([1.0000, 1.0000, 3.0000, 3.0000, 1.0000, 1.0000, 3.0000, 3.0000])
MAD_neighbor: tensor(2.)
非邻居间距离: tensor([4., 4., 4., 4.])
MAD_remote: tensor(4.)
MADGap: tensor(2.)


In [25]:
import torch
import torch.nn.functional as F
# 假设我们有一个简单的例子
# 2帧，3个原子，嵌入维度为2
node_ebd = torch.tensor([
    # 第1帧
    [[1.0, 0.0],   # 原子0
     [2.0, 0.0],   # 原子1  
     [5.0, 0.0]],  # 原子2
    # 第2帧
    [[1.1, 0.1],   # 原子0
     [2.1, 0.1],   # 原子1
     [5.1, 0.1]]   # 原子2
])  # shape: [2, 3, 2]

# 邻居列表：原子0的邻居是原子1，原子1的邻居是原子0和2
nlist = torch.tensor([
    # 第1帧
    [[1, -1],      # 原子0的邻居：原子1
     [0, 2],       # 原子1的邻居：原子0, 2
     [1, -1]],     # 原子2的邻居：原子1
    # 第2帧  
    [[1, -1],      # 原子0的邻居：原子1
     [0, 2],       # 原子1的邻居：原子0, 2
     [1, -1]]      # 原子2的邻居：原子1
])  # shape: [2, 3, 2]
    # 计算向量范数 |Hi|, |Hj|
norms = torch.norm(node_ebd, p=2, dim=-1, keepdim=True)  # [nf, nloc, 1]
print("norms", norms)
# 计算点积矩阵 Hi · Hj
dot_products = torch.bmm(node_ebd, node_ebd.transpose(-1, -2))  # [nf, nloc, nloc]
print("dot_products", dot_products)
# 计算范数乘积矩阵 |Hi| * |Hj|
norm_products = torch.bmm(norms, norms.transpose(-1, -2))  # [nf, nloc, nloc]
print("norm_products", norm_products)
# 余弦距离 = 1 - 余弦相似度
cosine_dist = 1.0 - (dot_products / (norm_products + 1e-8))  # 加小值避免除零
print("cosine_dist", cosine_dist)

# 基于nlist创建neighbor mask
nframes, nloc, embed_dim = node_ebd.shape
device = node_ebd.device
print(nlist)
neighbor_mask = torch.zeros_like(cosine_dist, dtype=torch.bool, device=device)
print("neighbor_mask", neighbor_mask)
print("berghbor mask shape", neighbor_mask.shape)
# 利用现有的nlist填充neighbor关系
for frame_idx in range(nframes):
    for atom_idx in range(nloc):
        print("frame_idx", frame_idx)
        print("atom_idx", atom_idx)

        neighbors = nlist[frame_idx, atom_idx]  # [nnei]
        # 过滤有效邻居（排除padding的-1值）
        valid_neighbors = neighbors[(neighbors >= 0) & (neighbors < nloc)]
        if len(valid_neighbors) > 0:
            neighbor_mask[frame_idx, atom_idx, valid_neighbors] = True
print("neighbor_mask", neighbor_mask)


# 排除对角线元素（i=j的情况）
eye_mask = torch.eye(nloc, dtype=torch.bool, device=device).unsqueeze(0).expand(nframes, -1, -1)
neighbor_mask = neighbor_mask & (~eye_mask)  # M^neighbor_ii = 0 恒成立

# 创建远程目标掩码：M^remote = ¬M^neighbor ∧ ¬I (非邻居且非对角线)
remote_mask = (~neighbor_mask) & (~eye_mask)

# === 公式3: D̄ᵢᵗᵍᵗ = Σⱼ Dᵢⱼᵗᵍᵗ / Σⱼ 1(Dᵢⱼᵗᵍᵗ) - 计算每个节点的平均距离 ===

# 计算每个节点的平均邻居距离
mad_neighbor_per_node = []  # 存储每个节点的D̄ᵢⁿᵉⁱᵍʰᵇᵒʳ

for f in range(nframes):
    for i in range(nloc):
        # 获取节点i的邻居距离 Dᵢⱼⁿᵉⁱᵍʰᵇᵒʳ (j是i的邻居)
        neighbor_distances_i = cosine_dist[f, i][neighbor_mask[f, i]]
        
        if len(neighbor_distances_i) > 0:
            # D̄ᵢⁿᵉⁱᵍʰᵇᵒʳ = Σⱼ Dᵢⱼⁿᵉⁱᵍʰᵇᵒʳ / Σⱼ 1(Dᵢⱼⁿᵉⁱᵍʰᵇᵒʳ > 0)
            avg_neighbor_dist_i = neighbor_distances_i.mean()
            mad_neighbor_per_node.append(avg_neighbor_dist_i)

# 计算每个节点的平均远程距离
mad_remote_per_node = []  # 存储每个节点的D̄ᵢʳᵉᵐᵒᵗᵉ

for f in range(nframes):
    for i in range(nloc):
        # 获取节点i的远程距离 Dᵢⱼʳᵉᵐᵒᵗᵉ (j不是i的邻居且j≠i)
        remote_distances_i = cosine_dist[f, i][remote_mask[f, i]]
        
        if len(remote_distances_i) > 0:
            # D̄ᵢʳᵉᵐᵒᵗᵉ = Σⱼ Dᵢⱼʳᵉᵐᵒᵗᵉ / Σⱼ 1(Dᵢⱼʳᵉᵐᵒᵗᵉ > 0)
            avg_remote_dist_i = remote_distances_i.mean()
            mad_remote_per_node.append(avg_remote_dist_i)

# === 公式4: MADᵗᵍᵗ = Σᵢ D̄ᵢᵗᵍᵗ / Σᵢ 1(D̄ᵢᵗᵍᵗ) - 对节点平均距离再求平均 ===

# MADⁿᵉⁱᵍʰᵇᵒʳ = Σᵢ D̄ᵢⁿᵉⁱᵍʰᵇᵒʳ / Σᵢ 1(D̄ᵢⁿᵉⁱᵍʰᵇᵒʳ)
if len(mad_neighbor_per_node) > 0:
    mad_neighbor = torch.stack(mad_neighbor_per_node).mean()
else:
    mad_neighbor = torch.tensor(0.0, device=device)

# MADʳᵉᵐᵒᵗᵉ = Σᵢ D̄ᵢʳᵉᵐᵒᵗᵉ / Σᵢ 1(D̄ᵢʳᵉᵐᵒᵗᵉ)
if len(mad_remote_per_node) > 0:
    mad_remote = torch.stack(mad_remote_per_node).mean()
else:
    mad_remote = torch.tensor(0.0, device=device)

# MADGap = MADʳᵉᵐᵒᵗᵉ - MADⁿᵉⁱᵍʰᵇᵒʳ
mad_gap = mad_remote - mad_neighbor
print("mad_gap", mad_gap)

norms tensor([[[1.0000],
         [2.0000],
         [5.0000]],

        [[1.1045],
         [2.1024],
         [5.1010]]])
dot_products tensor([[[ 1.0000,  2.0000,  5.0000],
         [ 2.0000,  4.0000, 10.0000],
         [ 5.0000, 10.0000, 25.0000]],

        [[ 1.2200,  2.3200,  5.6200],
         [ 2.3200,  4.4200, 10.7200],
         [ 5.6200, 10.7200, 26.0200]]])
norm_products tensor([[[ 1.0000,  2.0000,  5.0000],
         [ 2.0000,  4.0000, 10.0000],
         [ 5.0000, 10.0000, 25.0000]],

        [[ 1.2200,  2.3222,  5.6342],
         [ 2.3222,  4.4200, 10.7242],
         [ 5.6342, 10.7242, 26.0200]]])
cosine_dist tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[-1.1921e-07,  9.2763e-04,  2.5232e-03],
         [ 9.2763e-04,  0.0000e+00,  3.9136e-04],
         [ 2.5232e-03,  3.9136e-04,  5.9605e-08]]])
tensor([[[ 1, -1],
         [ 0,  2],
         [ 1, -1]],

        [[

In [None]:
import torch
data = torch.load("debug_mad_gap.pt")
node_ebd_pt = data["node_ebd"]
print("node_ebd_pt", node_ebd_pt)
print("node_ebd_pt shape", node_ebd_pt.shape)
nframes, nloc, nembed = node_ebd_pt.shape
device = node_ebd_pt.device
eye_mask = torch.eye(nloc, dtype=torch.bool, device=device).unsqueeze(0).expand(nframes, -1, -1)
print("eye_mask", eye_mask)
print("eye_mask shape", eye_mask.shape)
valid_mask = ~eye_mask
print("valid_mask", valid_mask)
print("valid_mask shape", valid_mask.shape)

    # 2. 添加简化的 _compute_mad 方法

"""计算基础MAD (Mean Average Distance) 用于正则化

MAD使用余弦距离衡量节点嵌入表征之间的平均距离:
余弦距离 = 1 - 余弦相似度 = 1 - (Hi · Hj) / (|Hi| · |Hj|)

Parameters
---------- 
node_ebd : torch.Tensor
    节点嵌入表征，形状 [nframes, nloc, embed_dim]
    
Returns
-------
torch.Tensor
    所有节点对之间的平均余弦距离
"""
nframes, nloc, embed_dim = node_ebd.shape
device = node_ebd.device

# 计算向量范数 |Hi|, |Hj|
norms = torch.norm(node_ebd, p=2, dim=-1, keepdim=True)  # [nf, nloc, 1]

# 计算点积矩阵 Hi · Hj  
dot_products = torch.bmm(node_ebd, node_ebd.transpose(-1, -2))  # [nf, nloc, nloc]

# 计算范数乘积矩阵 |Hi| * |Hj|
norm_products = torch.bmm(norms, norms.transpose(-1, -2))  # [nf, nloc, nloc]

# 余弦距离 = 1 - 余弦相似度
cosine_dist = 1.0 - (dot_products / (norm_products + 1e-8))  # 加小值避免除零
print("cosine_dist", cosine_dist)
print("cosine_dist shape", cosine_dist.shape)
# 排除对角线（自己与自己的距离为0）
eye_mask = torch.eye(nloc, dtype=torch.bool, device=device).unsqueeze(0).expand(nframes, -1, -1)
valid_mask = ~eye_mask

# 计算所有有效节点对的平均距离
valid_distances = cosine_dist[valid_mask]
mad = valid_distances.mean() if len(valid_distances) > 0 else torch.tensor(0.0, device=device)




print("mad", mad)

node_ebd_pt tensor([[[0.2028, 0.1567, 0.5198,  ..., 0.5832, 0.0840, 0.3441],
         [0.2053, 0.1584, 0.5246,  ..., 0.5857, 0.0838, 0.3596],
         [0.1926, 0.1502, 0.5139,  ..., 0.5794, 0.0786, 0.3462],
         ...,
         [0.2044, 0.1533, 0.4108,  ..., 0.4040, 0.1300, 0.3748],
         [0.2143, 0.1613, 0.4193,  ..., 0.4090, 0.1362, 0.3807],
         [0.2219, 0.1731, 0.4327,  ..., 0.4144, 0.1316, 0.3942]]],
       device='cuda:0', requires_grad=True)
node_ebd_pt shape torch.Size([1, 192, 128])
eye_mask tensor([[[ True, False, False,  ..., False, False, False],
         [False,  True, False,  ..., False, False, False],
         [False, False,  True,  ..., False, False, False],
         ...,
         [False, False, False,  ...,  True, False, False],
         [False, False, False,  ..., False,  True, False],
         [False, False, False,  ..., False, False,  True]]], device='cuda:0')
eye_mask shape torch.Size([1, 192, 192])
valid_mask tensor([[[False,  True,  True,  ...,  True,  T

In [27]:
import torch
node_ebd = torch.tensor([[[3.0,4.0],[4.0,3.0],[3.0,4.0]],[[3.0,4.0],[4.0,3.0],[3.0,4.0]]])
print("node_ebd", node_ebd)
print("node_ebd shape", node_ebd.shape)

"""计算基础MAD (Mean Average Distance) 用于正则化

MAD使用余弦距离衡量节点嵌入表征之间的平均距离:
余弦距离 = 1 - 余弦相似度 = 1 - (Hi · Hj) / (|Hi| · |Hj|)

Parameters
---------- 
node_ebd : torch.Tensor
    节点嵌入表征，形状 [nframes, nloc, embed_dim]
    
Returns
-------
torch.Tensor
    所有节点对之间的平均余弦距离
"""
nframes, nloc, embed_dim = node_ebd.shape
device = node_ebd.device
print("nframes", nframes)
print("nloc", nloc)
print("embed_dim", embed_dim)
print("device", device)
# 计算向量范数 |Hi|, |Hj|
norms = torch.norm(node_ebd, p=2, dim=-1, keepdim=True)  # [nf, nloc, 1]
print("norms", norms)
print("norms shape", norms.shape)
# 计算点积矩阵 Hi · Hj  
dot_products = torch.bmm(node_ebd, node_ebd.transpose(-1, -2))  # [nf, nloc, nloc]
print("dot_products", dot_products)
print("dot_products shape", dot_products.shape)
# 计算范数乘积矩阵 |Hi| * |Hj|
norm_products = torch.bmm(norms, norms.transpose(-1, -2))  # [nf, nloc, nloc]
print("norm_products", norm_products)
print("norm_products shape", norm_products.shape)
# 余弦距离 = 1 - 余弦相似度
cosine_dist = 1.0 - (dot_products / (norm_products + 1e-8))  # 加小值避免除零
print("cosine_dist", cosine_dist)
print("cosine_dist shape", cosine_dist.shape)
# 排除对角线（自己与自己的距离为0）
eye_mask = torch.eye(nloc, dtype=torch.bool, device=device).unsqueeze(0).expand(nframes, -1, -1)
print("eye_mask", eye_mask)
print("eye_mask shape", eye_mask.shape)
valid_mask = ~eye_mask
print("valid_mask", valid_mask)
print("valid_mask shape", valid_mask.shape)

# 计算所有有效节点对的平均距离
valid_distances = cosine_dist[valid_mask]
print("valid_distances", valid_distances)
print("valid_distances shape", valid_distances.shape)
print(len(valid_distances))
mad = valid_distances.mean() 
print("mad", mad)
print("mad shape", mad.shape)

print(cosine_dist.sum()/(nframes*nloc*(nloc-1)))
def _compute_mad_v1_manual(node_ebd: torch.Tensor) -> torch.Tensor:
    """手动实现版本（当前的高效版）"""
    # 标准化嵌入向量
    node_ebd_norm = F.normalize(node_ebd, p=2, dim=-1)  # [nf, nloc, embed_dim]
    
    # 计算余弦相似度矩阵
    cosine_sim = torch.bmm(node_ebd_norm, node_ebd_norm.transpose(-1, -2))
    
    # 余弦距离 = 1 - 余弦相似度
    cosine_dist = 1.0 - cosine_sim
    
    # Global MAD
    global_mad = cosine_dist.sum()/(nframes*nloc*(nloc-1))
    
    return global_mad

mad = _compute_mad_v1_manual(node_ebd)
print("mad", mad)

node_ebd tensor([[[3., 4.],
         [4., 3.],
         [3., 4.]],

        [[3., 4.],
         [4., 3.],
         [3., 4.]]])
node_ebd shape torch.Size([2, 3, 2])
nframes 2
nloc 3
embed_dim 2
device cpu
norms tensor([[[5.],
         [5.],
         [5.]],

        [[5.],
         [5.],
         [5.]]])
norms shape torch.Size([2, 3, 1])
dot_products tensor([[[25., 24., 25.],
         [24., 25., 24.],
         [25., 24., 25.]],

        [[25., 24., 25.],
         [24., 25., 24.],
         [25., 24., 25.]]])
dot_products shape torch.Size([2, 3, 3])
norm_products tensor([[[25., 25., 25.],
         [25., 25., 25.],
         [25., 25., 25.]],

        [[25., 25., 25.],
         [25., 25., 25.],
         [25., 25., 25.]]])
norm_products shape torch.Size([2, 3, 3])
cosine_dist tensor([[[0.0000, 0.0400, 0.0000],
         [0.0400, 0.0000, 0.0400],
         [0.0000, 0.0400, 0.0000]],

        [[0.0000, 0.0400, 0.0000],
         [0.0400, 0.0000, 0.0400],
         [0.0000, 0.0400, 0.0000]]])
cosine

In [6]:
torch.eye(2)

tensor([[1., 0.],
        [0., 1.]])

# nlist 是DeepMD-kit中描述原子邻居关系的关键数据结构：
1. 形状: [nframes, nloc, nnei]
2. 内容: 每个原子的邻居原子索引
3. $填充: 不足的邻居用-1填充$
4. 作用: 用于构建邻居掩码，区分邻居和非邻居原子
5. 物理意义: 基于距离截断的原子间相互作用关系

In [28]:
import torch
import torch.nn.functional as F

def _compute_mad_v1_manual(self, node_ebd: torch.Tensor) -> torch.Tensor:
    """手动实现版本（当前的高效版）"""
    # 标准化嵌入向量
    node_ebd_norm = F.normalize(node_ebd, p=2, dim=-1)  # [nf, nloc, embed_dim]
    
    # 计算余弦相似度矩阵
    cosine_sim = torch.bmm(node_ebd_norm, node_ebd_norm.transpose(-1, -2))
    
    # 余弦距离 = 1 - 余弦相似度
    cosine_dist = 1.0 - cosine_sim
    
    # Global MAD
    global_mad = cosine_dist.mean()
    
    return global_mad


def _compute_mad_v2_builtin(self, node_ebd: torch.Tensor) -> torch.Tensor:
    """使用内置cosine_similarity函数"""
    nframes, nloc, embed_dim = node_ebd.shape
    
    # 扩展维度以计算所有节点对的相似度
    node1 = node_ebd.unsqueeze(2)  # [nf, nloc, 1, embed_dim]  
    node2 = node_ebd.unsqueeze(1)  # [nf, 1, nloc, embed_dim]
    
    # 使用内置函数计算余弦相似度
    cosine_sim = F.cosine_similarity(node1, node2, dim=-1)  # [nf, nloc, nloc]
    
    # 余弦距离 = 1 - 余弦相似度
    cosine_dist = 1.0 - cosine_sim
    
    # Global MAD
    global_mad = cosine_dist.mean()
    
    return global_mad


def _compute_mad_v3_most_efficient(self, node_ebd: torch.Tensor) -> torch.Tensor:
    """最简洁版本（推荐）"""
    nframes, nloc, embed_dim = node_ebd.shape
    
    # 直接计算所有对的余弦距离
    cosine_dist = 1.0 - F.cosine_similarity(
        node_ebd.unsqueeze(2),   # [nf, nloc, 1, embed_dim]
        node_ebd.unsqueeze(1),   # [nf, 1, nloc, embed_dim]  
        dim=-1                   # [nf, nloc, nloc]
    )
    
    return cosine_dist.mean()


# 性能测试代码
def benchmark_methods():
    """简单的性能对比"""
    import time
    import torch
    data = torch.load("debug_mad_gap.pt")
    node_ebd = data["node_ebd"]
    # 创建测试数据
    nframes, nloc, embed_dim = node_ebd.shape
    #node_ebd = torch.randn(nframes, nloc, embed_dim, device='cuda' if torch.cuda.is_available() else 'cpu')
    
    methods = [
        ("手动实现", _compute_mad_v1_manual),
        ("内置函数", _compute_mad_v2_builtin), 
        ("最简洁版", _compute_mad_v3_most_efficient)
    ]
    
    # 预热
    for _, method in methods:
        _ = method(None, node_ebd)
    
    # 测试
    for name, method in methods:
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        start = time.time()
        
        for _ in range(100):
            result = method(None, node_ebd)
        
        torch.cuda.synchronize() if torch.cuda.is_available() else None
        end = time.time()
        
        print(f"{name}: {(end-start)*1000:.2f}ms, result: {result:.6f}")

benchmark_methods()  # 取消注释来运行测试

手动实现: 12.06ms, result: 0.004947
内置函数: 46.14ms, result: 0.004947
最简洁版: 45.38ms, result: 0.004947


In [34]:
import torch
data = torch.load("debug_mad_gap.pt")
node_ebd = data["node_ebd"]
# 创建测试数据
nframes, nloc, embed_dim = node_ebd.shape
print("node_ebd.min()", node_ebd.min())
print("node_ebd.max()", node_ebd.max())
node_ebd_norm = F.normalize(node_ebd, p=2, dim=-1)
print("node_ebd_norm.min()", node_ebd_norm.min())
print("node_ebd_norm.max()", node_ebd_norm.max())
cosine_sim = torch.bmm(node_ebd_norm, node_ebd_norm.transpose(-1, -2))
cosine_sim.min()
print("cosine_sim.max()", cosine_sim.max())
cosine_sim.max()
print("cosine_sim.mean()", cosine_sim.mean())

cosine_dist = 1.0 - cosine_sim
cosine_dist.min()
cosine_dist.max()

node_ebd.min() tensor(-0.1624, device='cuda:0', grad_fn=<MinBackward1>)
node_ebd.max() tensor(1.1052, device='cuda:0', grad_fn=<MaxBackward1>)
node_ebd_norm.min() tensor(-0.0370, device='cuda:0', grad_fn=<MinBackward1>)
node_ebd_norm.max() tensor(0.2571, device='cuda:0', grad_fn=<MaxBackward1>)
cosine_sim.max() tensor(1.0000, device='cuda:0', grad_fn=<MaxBackward1>)
cosine_sim.mean() tensor(0.9951, device='cuda:0', grad_fn=<MeanBackward0>)


tensor(0.0131, device='cuda:0', grad_fn=<MaxBackward1>)

In [15]:
nlist
nlist[0,0]

tensor([ 1, -1])

In [11]:
# 使用PyTorch内置函数
cosine_sim = F.cosine_similarity(
    node_ebd.unsqueeze(2),  # [nf, nloc, 1, embed_dim] 
    node_ebd.unsqueeze(1),  # [nf, 1, nloc, embed_dim]
    dim=-1  # 在embed_dim维度计算
)  # [nf, nloc, nloc]

cosine_dist = 1.0 - cosine_sim
print("cosine_dist", cosine_dist)

cosine_dist tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[-1.1921e-07,  9.2763e-04,  2.5233e-03],
         [ 9.2763e-04,  5.9605e-08,  3.9136e-04],
         [ 2.5233e-03,  3.9136e-04,  0.0000e+00]]])
