<a href="https://colab.research.google.com/github/wangbxj1234/annotated_pt_codes/blob/main/%E4%B8%8B%E9%87%87%E6%A0%B7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)

def index_points(points, idx):#从n个坐标中按照index提取s个坐标或者s*k个坐标，可以进行sampling 或者 sampling&grouping
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C] 
    """
    raw_size = idx.size()#这里是torch size 相当于 numpy的shape
    idx = idx.reshape(raw_size[0], -1) # B,S.  OR.   B,S*K.
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1))) #先把idx的第三维复制到=C,再用gather按索引提取对应点出来。
    return res.reshape(*raw_size, -1)#就是idx的shape（2维或者三维），再加上最后一维c，理论上这个-1也可以写成'points.size(-1)'吧。
    
class TransformerBlock(nn.Module):
    def __init__(self, d_points, d_model, k) -> None:
        super().__init__()
        self.fc1 = nn.Linear(d_points, d_model)
        self.fc2 = nn.Linear(d_model, d_points)
        self.fc_delta = nn.Sequential(
            nn.Linear(3, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.fc_gamma = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.w_qs = nn.Linear(d_model, d_model, bias=False)
        self.w_ks = nn.Linear(d_model, d_model, bias=False)
        self.w_vs = nn.Linear(d_model, d_model, bias=False)
        self.k = k
        self.fc_sub = nn.Sequential(
            nn.Linear(d_points, d_points),
            nn.ReLU(),
            nn.Linear(d_points, d_points),
            nn.BatchNorm1d(d_points),
            nn.ReLU()
        )        
    # xyz: b x n x 3, features: b x n x f
    def forward(self, xyz, features):
        dists = square_distance(xyz, xyz)
        knn_idx = dists.argsort()[:, :, :self.k]  # b x n x k
        knn_xyz = index_points(xyz, knn_idx)######b*n*k*f，local attention操作，没有sample，只有group！ （本来应该是b * n * n * f）
        
        x_in = features
        x = self.fc1(features)
        q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx) #q取所有点，localatten所以k和v按knn取点。
        #######
        pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)  ### b x n x k x f
        #####
        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
        attn = F.softmax(attn / np.sqrt(k.size(-1)), dim=-2)  # b x n x k x f
        ####（做内积的对象是2个， 每个都是 b x n x k x f 
        x_r = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc) ########按照第3个维度做内积，又变成了 b * n * f。
#        x_r = self.fc2(x_r) 
        print(x_r.shape)
 #       res = self.fc_sub(x_in-x_r) + x_in
        res = self.fc2(x_r) + x_in                             
        return res, attn

attn = TransformerBlock(d_points=32, d_model=512, k=16) ################第一步，传入对应init中参数

x = torch.randn(1, 1024, 32)
xyz = x[..., :3] 

In [13]:
attn(xyz,x)
points = attn(xyz,x)[0]########第二步，传入对应forward参数

torch.Size([1, 1024, 512])
torch.Size([1, 1024, 512])


In [14]:
points.shape

torch.Size([1, 1024, 32])

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np


# reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You


def timeit(tag, t):
    print("{}: {}s".format(tag, time() - t))
    return time()

def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)


def index_points(points, idx):#从n个坐标中按照index提取s个坐标或者s*k个坐标，可以进行sampling 或者 sampling&grouping
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S, [K]]
    Return:
        new_points:, indexed points data, [B, S, [K], C] 
    """
    raw_size = idx.size()#这里是torch size 相当于 numpy的shape
    idx = idx.reshape(raw_size[0], -1) # B,S.  OR.   B,S*K.
    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1))) #先把idx的第三维复制到=C,再用gather按索引提取对应点出来。
    return res.reshape(*raw_size, -1)#就是idx的shape（2维或者三维），再加上最后一维c，理论上这个-1也可以写成'points.size(-1)'吧。


def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        distance = torch.min(distance, dist)
        farthest = torch.max(distance, -1)[1]
    return centroids





def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint],（centroids: sampled pointcloud index）
    torch.cuda.empty_cache()#显存一步一清理，清除下面用不到的临时变量
    new_xyz = index_points(xyz, fps_idx)#new_points:, indexed points data, [B, npoint, C]
    torch.cuda.empty_cache()
    if knn:
        dists = square_distance(new_xyz, xyz)  # B x npoint x N
        idx = dists.argsort()[:, :, :nsample]  # B x npoint x K       
    else:
        idx = query_ball_point(radius, nsample, xyz, new_xyz)
    torch.cuda.empty_cache()
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
    torch.cuda.empty_cache()
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) #就是针对每个组，都用k个点减去中心点，得到归一化的normalized coordinate
    torch.cuda.empty_cache()

    if points is not None: #如果输入的除了点云的坐标数据（c=3），还有特征数据（d=？），那么cat一下，c+d
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
    else:######################如果没有特征数据d，那就只输出3维坐标特征
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx###想要 分组 和采样标签 也可以输出
    else:
        return new_xyz, new_points ######new_xyz输出的是中心点（sample的结果[B, npoint, C]），new_points是[B, npoint, nsample, C+D]


###channel = 32 * 2 ** (i + 1)
class PointNetSetAbstraction(nn.Module):##npoints // 4 ** (i + 1), 0, nneighbor, channel // 2 + 3, [channel, channel], group_all=False, knn=True
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.knn = knn
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel#####channel // 2 + 3
        for out_channel in mlp:#####[channel, channel]
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))######两层2维卷积构成的mlp,1的意思是对第二个维度做卷积，inputdim是35，outputdim是64
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
            #print(out_channel)    64 ,64
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, N, C]
            points: input points data, [B, N, C]
        Return:
            new_xyz: sampled points position data, [B, S, C]
            new_points_concat: sample points feature data, [B, S, D']
        """
        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:####################npoints // 4 ** (i + 1),0,16,xyz,points,knn
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        print(new_points.shape)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))
            print(i,new_points.shape)
        new_points = torch.max(new_points, 2)[0].transpose(1, 2)
        return new_xyz, new_points

channel = 32 * 2 ** (0 + 1)######## 32 * 2，这是outputdim。
channels=[channel // 2 + 3, channel, channel]############inputdim 就是 32+xyz那3维.
transition_downs_0 = PointNetSetAbstraction(1024 // 4 ** (0 + 1), 0, 16, channels[0], channels[1:], group_all=False, knn=True)

In [16]:
xyz, points = transition_downs_0(xyz, points)

torch.Size([1, 35, 16, 256])
0 torch.Size([1, 64, 16, 256])
1 torch.Size([1, 64, 16, 256])


In [17]:
print(xyz.shape,points.shape)

torch.Size([1, 256, 3]) torch.Size([1, 256, 64])


In [18]:
channels

[35, 64, 64]