In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_cluster import knn_graph, knn
from torch_geometric.nn import MLP

In [26]:
class Global_Encoder(nn.Module):

    def __init__(self, num_neighbors, num_samples, mlp, witht=True) -> None:
        super().__init__()
        self.mlp = mlp
        self.num_neighbors = num_neighbors
        self.num_samples = num_samples
        self.witht = witht

    def forward(self, f1_list, xyz1_list, dbatch):
        fea_source, fea_tar = self.wighted_sample(f1_list, xyz1_list, dbatch)
        print("fea_source shape", fea_source.shape)
        print("fea_tar shape", fea_tar.shape)
        tmp = torch.cat([fea_tar, fea_tar-fea_source], dim=2)
        print("tmp shape", tmp.shape)
        out = torch.max(self.mlp(tmp), dim=-2)[0]
        print("out shape", out.shape)
        return torch.tanh(out)

    def wighted_sample(self, f1_list, xyz1_list, dbatch):
        #input: xyz1_list, f1_list
        l = len(f1_list)
        n, h = f1_list[-1].shape

        fea_set = torch.zeros(l-1, n, self.num_neighbors, h)
        delta_r_set = torch.zeros(l-1, n, self.num_neighbors)
        
        anchor = xyz1_list[-1]
        for i in range(0, l-1):
            idx_anchor, idx_i= knn(x=xyz1_list[i],y=anchor, k=self.num_neighbors, batch_x=dbatch, batch_y=dbatch)
            # 每个frame先找 num_neighbors 个点

            delta_xyz = (anchor[idx_anchor,:] - xyz1_list[i][idx_i,:]).reshape(-1, self.num_neighbors, 3).contiguous()
            delta_r = torch.sqrt(torch.sum(delta_xyz**2, dim=2))
            delta_r_set[i,:,:] = delta_r
            #可能会用index_operation的问题，待定

            fea_set[i,:,:] = f1_list[i][idx_i,:].reshape(-1, self.num_neighbors, h).contiguous()

        fea_set = fea_set.view(1, n, -1, h).contiguous().squeeze(0) #[B*N, num_neighbors, h]
        delta_r_set = delta_r_set.view(n, -1).contiguous() 

        #根据delta_r采样， P 正比于 exp(-belta * delta_r)
        alpha, belta = 0.5 ,0.5
        # belta越小，分布越平滑, 
        # belta太大的话，跟第一步直接取num_samples个点的效果差不多

        print("delta_r_set shape", delta_r_set.shape)
        pr_r = torch.softmax(delta_r_set* (-1.*belta), dim=1) # [B*N, num_neighbors * (T-1)]

        if self.witht:
            delta_t = torch.repeat_interleave(torch.arange(l-1, 0, -1), self.num_neighbors)
            delta_t_set = delta_t.unsqueeze(0).repeat(n, 1).contiguous()
            print("delta_t_set shape", delta_t_set.shape)
            pr_t = torch.softmax(delta_t_set* (-1.*0.5), dim=1) # [B*N, num_neighbors * (T-1)]
            p = 0.5 * pr_r + 0.5 * pr_t
        else:
            p = pr_r
        
        print("p shape", p.shape)
        print(p[0].sum())

        idx = torch.multinomial(p, self.num_samples, replacement=False) # [B*N, num_samples * (T-1)]

        fea_source = fea_set[torch.arange(n).unsqueeze(1),idx,:] # [B*N, num_samples * (T-1), 64]
        _, sum_n, _ = fea_source.shape
        fea_tar = f1_list[-1].unsqueeze(1).repeat(1, sum_n, 1) # [B*N, num_samples * (T-1), 64]
    
        return fea_source, fea_tar


In [27]:
f1_list = [torch.randn(2*512, 64) for _ in range(5)]
xyz1_list = [torch.randn(2*512, 3) for _ in range(5)]

t = torch.LongTensor([0,1])
dbatch = torch.repeat_interleave(t, 512)

In [29]:
gmodel = Global_Encoder(10, 4, MLP([128, 128,128], norm=None), witht=True)

out = gmodel(f1_list, xyz1_list, dbatch)

delta_r_set shape torch.Size([1024, 40])
delta_t_set shape torch.Size([1024, 40])
p shape torch.Size([1024, 40])
tensor(1.0000)
fea_source shape torch.Size([1024, 4, 64])
fea_tar shape torch.Size([1024, 4, 64])
tmp shape torch.Size([1024, 4, 128])
out shape torch.Size([1024, 128])


In [None]:
import os
import sys

# 获取当前 Notebook 文件的路径
notebook_path = os.path.abspath('')
# 获取父文件夹的路径
parent_folder = os.path.dirname(notebook_path)

# 使用 IPython 魔术命令设置环境变量
%env PYTHONPATH=$PYTHONPATH:{parent_folder}


In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_cluster import knn_graph, knn
from torch import Tensor
from typing import Optional, Tuple
from torch_geometric.nn import global_max_pool, global_mean_pool, GATv2Conv
import sys

from torch.utils.data import DataLoader

from dataset.kitti_dataset import KittiDataset

from pugcn_lib.feature_extractor import InceptionFeatureExtractor
from pugcn_lib.upsample import GeneralUpsampler
from pugcn_lib.models import Refiner

from models.encoder_pn2 import PNT_2
from dataset.kitti_dataset_v2 import KittiDataset_2
from models.time_series_models import ScaledDotProductAttention, MultiHeadAttention

In [None]:
#LSTM
ft3_list = torch.stack([torch.rand(2*512, 256) for i in range(4)], dim=0)
print(ft3_list.shape)
lstm_3 = nn.LSTM(input_size=256, hidden_size=256*2, num_layers=2, batch_first=False)

In [None]:
ft3_list[0:-1].shape

In [None]:
import torch
import torch.nn as nn
import numpy as np

def model_summary(model, input_size):
    # 打印模型信息
    print(f"{'='*30} Model Summary {'='*30}")
    print("{:<30} {:<25} {:<20}".format("Layer", "Parameter Size", "Parameters"))
    print("="*75)
    
    # 将输入传递到模型，以便计算参数大小
    device = next(model.parameters()).device
    input_tensor = torch.randn(*input_size).to(device)
    
    # 遍历模型的所有层
    total_params = 0
    total_size = 0
    for name, param in model.named_parameters():
        layer_params = np.prod(param.size())
        total_params += layer_params
        total_size += param.numel() * param.element_size()
        print("{:<30} {:<25} {:<20}".format(name, str(param.size()), f"{layer_params:,}"))
    
    total_size_mb = total_size / (1024 ** 2)
    
    print("="*75)
    print(f"Total parameters: {total_params:,}")
    print(f"Total size: {total_size_mb:.2f} MB")

In [None]:
model_summary(lstm_3, (4, 1, 256))

In [None]:
out, hc = lstm_3(ft3_list)
print("out.shape: ", out.shape)
out[-1].shape

In [None]:
hc[0].shape, hc[1].shape

## model_ts

In [None]:
bpos_1 = torch.randn(2*1024,3)
bpos_5 =torch.randn(2*1024,3)
bfea_1 = torch.randn(2*1024,256)
bfea_5 =torch.randn(2*1024,256)
t = torch.LongTensor([0,1])
dbatch = torch.repeat_interleave(t, 1024)

In [None]:
# Finds for each element in :obj:`y` the :obj:`k` nearest points in :obj:`x`.
index_ex = knn(x=bpos_1,y=bpos_5, k=10, batch_x=dbatch, batch_y=dbatch)
index_ex.shape, index_ex

In [None]:
index_ex[0]

In [None]:
a = torch.randn(2,3)
b = torch.randn(2,1)

c = torch.cat((a,b), dim=-1)
c.shape

In [None]:
class Local_Point_Trans(nn.Module):
    def __init__(self,k) -> None:
        super().__init__()
        self.k = int(k)
        self.lin_p = nn.Sequential(
            nn.Linear(4, 64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(),
            nn.Linear(64, 256),
        )

        self.lin_q = nn.Linear(256, 256)
        self.lin_k = nn.Linear(256, 256)
        self.lin_v = nn.Linear(256, 256)
        self.lin_w = nn.Sequential(
            nn.BatchNorm1d(256),
            nn.LeakyReLU(),
            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU())
        
        self.att = nn.Softmax(dim=1)


    def forward(self, f0, fea_cur, xyz_i, xyz_last, batch, t_i, t_cur=1.0):

        index = knn(x=xyz_i,y=xyz_last, k=self.k, batch_x=batch, batch_y=batch)

        idx_last, idx_i = index[0], index[1]

        fea_cur = fea_cur[idx_last].reshape(-1, self.k, 256).contiguous()  # [N, k, 256]
        f0 = f0[idx_i].reshape(-1, self.k, 256).contiguous()  # [N, k, 256]

        # print("fea_cur.shape",fea_cur.shape)
        # print("f0.shape",f0.shape)

        xyzt_last = torch.cat((xyz_last[idx_last], t_cur * torch.ones_like(xyz_last[idx_last][:,:1])), 
                               dim=-1).reshape(-1, self.k, 4).contiguous()
        xyzt_i = torch.cat((xyz_i[idx_i], t_i * torch.ones_like(xyz_i[idx_i][:,:1])), 
                           dim=-1).reshape(-1, self.k, 4).contiguous()
        # print("xyzt_last.shape",xyzt_last.shape)
        # print("xyzt_i.shape",xyzt_i.shape)
        pe = xyzt_i - xyzt_last
        for i, layer in enumerate(self.lin_p): pe = layer(pe.transpose(1, 2).contiguous()).transpose(1, 2).contiguous() if i == 1 else layer(pe)
        w = self.lin_q(fea_cur) - self.lin_k(f0) + pe
        # print("w.shape",w.shape)
        for i, layer in enumerate(self.lin_w): w = layer(w.transpose(1, 2).contiguous()).transpose(1, 2).contiguous() if i % 3 == 0 else layer(w)
        w = self.att(w)
        v = self.lin_v(f0) + pe

        res = torch.sum((w * v), dim=1)

        return res


In [None]:
lpt = Local_Point_Trans(10)
res = lpt(bfea_1, bfea_5, bpos_1, bpos_5, dbatch, 0.2)
res.shape


In [None]:
K = torch.randn(5, 512, 256)
Q = torch.randn(1, 512, 256)
V = torch.randn(5, 512, 256)


multihead_attn = nn.MultiheadAttention(256, 4)
attn_output, attn_output_weights = multihead_attn(Q, K, V)


In [None]:
attn_output.shape, attn_output_weights.shape

In [None]:
attn_output.squeeze(0).shape

## complete test

In [None]:
from models.time_series_model_v2 import PC_forecasting_model_0_1
encoder_args = {'df':3}
att_args = {'num_heads': 4, 'num_neighs': 20}

upsampler_args = {
    "upsampler": "nodeshuffle",
    "in_channels": 256,
    "out_channels": 64,
    "k": 40,
    "r": 16, 
    "conv":"gatv2",
    "heads":4
}

refiner_kargs = {
    "in_channels": 64,
    "out_channels": 3,
    "k": 30,
    "dilations": (1,2),
    "add_points": True
}


data_config = {'root': "/home/stud/ding/PC_FC/PC_forecasting/kittiraw/dataset/sequences", 'npoints': 4096, 'input_num': 5, 'pred_num': 5, 'tr_seqs': ['00']}
demo_dataset = KittiDataset_2(root=data_config['root'], npoints=data_config['npoints'], input_num=data_config['input_num'], pred_num=data_config['pred_num'], seqs=data_config['tr_seqs'])
train_dataloader = DataLoader(demo_dataset, batch_size=4) # collate_fn=custom_collate_fn
bin_xyz, bin_fea, bgt_xyz = next(iter(train_dataloader))

pc_fc_model = PC_forecasting_model_0_1(encoder_args, att_args, upsampler_args, refiner_kargs, data_config['npoints'])



## model

In [None]:
bin_xyz[0].shape, bin_fea[0].shape, bgt_xyz[0].shape

In [None]:
coa_list, det_list = pc_fc_model(bin_xyz, 5, 'cpu')
len(coa_list), len(det_list)


In [None]:
coa_list[-1].shape, det_list[0].shape

In [None]:
len(coa_list), len(det_list)

## loss