##  Model Training 

This code is to understand `engine.py`.

This code is to load `PTVflow3D` dataset.

In [1]:
import os
import logging
import warnings
import numpy as np
from tqdm import tqdm
from datetime import datetime

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

#from datasets.generic import Batch
from model.RAFTSceneFlow import RSF_DGCNN
from tools.loss import sequence_loss

from tools.metric import compute_epe_train, compute_epe2, compute_rmse_train, compute_rmse
from tools.utils import save_checkpoint
from deepptv.data import FluidflowDataset, FluidflowDataset3D

## 1. Load dataset

In [2]:
num_points = 512
folder = 'PTVflow3D_norm' # 'data_sample'
dataset_path = os.path.join('data/training_set', folder)
train_dataset = FluidflowDataset3D(npoints=num_points, root = dataset_path, partition='train')
val_dataset = FluidflowDataset3D(npoints=num_points, root = dataset_path, partition='test')
test_dataset = val_dataset

train :  1
test :  10


In [4]:
batch_size = 16
test_batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True,
                                           num_workers=4, drop_last=True)
val_dataloader = DataLoader(val_dataset, test_batch_size, shuffle=False, num_workers=4,
                                         drop_last=False)
test_dataloader = val_dataloader

## 2. Load the model

The model is `RSF_DGCNN model`

In [6]:
import torch
import torch.nn as nn

from model.extractor import FlotEncoder, FlotGraph
from model.corr2 import CorrBlock2
from model.update import UpdateBlock
from model.scale import KnnDistance
import model.ot as ot
from model.model_dgcnn import GeoDGCNN_flow2
import os
import sys
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

from model.flot.gconv import GeoSetConv
from model.flot.graph import ParGraph

## 2.1 Feature Extration Network

In [5]:
# from `model_dgcnn.py`
def knn(x, k):
    # return the index of k-nearest data of x
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   
    return idx

In [11]:
batch_size,num_points

(16, 512)

In [14]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# 0,512,1024
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

## 2.2 CorrBlock 2

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [23]:
class CorrBlock2(nn.Module):
    def __init__(self, num_levels=3, base_scale=0.25, resolution=3, truncate_k=128, knn=32):
        super(CorrBlock2, self).__init__()
        self.truncate_k = truncate_k
        self.num_levels = num_levels
        self.resolution = resolution  # local resolution
        self.base_scale = base_scale  # search (base_sclae * resolution)^3 cube
        self.out_conv = nn.Sequential(
            nn.Conv1d((self.resolution ** 3) * self.num_levels, 128, 1),
            nn.GroupNorm(8, 128),
            nn.PReLU(),
            nn.Conv1d(128, 64, 1)
        )
        self.knn = knn # 32
        
        self.knn_conv = nn.Sequential(
            nn.Conv2d(4, 64, 1),
            nn.GroupNorm(8, 64),
            nn.PReLU()
        )

        self.knn_out = nn.Conv1d(64, 64, 1)
    def init_module(self, fmap1, fmap2, xyz2, transport):
        b, n_p, _ = xyz2.size()
        xyz2 = xyz2.view(b, 1, n_p, 3).expand(b, n_p, n_p, 3)
        # corr = self.calculate_corr(fmap1, fmap2)
        # corr = self.calculate_ncc(fmap1, fmap2)  
        corr = transport  
        corr_topk = torch.topk(corr.clone(), k=self.truncate_k, dim=2, sorted=True)
        self.truncated_corr = corr_topk.values
        indx = corr_topk.indices.reshape(b, n_p, self.truncate_k, 1).expand(b, n_p, self.truncate_k, 3)
        self.ones_matrix = torch.ones_like(self.truncated_corr)
        
        self.truncate_xyz2 = torch.gather(xyz2, dim=2, index=indx)  # b, n_p1, k, 3

    def __call__(self, coords, all_delta_flow, num_iters, scale):
        
        return self.get_adaptive_voxel_feature2(coords, all_delta_flow, num_iters, scale) + self.get_dynamic_knn_feature(coords, all_delta_flow)  ######## modified ########
    
    def get_adaptive_voxel_feature2(self, coords, all_delta_flow, num_iters, scale):
        b, n_p, _ = coords.size()
        corr_feature = []
        from torch_scatter import scatter_add
        for i in range(self.num_levels):
            r = scale * (2 ** i)
            dis_voxel = torch.round((self.truncate_xyz2 - coords.unsqueeze(dim=-2)) / r) # [B, N, truncate_k, 3]
            ### TO DO ###
            ids = len(all_delta_flow)
            if ids >= 2 and ids < (num_iters-2):
                r1_norm = torch.norm(all_delta_flow[-1], dim=-1, keepdim=True)  # [B, N, 1]
                r1_max = torch.max(r1_norm, dim=1, keepdim=True)[0] # [B, 1, 1]
                cos1 = torch.cosine_similarity(all_delta_flow[-1], all_delta_flow[-2], dim=-1).unsqueeze(dim=-1)  # [B, N, 1]
                r1_len = all_delta_flow[-1] / r1_norm * (0.125*r1_norm/r1_max + 0.125*cos1 + 1) * r # self.base_scale
                r1 = r1_len.unsqueeze(dim=-2).repeat(1, 1, self.truncate_xyz2.shape[2], 1) # [B, N, truncate_k, 3]
                r2_p = all_delta_flow[-2] - cos1 * torch.norm(all_delta_flow[-2], dim=-1, keepdim=True) * all_delta_flow[-1] / r1_norm
                r2_norm = torch.norm(r2_p, dim=-1, keepdim=True)  # [B, N, 1]
                r2_max = torch.max(r2_norm, dim=1, keepdim=True)[0] # [B, 1, 1]
                r2_len = r2_p / r2_norm * (0.125*r2_norm/r2_max + 0.9) * r # self.base_scale
                r2 = r2_len.unsqueeze(dim=-2).repeat(1, 1, self.truncate_xyz2.shape[2], 1)
                r3_cross = torch.cross(all_delta_flow[-1], all_delta_flow[-2])
                r3_len = r3_cross / torch.norm(r3_cross, dim=-1, keepdim=True) * 0.9 * r # self.base_scale
                r3 = r3_len.unsqueeze(dim=-2).repeat(1, 1, self.truncate_xyz2.shape[2], 1)
                dis_voxel[:, :, :, 0] = torch.round(torch.sum((self.truncate_xyz2 - coords.unsqueeze(dim=-2)) * r1, dim=-1) / torch.norm(r1, dim=-1))
                dis_voxel[:, :, :, 1] = torch.round(torch.sum((self.truncate_xyz2 - coords.unsqueeze(dim=-2)) * r2, dim=-1) / torch.norm(r2, dim=-1))
                dis_voxel[:, :, :, 2] = torch.round(torch.sum((self.truncate_xyz2 - coords.unsqueeze(dim=-2)) * r3, dim=-1) / torch.norm(r3, dim=-1))
            ### TO DO ###
            valid_scatter = (torch.abs(dis_voxel) <= np.floor(self.resolution / 2)).all(dim=-1) # [B, N, truncate_k]
            dis_voxel = dis_voxel - (-1)
            cube_idx = dis_voxel[:, :, :, 0] * (self.resolution ** 2) +\
                dis_voxel[:, :, :, 1] * self.resolution + dis_voxel[:, :, :, 2] # [B, N, truncate_k]
            cube_idx_scatter = cube_idx.type(torch.int64) * valid_scatter

            valid_scatter = valid_scatter.detach()
            cube_idx_scatter = cube_idx_scatter.detach()

            corr_add = scatter_add(self.truncated_corr * valid_scatter, cube_idx_scatter)
            corr_cnt = torch.clamp(scatter_add(self.ones_matrix * valid_scatter, cube_idx_scatter), 1, n_p)
            corr = corr_add / corr_cnt
            if corr.shape[-1] != self.resolution ** 3:
                repair = torch.zeros([b, n_p, self.resolution ** 3 - corr.shape[-1]], device=coords.device)
                corr = torch.cat([corr, repair], dim=-1)

            corr_feature.append(corr.transpose(1, 2).contiguous())

        return self.out_conv(torch.cat(corr_feature, dim=1))

    ######## modified ########
    def get_dynamic_knn_feature(self, coords, all_delta_flow):
        b, n_p, _ = coords.size()

        dist = self.truncate_xyz2 - coords.view(b, n_p, 1, 3)
        dist = torch.sum(dist ** 2, dim=-1)     # b, 8192, 512

        if len(all_delta_flow) < 8: ## *modified 20220401* ##
            dynamic_k = self.knn - 2 * len(all_delta_flow)
        else:
            dynamic_k = self.knn - 2 * 8
        # dynamic_k = self.knn - 2 * len(all_delta_flow)
        neighbors = torch.topk(-dist, k=dynamic_k, dim=2).indices

        b, n_p, _ = coords.size()
        knn_corr = torch.gather(self.truncated_corr.view(b * n_p, self.truncate_k), dim=1,
                                index=neighbors.reshape(b * n_p, dynamic_k)).reshape(b, 1, n_p, dynamic_k)

        neighbors = neighbors.view(b, n_p, dynamic_k, 1).expand(b, n_p, dynamic_k, 3)
        knn_xyz = torch.gather(self.truncate_xyz2, dim=2, index=neighbors).permute(0, 3, 1, 2).contiguous()
        knn_xyz = knn_xyz - coords.transpose(1, 2).reshape(b, 3, n_p, 1)

        knn_feature = self.knn_conv(torch.cat([knn_corr, knn_xyz], dim=1)) # [B, C, N, K]
        knn_feature = torch.max(knn_feature, dim=3)[0] # [B, C, N]
        return self.knn_out(knn_feature)

    @staticmethod
    def calculate_corr(fmap1, fmap2):
        batch, dim, num_points = fmap1.shape
        corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
        corr = corr / torch.sqrt(torch.tensor(dim).float())
        return corr

    ###### *modified* ######
    @staticmethod
    def calculate_ncc(fmap1, fmap2):
        batch, dim, num_points = fmap1.shape
        corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
        # corr = corr / torch.sqrt(torch.tensor(dim).float())
        n1 = torch.norm(fmap1, dim=1, keepdim=True)
        n2 = torch.norm(fmap2, dim=1, keepdim=True)
        ncc = corr / torch.matmul(n1.transpose(1, 2), n2)
        return ncc

In [16]:
# from `model_dgcnn.py`
def get_graph_feature(x, k=20, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim9 == False:
            idx = knn(x, k=k)   
        else:
            idx = knn(x[:, 6:], k=k)
    # device = torch.device('cuda')
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    # idx_base = 0,512,1024
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points
    
    # idx of k-nearest data
    idx = idx + idx_base

    idx = idx.view(-1)
     
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous() # 
    
    # # find k-nearest features in k
    feature = x.view(batch_size*num_points, -1)[idx, :] 
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    # feature-x: 以x为原点，check相对距离。
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature 

In [17]:
class GeoDGCNN_flow2(nn.Module):
    def __init__(self, k, emb_dims, dropout):
        super(GeoDGCNN_flow2, self).__init__()
        # self.args = args
        self.k = k
        self.emb_dims = emb_dims
        self.dropout = dropout
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(96)
        self.bn4 = nn.BatchNorm2d(96)
        self.bn5 = nn.BatchNorm1d(self.emb_dims)
        self.bn6 = nn.BatchNorm1d(512)
        self.bn7 = nn.BatchNorm1d(256)
        
        self.conv1 = nn.Sequential(nn.Conv2d(32*2, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 96, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(96, 96, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv1d(352, self.emb_dims, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(1376, 512, kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv7 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False),
                                   self.bn7,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=self.dropout)
        self.conv8 = nn.Conv1d(256, 128, kernel_size=1, bias=False)

        self.feat_conv1 = GeoSetConv(3, 32)
        self.feat_conv2 = GeoSetConv(32, 64)
        self.feat_conv3 = GeoSetConv(64, 96)
        

    def forward(self, x):
        geo_graph = ParGraph.construct_graph(x, self.k)
        g1 = self.feat_conv1(x, geo_graph)     # B x nb_feat_out x N
        g2 = self.feat_conv2(g1, geo_graph)
        g3 = self.feat_conv3(g2, geo_graph)
        g1 = g1.transpose(1, 2).contiguous() 
        g2 = g2.transpose(1, 2).contiguous() 
        g3 = g3.transpose(1, 2).contiguous() 
        
        print('Content of input of get_graph_feature:',g1)
        
        x = get_graph_feature(g1, k=self.k)  # get graph features   
        x = self.conv1(x)                       
        x = self.conv2(x)  
        
        print('Content of x', x)
        x2 = x.max(dim=-1, keepdim=False)[0]    

        x = get_graph_feature(x2, k=self.k)     
        x = self.conv3(x)                       
        x = self.conv4(x)                       
        x3 = x.max(dim=-1, keepdim=False)[0]    

        mid = torch.cat((g1, x2, x3, g2, g3), dim=1)      

        x = self.conv5(mid)                       
        x = torch.cat((x, mid), dim=1)   
        x = self.conv6(x)                       
        x = self.conv7(x)                       
        x = self.dp1(x)
        x = self.conv8(x)                       
        
        return x

In [18]:
class FlotEncoder(nn.Module):
    def __init__(self, num_neighbors=32):
        super(FlotEncoder, self).__init__()
        n = 32
        self.num_neighbors = num_neighbors
        self.feat_conv1 = SetConv(3, n)
        self.feat_conv2 = SetConv(n, 2 * n)
        self.feat_conv3 = SetConv(2 * n, 4 * n)

    def forward(self, pc):
        # pc
        graph = Graph.construct_graph(pc, self.num_neighbors)
        x = self.feat_conv1(pc, graph)
        x = self.feat_conv2(x, graph)
        x = self.feat_conv3(x, graph)
        x = x.transpose(1, 2).contiguous() # B,C,N
        return x, graph

## 2.4 DGCNN

In [19]:
class RSF_DGCNN(nn.Module):
    def __init__(self, args):
        super(RSF_DGCNN, self).__init__()
        base_scales = 0.25
        truncate_k = 2048
        self.hidden_dim = 64
        self.context_dim = 64
        
        self.feature_extractor = GeoDGCNN_flow2(k=32, emb_dims=1024, dropout=0.5)
        
        self.context_extractor = FlotEncoder()
        # self.graph_extractor = FlotGraph()
        self.corr_block = CorrBlock2(num_levels=args.corr_levels, base_scale=base_scales,
                                    resolution=3, truncate_k=truncate_k)
        self.update_block = UpdateBlock(hidden_dim=self.hidden_dim)

        self.scale_offset = nn.Parameter(torch.ones(1)/2.0) # torch.ones(1)/10.0
        self.gamma = nn.Parameter(torch.zeros(1))
        self.epsilon = nn.Parameter(torch.zeros(1))

    def forward(self, p, num_iters=12):
        # feature extraction
        [xyz1, xyz2] = p # B x N x 3
        fmap1 = self.feature_extractor(p[0])
        fmap2 = self.feature_extractor(p[1])
        ## modified scale ##
        nn_distance = KnnDistance(p[0], 3)
        voxel_scale = self.scale_offset * nn_distance

        # correlation matrix
        transport = ot.sinkhorn(fmap1.transpose(1,-1), fmap2.transpose(1,-1), xyz1, xyz2, 
            epsilon=torch.exp(self.epsilon) + 0.03, 
            gamma=self.gamma, #torch.exp(self.gamma), 
            max_iter=1)
        self.corr_block.init_module(fmap1, fmap2, xyz2, transport)
        
        print('input of flow encoder:',p)
        fct1, graph_context = self.context_extractor(p[0]) # Flot Encoder

        net, inp = torch.split(fct1, [self.hidden_dim, self.context_dim], dim=1)
        net = torch.tanh(net)
        inp = torch.relu(inp)

        coords1, coords2 = xyz1, xyz1
        flow_predictions = []
        all_delta_flow = []  

        for itr in range(num_iters):
            coords2 = coords2.detach()
            corr = self.corr_block(coords=coords2, all_delta_flow=all_delta_flow, num_iters=num_iters, scale=voxel_scale)  
            flow = coords2 - coords1
            net, delta_flow = self.update_block(net, inp, corr, flow, graph_context)
            all_delta_flow.append(delta_flow)  
            coords2 = coords2 + delta_flow
            flow_predictions.append(coords2 - coords1)

        return flow_predictions

## 2.5 Main training part

In [26]:
model = RSF_DGCNN(args)

NameError: name 'args' is not defined