<a href="https://colab.research.google.com/github/wangbxj1234/xformerforpt/blob/main/withop_local_linear_Attention%E4%BB%A3%E7%A0%81.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn.functional as F
import numpy as np

from torch import Tensor
from typing import Optional
from torch import nn
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [N, C]
        dst: target points, [M, C]
    Output:
        dist: per-point square distance, [ N, M]
    """
    return torch.sum((src[:, None] - dst[None,:]) ** 2, dim=-1)

def index_points(points, idx):#从n个坐标中按照index提取s个坐标或者s*k个坐标，进行grouping
    """
    Input:
        points: input points data, [N, C]
        idx: sample index data, [N, [K]]
    Return:
        new_points:, indexed points data, [N, [K], C] 
    """
    raw_size = idx.size()#这里是torch size 相当于 numpy的shape
    idx = idx.reshape(-1) # N*K.
    res = torch.gather(points, 0, idx[..., None].expand(-1, points.size(-1))) #先把idx的第2维复制到=C,再用gather按索引提取对应点出来。
    return res.reshape(*raw_size, -1)#就是idx的shape（2维或者三维），再加上最后一维c，理论上这个-1也可以写成'points.size(-1)'吧。




class local_linear_Attention(nn.Module):
    # flow attention in normal version
    def __init__(self, in_planes,k):
        super(local_linear_Attention, self).__init__()
        self.n_heads =8
        self.query_projection = nn.Linear(in_planes, in_planes)
        self.key_projection = nn.Linear(in_planes, in_planes)
        self.value_projection = nn.Linear(in_planes, in_planes)
        self.out_projection = nn.Linear(in_planes, in_planes)
        self.dropout = nn.Dropout(0.05)
        self.eps = 1e-6
        self.k = k

    def kernel_method(self, x):
        return torch.sigmoid(x)

    def dot_product(self, q, k, v):
        kv = torch.einsum("hld,hlm->hdm", k, v)
        qkv = torch.einsum("hld,hdm->hlm", q, kv)
        return qkv

    def forward(self, pxo) -> torch.Tensor:
        p, x, o = pxo          
        L, embed_dim = x.size() ###(N,d) 
        dists = square_distance(p, p)
        knn_idx = dists.argsort()[:, :self.k]

        Q = self.query_projection(x).view(L, self.n_heads, -1)
        K = index_points(self.key_projection(x), knn_idx).view(L, self.k,self.n_heads, -1)
        V = index_points(self.value_projection(x), knn_idx).view(L, self.k,self.n_heads, -1)

        Q = self.kernel_method(Q)
        K = self.kernel_method(K)

        KV = torch.einsum("skhd,skhm->hmd", K, V)

           # Compute the normalizer
        Z = 1/(torch.einsum("lhd,lhd->lh", Q, K.sum(dim=1))+self.eps)

           # Finally compute and return the new values
        V = torch.einsum("lhd,hmd,lh->lhm", Q, KV, Z)


        output = V.reshape(L, -1)
        output = self.out_projection(output)
        output = self.dropout(output)

        return output


In [2]:
o=[ 838, 2088, 2317, 3567, 4541, 5561, 6427, 7077]
o=torch.IntTensor(o)
x = torch.rand(7077, 128)
p = torch.rand(7077, 3)

In [5]:
planes=128
transformer2 = local_linear_Attention(planes,k=16)
right_res = transformer2([p, x, o])
print(right_res.shape)

torch.Size([7077, 128])
