<a href="https://colab.research.google.com/github/wangbxj1234/offset_pt_try/blob/main/%E7%BD%91%E7%BB%9C%E5%B1%82%E7%9A%84%E5%8F%82%E6%95%B0%E4%BC%A0%E5%85%A5%E8%A7%84%E5%88%99.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)

def index_points(points, idx):#从n个坐标中按照index提取s个坐标或者s*k个坐标，可以进行sampling 或者 sampling&grouping
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C] 
    """
    raw_size = idx.size()#这里是torch size 相当于 numpy的shape
    idx = idx.reshape(raw_size[0], -1) # B,S.  OR.   B,S*K.
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1))) #先把idx的第三维复制到=C,再用gather按索引提取对应点出来。
    return res.reshape(*raw_size, -1)#就是idx的shape（2维或者三维），再加上最后一维c，理论上这个-1也可以写成'points.size(-1)'吧。
    
class TransformerBlock(nn.Module):
    def __init__(self, d_points, d_model, k) -> None:
        super().__init__()
        self.fc1 = nn.Linear(d_points, d_model)
        self.fc2 = nn.Linear(d_model, d_points)
        self.fc_delta = nn.Sequential(
            nn.Linear(3, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.fc_gamma = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.w_qs = nn.Linear(d_model, d_model, bias=False)
        self.w_ks = nn.Linear(d_model, d_model, bias=False)
        self.w_vs = nn.Linear(d_model, d_model, bias=False)
        self.k = k
        self.fc_sub = nn.Sequential(
            nn.Linear(d_points, d_points),
            nn.ReLU(),
            nn.Linear(d_points, d_points),
            nn.BatchNorm1d(d_points),
            nn.ReLU()
        )        
    # xyz: b x n x 3, features: b x n x f
    def forward(self, xyz, features):
        dists = square_distance(xyz, xyz)
        knn_idx = dists.argsort()[:, :, :self.k]  # b x n x k
        knn_xyz = index_points(xyz, knn_idx)######b*n*k*f，local attention操作，没有sample，只有group！ （本来应该是b * n * n * f）
        
        x_in = features
        x = self.fc1(features)
        q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx) #q取所有点，localatten所以k和v按knn取点。
        #######
        pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)  ### b x n x k x f
        #####
        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f
        ####（做内积的对象是2个， 每个都是 b x n x k x f 
        x_r = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) ########按照第3个维度做内积，又变成了 b * n * f。
#        x_r = self.fc2(x_r) 
        print(x_r.shape)
 #       res = self.fc_sub(x_in-x_r) + x_in
        res = self.fc2(x_r) + x_in                             
        return res, attn

attn = TransformerBlock(d_points=32, d_model=512, k=16) ################第一步，传入对应init中参数

x = torch.randn(1, 1024, 32)
xyz = x[..., :3] 

In [None]:
attn(xyz,x)########第二步，传入对应forward参数

torch.Size([1, 1024, 512])


(tensor([[[ 0.7942, -1.1685,  0.8907,  ..., -2.1958,  0.3499,  0.5752],
          [ 0.5355, -0.4327, -0.2372,  ...,  1.1746, -0.4953, -0.8143],
          [-0.3290, -0.3095, -0.6630,  ...,  0.6544,  0.6003, -0.5973],
          ...,
          [ 1.2142,  1.6940, -1.2997,  ...,  0.3245, -0.2688, -0.2004],
          [-0.6360, -1.1067,  0.3465,  ..., -0.5920,  2.4559,  1.5122],
          [ 0.0665,  0.4303, -1.1062,  ...,  1.5498,  2.0652, -0.4984]]],
        grad_fn=<AddBackward0>),
 tensor([[[[0.0624, 0.0625, 0.0624,  ..., 0.0626, 0.0626, 0.0626],
           [0.0622, 0.0624, 0.0626,  ..., 0.0624, 0.0625, 0.0624],
           [0.0626, 0.0623, 0.0630,  ..., 0.0624, 0.0625, 0.0625],
           ...,
           [0.0624, 0.0625, 0.0624,  ..., 0.0626, 0.0623, 0.0624],
           [0.0628, 0.0626, 0.0623,  ..., 0.0626, 0.0624, 0.0627],
           [0.0624, 0.0622, 0.0624,  ..., 0.0626, 0.0630, 0.0624]],
 
          [[0.0626, 0.0623, 0.0623,  ..., 0.0623, 0.0625, 0.0623],
           [0.0623, 0.0625, 0.