# Models
- https://github.com/POSTECH-CVLab/FastPointTransformer
    - This also have Minkowsi inside
- https://github.com/PRBonn/Mask4D

# Inspiration
- https://arxiv.org/pdf/2309.16133.pdf
- https://arxiv.org/pdf/2305.16404.pdf 
    - progresive superpoints, unsupervised  

In [8]:
%load_ext autoreload
%autoreload 2

import os
os.chdir('/home/vacekpa2/4D-RNSFP')
import torch
import time
import numpy as np
from pytorch3d.ops.knn import knn_points

from data.dataloader import SFDataset4D
from vis.deprecated_vis import *
from loss.flow import *

device = torch.device('cuda:0')

dataset = SFDataset4D(dataset_type='waymo', n_frames=1)
data = dataset.__getitem__(80)

device = torch.device('cuda:0')
pc1 = data['pc1'].to(device)
pc2 = data['pc2'].to(device)
id_mask1 = data['id_mask1'].to(device)
mos1 = data['mos1'].to(device)

id_mask1[mos1==False] = 0

K = len(torch.unique(id_mask1)) + 5 # number of objects to infer

# visualize_points3D(pc1.view(-1,3), id_mask1.view(-1))
# os.system('nvidia-smi')
# init by dbscan
from sklearn.cluster import DBSCAN, KMeans
numpy_st_pc = np.concatenate([np.insert(pc1[i].cpu().numpy(), 3, i * 0.20, axis=1) for i in range(pc1.shape[0])])
# numpy_st_pc = np.concatenate((pc1[0].cpu().numpy(), data['gt_flow'][0].cpu().numpy()), axis=1)
%timeit init_cluster_ids = DBSCAN(eps=0.2, min_samples=3).fit_predict(numpy_st_pc)
# init_cluster_ids = KMeans(n_clusters=K, n_init=5).fit_predict(numpy_st_pc)
init_cluster_ids = torch.tensor(init_cluster_ids, dtype=torch.long, device=device)
visualize_points3D(pc1.view(-1,3), init_cluster_ids.view(-1))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
780 ms ± 7.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


  init_cluster_ids = torch.tensor(init_cluster_ids, dtype=torch.long, device=device)


### Framework for joint trajectory, instance segmentation and flow estimation

### Input: Sequence of point clouds as batch of size (B, N, 3) and mask of size (B, N, K) where K is number of objects in the scene.
### Output: Flow, instance weights, and trajectories

### Features
- [x] fitting flow from each frame to the next
- [x] Cyclic smoothness
- [] fitting instance segmentation from each frame to the next
- [ ] model plug and play
- [ ] losses inside model

### Do not Generate:


### Representation

### PseudoCode

In [2]:


def cyclic_smoothness(pc1, est_flow, pc2, NN_pc2, NN_forward=None, pc2_smooth=True):
    
    if NN_forward is None:
        # print('not now')
        _, NN_forward, _ = knn_points(pc1 + est_flow, pc2, lengths1=None, lengths2=None, K=1, norm=1)
    
    a = est_flow[0]

    ind = NN_forward[0] # more than one?
    
    if pc1.shape[1] < pc2.shape[1]:
        shape_diff = pc2.shape[1] - ind.shape[0] + 1 # one for dummy    # what if pc1 is bigger than pc2?
        a = torch.nn.functional.pad(a, (0,0,0, shape_diff), mode='constant', value=0)
        a.retain_grad() # padding does not retain grad, need to do it manually. Check it

        ind = torch.nn.functional.pad(ind, (0,0,0, shape_diff), mode='constant', value=pc2.shape[1])  # pad with dummy not in orig

    # storage of same points
    vec = torch.zeros(ind.shape[0], 3, device=pc1.device)

    # this is forward flow withnout NN_pc2 smoothness
    vec = vec.scatter_reduce_(0, ind.repeat(1,3), a, reduce='mean', include_self=False)

    forward_flow_loss = torch.nn.functional.mse_loss(vec[ind[:,0]], a, reduction='none').mean(dim=-1)

    if pc2_smooth:
        # rest is pc2 smoothness with pre-computed NN
        keep_ind = ind[ind[:,0] != pc2.shape[1] ,0]

        # znamena, ze est flow body maji tyhle indexy pro body v pc2 a ty indexy maji mit stejne flow.
        n = NN_pc2[0, keep_ind, :]

        # beware of zeros!!!
        connected_flow = vec[n] # N x KSmooth x 3 (fx, fy, fz)

        prep_flow = est_flow[0].unsqueeze(1).repeat_interleave(repeats=NN_pc2.shape[-1], dim=1) # correct

        # smooth it, should be fine
        # print(prep_flow.shape, connected_flow.shape)
        flow_diff = prep_flow - connected_flow  # correct operation, but zeros makes problem

        occupied_mask = connected_flow.all(dim=2).repeat(3,1,1).permute(1,2,0)

        # occupied_mask
        per_flow_dim_diff = torch.masked_select(flow_diff, occupied_mask)

        # per_point_loss = per_flow_dim_diff.norm(dim=-1).mean()
        NN_pc2_loss = (per_flow_dim_diff ** 2).mean()    # powered to 2 because norm will sum it directly

    else:
        NN_pc2_loss = torch.tensor(0.)

    forward_loss = forward_flow_loss.mean() + NN_pc2_loss

    return forward_loss#, forward_flow_loss

def get_instance(mask_weights):
    return torch.nn.functional.softmax(mask_weights, dim=2)



In [3]:
# fit  flow
flow_weights = torch.zeros(pc1.shape, device=device, requires_grad=True)
mask_weights = torch.rand((pc1.shape[0], pc1.shape[1], K), device=device, requires_grad=True)

# batched fastKNN
NN_modules = [FastNN(pc1[i:i+1],pc2[i:i+1], cell_size=0.075) for i in range(pc1.shape[0])]
full_ids = [NN_modules[i].full_ids for i in range(pc1.shape[0])]
full_ids = torch.stack(full_ids).permute(0,2,3,4,1)


In [6]:
optimizer = torch.optim.Adam([flow_weights, mask_weights], lr=0.008)

# smoothness for batch size
SmoothModule = SmoothnessLoss(pc1=pc1, pc2=pc2, K=12, max_radius=1, pc2_smooth=True)
NN_pc2 = SmoothModule.NN_pc2


for e in range(300):
    dist_list = []
    bs = pc1.shape[0]
    for t in range(pc1.shape[0]):
        dist, NN_index = NN_modules[t](pc1[t:t+1], flow_weights[t:t+1], pc2[t:t+1])
        dist_list.append(dist)
    # bs only 1
    
    # cyclic_smooth_loss = cyclic_smoothness(pc1, flow_weights, pc2, NN_pc2, NN_forward=NN_index.unsqueeze(0).unsqueeze(2), pc2_smooth=True)
    
    loss_dist = torch.cat(dist_list)
    smooth_loss = SmoothModule(pc1, flow_weights, pc2)
    mask_smooth_loss = SmoothModule(pc1, mask_weights, pc2)
    
    # pseudo_mask_loss = 
    # Art loss
    static_mask = flow_weights.norm(dim=-1) < 0.05
    # todo reweight art loss dynamically?
    
    label_loss = - (mask_weights[:, ..., 0] * static_mask).softmax(dim=1).mean() + (mask_weights[:, ..., 0] * ~static_mask).softmax(dim=1).mean() 

    # loss = loss_dist.mean() + smooth_loss.mean()
    
    loss = bs * loss_dist.mean() + bs * smooth_loss.mean() + mask_smooth_loss.mean() + label_loss.mean() #+ cyclic_smooth_loss.mean() 
    # loss = smooth_loss.mean()
    loss.backward()
    
    optimizer.step()
    optimizer.zero_grad()
    
    print(e, 'Flow Loss: ', f"{loss_dist.mean().item():.4f}", 'Smooth Loss: ', f"{smooth_loss.mean().item():.4f}") #,'Cyclic Smooth Loss: ', f"{cyclic_smooth_loss.mean().item():.4f}")


vis_pc1 = pc1.view(-1, 3).cpu()
vis_pc2 = pc2.view(-1, 3).cpu()
vis_flow = flow_weights.view(-1, 3).cpu()

visualize_flow3d(vis_pc1, vis_pc2, vis_flow)
pred_inst_ids = torch.argmax(mask_weights, dim=2).view(-1).cpu()
# visualize_points3D(vis_pc1, pred_inst_ids)
# visualize_points3D(vis_pc1, pred_inst_ids > 0)

0 Flow Loss:  0.0189 Smooth Loss:  0.0019
1 Flow Loss:  0.0285 Smooth Loss:  0.0182
2 Flow Loss:  0.0264 Smooth Loss:  0.0143
3 Flow Loss:  0.0243 Smooth Loss:  0.0091
4 Flow Loss:  0.0228 Smooth Loss:  0.0091
5 Flow Loss:  0.0218 Smooth Loss:  0.0098
6 Flow Loss:  0.0221 Smooth Loss:  0.0095
7 Flow Loss:  0.0225 Smooth Loss:  0.0086
8 Flow Loss:  0.0221 Smooth Loss:  0.0077
9 Flow Loss:  0.0211 Smooth Loss:  0.0068
10 Flow Loss:  0.0205 Smooth Loss:  0.0062
11 Flow Loss:  0.0215 Smooth Loss:  0.0060
12 Flow Loss:  0.0217 Smooth Loss:  0.0059
13 Flow Loss:  0.0210 Smooth Loss:  0.0057
14 Flow Loss:  0.0201 Smooth Loss:  0.0053
15 Flow Loss:  0.0202 Smooth Loss:  0.0049
16 Flow Loss:  0.0207 Smooth Loss:  0.0047
17 Flow Loss:  0.0206 Smooth Loss:  0.0045
18 Flow Loss:  0.0200 Smooth Loss:  0.0043
19 Flow Loss:  0.0199 Smooth Loss:  0.0041
20 Flow Loss:  0.0203 Smooth Loss:  0.0039
21 Flow Loss:  0.0202 Smooth Loss:  0.0038
22 Flow Loss:  0.0198 Smooth Loss:  0.0036
23 Flow Loss:  0.0198

In [5]:
visualize_flow3d(pc1[2], pc2[2], flow_weights[2])

In [6]:
est_flow = flow_weights
NN_idx = SmoothModule.NN_pc1


est_flow_neigh = est_flow.reshape(-1, 3)[NN_idx.reshape(-1, 12)]
print(est_flow_neigh.shape)
flow_diff = est_flow_neigh[:, :1, :] - est_flow_neigh[:, 1:, :]
# est_flow_neigh = est_flow_neigh[:, 1:, :]   # drop identity to ease computation

smooth_flow_loss = flow_diff.norm(dim=2).mean()
print(smooth_flow_loss)

# torch.index_select(est_flow, 1, NN_idx).shape
# index est_flow with NN_idx to get points


torch.Size([199958, 12, 3])
tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)
