In [1]:
import os
import random
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import torch
import torch_geometric
from einops import rearrange
from open3d.web_visualizer import draw
from torch import nn
from torch_geometric.datasets import ModelNet, ShapeNet
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, knn, knn_graph
from tqdm.notebook import tqdm

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
[Open3D INFO] Resetting default logger to print to terminal.


In [3]:
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [4]:
def viz_pcd_graph(points: np.array, edge_list: np.array):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    ls = o3d.geometry.LineSet.create_from_point_cloud_correspondences(
        pcd, pcd, edge_list
    )
    ls.paint_uniform_color([0.5, 0.5, 0.5])
    o3d.visualization.draw_geometries([pcd, ls])

In [5]:
def get_batch(pcd):
    assert pcd.ndim == 3
    b, n, _ = pcd.shape
    batch = torch.repeat_interleave(torch.arange(b), repeats=n).type(torch.long)
    return batch.to(pcd.device)


def fps_subsample(pcd, n_points: int, random_start: bool = False):
    b, n, _ = pcd.shape
    batch = get_batch(pcd)
    pcd = rearrange(pcd, "b n c -> (b n) c")
    idxs = fps(pcd, batch=batch, ratio=n_points / n, random_start=False)
    return to_dense_batch(pcd[idxs], batch=batch[idxs])[0]

# Cross transformer

In [67]:
p1 = torch.tensor([[0, 0, 0], [10, 10, 10], [20, 20, 20]]).type(torch.float32)
p1.shape

torch.Size([3, 3])

In [68]:
p2 = torch.tensor(
    [
        [1, 1, 1],
        [-1, -1, -1],
        [11, 11, 11],
        [9, 9, 9],
        [21, 21, 21],
        [19, 19, 19],
    ]
).type(torch.float32)
p2.shape

torch.Size([6, 3])

In [69]:
p3 = torch.rand(2048, 3)

In [70]:
f1 = torch.ones(3, 2)
f2 = torch.ones(6, 2) * 2
f3 = torch.ones(2048, 2)

In [71]:
p1_ = rearrange(p1, "n d -> 1 d n").cuda()
p2_ = rearrange(p2, "n d -> 1 d n").cuda()
p3_ = rearrange(p3, "n d -> 1 d n").cuda()
f1_ = rearrange(f1, "n d -> 1 d n").cuda()
f2_ = rearrange(f2, "n d -> 1 d n").cuda()
f3_ = rearrange(f3, "n d -> 1 d n").cuda()

## FBNet Cross Transfromer

In [11]:
from einops import rearrange
from FBNet import CrossTransformer as FBCrossTransformer
from FBNet import grouping_operation, query_knn
from torch import einsum

In [12]:
class FBCrossTransformer(nn.Module):
    def __init__(
        self,
        in_channel,
        dim=256,
        n_knn=16,
        pos_hidden_dim=64,
        attn_hidden_multiplier=4,
    ):
        super().__init__()
        self.n_knn = n_knn

        self.pos_mlp = nn.Sequential(
            nn.Conv2d(3, pos_hidden_dim, 1),
            nn.BatchNorm2d(pos_hidden_dim),
            nn.ReLU(),
            nn.Conv2d(pos_hidden_dim, in_channel, 1),
        )

        self.attn_mlp = nn.Sequential(
            nn.Conv2d(in_channel, dim * attn_hidden_multiplier, 1),
            nn.BatchNorm2d(dim * attn_hidden_multiplier),
            nn.ReLU(),
            nn.Conv2d(dim * attn_hidden_multiplier, in_channel, 1),
        )

    def forward(self, pcd, feat, pcd_feadb, feat_feadb):
        """
        Args:
            pcd: (B, 3, N)
            feat: (B, in_channel, N)
            pcd_feadb: (B, 3, N2)
            feat_feadb: (B, in_channel, N2)

        Returns:
            Tensor: (B, in_channel, N), shape context feature
        """
        b, _, num_point = pcd.shape

        fusion_pcd = torch.cat((pcd, pcd_feadb), dim=2)
        fusion_feat = torch.cat((feat, feat_feadb), dim=2)

        key_point = pcd
        key_feat = feat

        # Preception processing between pcd and fusion_pcd
        key_point_idx = query_knn(
            self.n_knn,
            fusion_pcd.transpose(2, 1).contiguous(),
            key_point.transpose(2, 1).contiguous(),
            include_self=True,
        )

        group_point = grouping_operation(fusion_pcd, key_point_idx)
        group_feat = grouping_operation(fusion_feat, key_point_idx)

        # print(f"{group_point=}")
        # print(f"{key_point=}")

        qk_rel = key_feat.reshape((b, -1, num_point, 1)) - group_feat
        pos_rel = key_point.reshape((b, -1, num_point, 1)) - group_point

        pos_embedding = self.pos_mlp(pos_rel)
        sample_weight = self.attn_mlp(
            qk_rel + pos_embedding
        )  # b, in_channel + 3, n, n_knn
        sample_weight = torch.softmax(
            sample_weight, -1
        )  # b, in_channel + 3, n, n_knn

        group_feat = group_feat + pos_embedding
        refined_feat = einsum(
            "b c i j, b c i j -> b c i", sample_weight, group_feat
        )

        return refined_feat

In [13]:
knn(p2, p1, k=2)

tensor([[0, 0, 1, 1, 2, 2],
        [0, 1, 2, 3, 4, 5]])

In [14]:
query_knn(
    2,
    p2_.transpose(2, 1).contiguous(),
    p1_.transpose(2, 1).contiguous(),
    include_self=True,
)

tensor([[[1, 0],
         [3, 2],
         [5, 4]]], device='cuda:0', dtype=torch.int32)

In [15]:
cross_transformer_fb = FBCrossTransformer(
    in_channel=2, dim=16, attn_hidden_multiplier=1, pos_hidden_dim=8, n_knn=2
)

In [16]:
cross_transformer_fb.cuda()

FBCrossTransformer(
  (pos_mlp): Sequential(
    (0): Conv2d(3, 8, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(8, 2, kernel_size=(1, 1), stride=(1, 1))
  )
  (attn_mlp): Sequential(
    (0): Conv2d(2, 16, kernel_size=(1, 1), stride=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [17]:
out = cross_transformer_fb(p1_, f1_, p2_, f2_)

In [18]:
out

tensor([[[1.6808, 1.6808, 1.6808],
         [1.5930, 1.5930, 1.5930]]], device='cuda:0', grad_fn=<ViewBackward0>)

## My Cross transfomer 

In [2]:
class CrossTransformer(MessagePassing):
    def __init__(
        self,
        in_channel: int,
        pos_hidden_dim: int,
        attn_hidden_dim: int,
        k: int = 2,
    ):
        super().__init__(aggr="add", flow="target_to_source")
        self.k = k
        self.pos_mlp = nn.Sequential(
            nn.Linear(in_features=3, out_features=pos_hidden_dim),
            nn.BatchNorm1d(pos_hidden_dim),
            nn.ReLU(),
            nn.Linear(in_features=pos_hidden_dim, out_features=in_channel),
        )

        self.attn_mlp = nn.Sequential(
            nn.Linear(in_channel, attn_hidden_dim),
            nn.BatchNorm1d(attn_hidden_dim),
            nn.ReLU(),
            nn.Linear(attn_hidden_dim, in_channel),
        )

    def forward(
        self, px, py, fx, fy, edge_index=None, batch_x=None, batch_y=None
    ):

        # Include self in target point cloud
        p_fusion = torch.cat([px, py])
        f_fusion = torch.cat([fx, fy])
        if batch_x is not None:
            fusion_batch = torch.cat([batch_x, batch_y])
        else:
            fusion_batch = None

        if edge_index is None:
            edge_index = knn(
                x=p_fusion,
                y=px,
                batch_x=fusion_batch,
                batch_y=batch_x,
                k=self.k,
            )

        # flow = "target_to_source" => (x_i, x_j), (pos_i, pos_j)
        out = self.propagate(edge_index, x=(fx, f_fusion), pos=(px, p_fusion))
        return out

    def message(self, x_i, x_j, pos_i, pos_j, index):
        # Positional embedding
        delta_ij = self.pos_mlp(pos_i - pos_j)
        # Attention embedding
        attn_weights = self.attn_mlp(x_i - x_j + delta_ij)
        # Normalize attention
        attn_weights = torch_geometric.utils.softmax(
            attn_weights, index=index, num_nodes=None
        )

        # Multiply with the attention weights
        out = attn_weights * (x_j + delta_ij)
        return out

In [7]:
cross_transformer = CrossTransformer(
    in_channel=2, pos_hidden_dim=8, attn_hidden_dim=16, k=2
)

In [8]:
out = cross_transformer(p1, p2, f1, f2)
out, out.shape

NameError: name 'p1' is not defined

In [9]:
tpx = torch.rand(30, 3)
tpy = torch.rand(60, 3)
tfx = torch.rand(30, 2)
tfy = torch.rand(60, 2)
x_batch = torch.repeat_interleave(torch.arange(3), 10)
y_batch = torch.repeat_interleave(torch.arange(3), 20)

In [10]:
res = cross_transformer(tpx, tpy, tfx, tfy, batch_x=x_batch, batch_y=y_batch)
res.shape

torch.Size([30, 2])

# Adapt graph pooling

## FB Net

In [22]:
from FBNet import (
    furthest_point_sample,
    gather_operation,
    grouping_operation,
    query_knn,
)

In [23]:
class FBAdaptGraphPooling(nn.Module):
    def __init__(self, pooling_rate, in_channel, neighbor_num, dim=64):
        super().__init__()
        self.pooling_rate = pooling_rate
        self.neighbor_num = neighbor_num

        self.pos_mlp = nn.Sequential(
            nn.Conv2d(3, 64, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(64, in_channel, 1),
        )

        self.attn_mlp = nn.Sequential(
            nn.Conv2d(in_channel, dim, 1),
            nn.BatchNorm2d(dim),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(dim, 3 + in_channel, 1),
        )

    def forward(
        self,
        vertices: "(bs, 3, vertice_num)",
        feature_map: "(bs, channel_num, vertice_num)",
        idx=False,
    ):
        """

        Return:
            vertices_pool: (bs, 3, pool_vertice_num),
            feature_map_pool: (bs, channel_num, pool_vertice_num)
        """

        bs, _, vertice_num = vertices.size()
        new_npoints = int(vertice_num * 1.0 / self.pooling_rate + 0.5)
        key_points_idx = furthest_point_sample(
            vertices.transpose(2, 1).contiguous(), new_npoints
        )
        key_point = gather_operation(vertices.contiguous(), key_points_idx)
        key_feat = gather_operation(feature_map.contiguous(), key_points_idx)

        key_point_idx = query_knn(
            self.neighbor_num,
            vertices.transpose(2, 1).contiguous(),
            key_point.transpose(2, 1).contiguous(),
            include_self=True,
        )

        group_point = grouping_operation(vertices.contiguous(), key_point_idx)
        group_feat = grouping_operation(feature_map.contiguous(), key_point_idx)

        qk_rel = key_feat.reshape((bs, -1, new_npoints, 1)) - group_feat
        pos_rel = key_point.reshape((bs, -1, new_npoints, 1)) - group_point

        pos_embedding = self.pos_mlp(pos_rel)
        sample_weight = self.attn_mlp(
            qk_rel + pos_embedding
        )  # b, in_channel + 3, n, n_knn
        sample_weight = torch.softmax(
            sample_weight, -1
        )  # b, in_channel + 3, n, n_knn
        new_xyz_weight = sample_weight[:, :3, :, :]  # b, 3, n, n_knn
        new_feture_weight = sample_weight[
            :, 3:, :, :
        ]  # b, in_channel, n, n_knn

        group_feat = group_feat + pos_embedding  #
        new_feat = einsum(
            "b c i j, b c i j -> b c i", new_feture_weight, group_feat
        )
        new_point = einsum(
            "b c i j, b c i j -> b c i", new_xyz_weight, group_point
        )

        return new_point, new_feat

In [26]:
agp = FBAdaptGraphPooling(in_channel=2, pooling_rate=4, neighbor_num=3).cuda()

In [27]:
agp(p2_, f2_)

(tensor([[[ 1.3120, 15.3659],
          [ 2.6902, 14.9962],
          [ 2.8790, 17.4803]]], device='cuda:0', grad_fn=<ViewBackward0>),
 tensor([[[2.1771, 1.8801],
          [2.1400, 1.9103]]], device='cuda:0', grad_fn=<ViewBackward0>))

In [28]:
tp = torch.rand(3, 3, 10).cuda()
tf = torch.rand(3, 2, 10).cuda()
# x_batch = torch.repeat_interleave(torch.arange(3), 10)

In [29]:
res1, res2 = agp(tp, tf)
res1.shape, res2.shape

(torch.Size([3, 3, 3]), torch.Size([3, 2, 3]))

## My Adapt Graph Pooling

In [3]:
import math

from torch_geometric.nn import fps

In [4]:
class AdaptGraphPooling(MessagePassing):
    def __init__(
        self,
        in_channels: int,
        pooling_rate: int = 0.25,
        pos_hidden_dim: int = 64,
        attn_hidden_dim: int = 64,
        k=3,
    ):
        super().__init__(aggr="add", flow="target_to_source")
        self.pooling_rate = pooling_rate
        self.k = k
        self.pos_mlp = nn.Sequential(
            nn.Linear(in_features=3, out_features=pos_hidden_dim),
            nn.BatchNorm1d(pos_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(in_features=pos_hidden_dim, out_features=in_channels),
        )

        self.feat_attn_mlp = nn.Sequential(
            nn.Linear(in_channels, attn_hidden_dim),
            nn.BatchNorm1d(attn_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(attn_hidden_dim, in_channels),
        )

        self.pos_attn_mlp = nn.Sequential(
            nn.Linear(in_channels, attn_hidden_dim),
            nn.BatchNorm1d(attn_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(attn_hidden_dim, 3),
        )

    def forward(self, x, pos, batch=None, return_batch=True):

        n_points = len(pos)  # careful with batching
        idxs = fps(
            pos, batch=batch, ratio=self.pooling_rate, random_start=False
        )
        pos_ = pos[idxs]
        x_ = x[idxs]

        if batch is not None:
            batch_ = batch[idxs]
        else:
            batch_ = None

        edge_index = knn(pos, pos_, batch_x=batch, batch_y=batch_, k=self.k)
        out = self.propagate(edge_index, x=(x_, x), pos=(pos_, pos))

        new_pos, new_feat = out[:, :3], out[:, 3:]

        if return_batch:
            return new_pos, new_feat, batch_
        else:
            return new_pos, new_feat

    def message(self, x_i, x_j, pos_i, pos_j, index):
        # Positional embedding
        delta_ij = self.pos_mlp(pos_i - pos_j)

        # Positional weights
        pos_weights = self.pos_attn_mlp(x_i - x_j + delta_ij)
        pos_weights = torch_geometric.utils.softmax(
            pos_weights, index=index, num_nodes=None
        )

        # Feature weights
        feat_weights = self.feat_attn_mlp(x_i - x_j + delta_ij)
        feat_weights = torch_geometric.utils.softmax(
            feat_weights, index=index, num_nodes=None
        )

        # Concatenate to return
        out = torch.cat(
            [pos_weights * pos_j, feat_weights * (x_j + delta_ij)], dim=-1
        )

        return out

In [86]:
agp = AdaptGraphPooling(2, k=2)

In [87]:
np, nf = agp(f2, p2, return_batch=False)

In [13]:
tp = torch.rand(30, 3)
tf = torch.rand(30, 2)
x_batch = torch.repeat_interleave(torch.arange(3), 10)

In [16]:
agp = AdaptGraphPooling(2, k=2)
res1, res2, res_batch = agp(tf, tp, batch=x_batch)
res1.shape, res2.shape, res_batch

(torch.Size([9, 3]), torch.Size([9, 2]), tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]))

# HGNet

## FBNet

In [17]:
import torch.nn.functional as F

In [25]:
from FBNet import (
    furthest_point_sample,
    gather_operation,
    group_local,
    grouping_operation,
    query_knn,
)
from torch import einsum

In [26]:
class FBEdgeConv(torch.nn.Module):
    """
    Input:
        x: point cloud, [B, C1, N]
    Return:
        x: point cloud, [B, C2, N]
    """

    def __init__(self, input_channel, output_channel, k):
        super().__init__()
        self.num_neigh = k

        self.conv = nn.Sequential(
            nn.Conv2d(2 * input_channel, output_channel // 2, kernel_size=1),
            nn.BatchNorm2d(output_channel // 2),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(output_channel // 2, output_channel // 2, kernel_size=1),
            nn.BatchNorm2d(output_channel // 2),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv2d(output_channel // 2, output_channel, kernel_size=1),
        )

    def forward(self, inputs):
        batch_size, dims, num_points = inputs.shape
        if self.num_neigh is not None:
            neigh_feature = group_local(inputs, k=self.num_neigh).contiguous()
            central_feat = inputs.unsqueeze(dim=3).repeat(
                1, 1, 1, self.num_neigh
            )
        else:
            central_feat = torch.zeros(batch_size, dims, num_points, 1).to(
                inputs.device
            )
            neigh_feature = inputs.unsqueeze(-1)
        edge_feature = central_feat - neigh_feature
        feature = torch.cat((edge_feature, central_feat), dim=1)
        feature = self.conv(feature)
        central_feature = feature.max(dim=-1, keepdim=False)[0]
        return central_feature

In [46]:
# Hierarchical Graph-based Network
class FBHGNet(nn.Module):
    def __init__(self, num_pc=128, g_feat_dim=1024, using_max=True, k=3):
        super().__init__()

        self.using_max = using_max
        self.num_pc = num_pc
        pool_num = 2048

        self.out_channel = g_feat_dim // 2

        # HGNet econder
        self.gcn_1 = FBEdgeConv(3, 64, k)

        self.graph_pooling_1 = FBAdaptGraphPooling(4, 64, k)
        self.gcn_2 = FBEdgeConv(64, 128, k)
        self.graph_pooling_2 = FBAdaptGraphPooling(2, 128, k)
        self.gcn_3 = FBEdgeConv(128, 512, k)

        # Fully-connected decoder
        self.fc = nn.Sequential(
            nn.Linear(512 * 2, 1024),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(1024, 3 * num_pc),
        )

    def forward(self, inputs):
        device = inputs.device
        batch_size = inputs.size(0)
        x1 = self.gcn_1(inputs)
        print(x1.shape)

        vertices_pool_1, x1 = self.graph_pooling_1(inputs, x1)

        # B x 128 x 512
        x2 = self.gcn_2(x1)
        print(x2.shape)

        vertices_pool_2, x2 = self.graph_pooling_2(vertices_pool_1, x2)

        # B x 256 x 256
        x3 = self.gcn_3(x2)

        print(x3.shape)
        # Global feature generating B*1024
        feat_max = F.adaptive_max_pool1d(x3, 1).view(batch_size, -1)
        feat_avg = F.adaptive_avg_pool1d(x3, 1).view(batch_size, -1)
        feat_gf = torch.cat((feat_max, feat_avg), dim=1)
        print(feat_gf.shape)

        # Decoder coarse input
        print(self.fc(feat_gf).shape)
        coarse_pcd = self.fc(feat_gf).reshape(batch_size, -1, self.num_pc)

        return coarse_pcd, feat_max

In [38]:
net = FBHGNet(128, k=16).cuda()

In [42]:
cpcd, fmax = net(p3_)

torch.Size([1, 512, 256])
torch.Size([1, 1024])


In [43]:
cpcd.shape, fmax.shape

(torch.Size([1, 3, 128]), torch.Size([1, 512]))

In [57]:
net = FBHGNet(33, k=2).cuda()

# tp = torch.rand(3, 3, 2048).cuda()
tp = torch.rand(4, 3, 10).cuda()
res1, res2 = net(tp)
res1.shape, res2.shape

torch.Size([4, 64, 10])
torch.Size([4, 128, 3])
torch.Size([4, 512, 2])
torch.Size([4, 1024])
torch.Size([4, 99])


(torch.Size([4, 3, 33]), torch.Size([4, 512]))

## My HGNet

In [5]:
from torch_geometric.nn import (
    DynamicEdgeConv,
    EdgeConv,
    global_max_pool,
    global_mean_pool,
)

In [6]:
def edge_conv_nn(in_channels, out_channels):
    return nn.Sequential(
        nn.Linear(2 * in_channels, out_channels // 2),
        nn.BatchNorm1d(out_channels // 2),
        nn.LeakyReLU(0.2),
        nn.Linear(out_channels // 2, out_channels),
    )

In [62]:
class HGNet(nn.Module):
    def __init__(self, num_pc: int = 128, k=3):
        super().__init__()

        # HGNet econder
        self.num_pc = num_pc

        self.gcn_1 = DynamicEdgeConv(nn=edge_conv_nn(3, 64), k=k)
        self.graph_pooling_1 = AdaptGraphPooling(
            in_channels=64, pooling_rate=0.25, k=k
        )
        self.gcn_2 = DynamicEdgeConv(nn=edge_conv_nn(64, 128), k=k)
        self.graph_pooling_2 = AdaptGraphPooling(
            in_channels=128, pooling_rate=0.5, k=k
        )
        self.gcn_3 = DynamicEdgeConv(nn=edge_conv_nn(128, 512), k=k)

        # Fully-connected decoder
        self.fc = nn.Sequential(
            nn.Linear(in_features=512 * 2, out_features=1024),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(in_features=1024, out_features=1024),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(in_features=1024, out_features=3 * num_pc),
        )

    def forward(self, pos, batch=None, return_batch=True):
        device = pos.device
        x1 = self.gcn_1(pos, batch=batch)
        print(x1.shape)

        vertices_pool_1, x1, batch = self.graph_pooling_1(x1, pos, batch=batch)

        x2 = self.gcn_2(x1, batch=batch)
        print(x2.shape)

        vertices_pool_2, x2, batch = self.graph_pooling_2(
            x2, vertices_pool_1, batch=batch
        )

        x3 = self.gcn_3(x2, batch=batch)

        # Global feature generating B*1024
        print(x3.shape)
        feat_max = global_max_pool(x3, batch=batch)
        feat_avg = global_mean_pool(x3, batch=batch)
        feat_gf = torch.cat((feat_max, feat_avg), dim=1)
        # Decoder coarse input
        print(feat_gf.shape)
        coarse_pcd = self.fc(feat_gf)
        print(coarse_pcd.shape)

        coarse_pcd = rearrange(coarse_pcd, "b (d n)-> (b n) d", d = 3)

        if return_batch:
            batch = torch.repeat_interleave(
                torch.arange(feat_max.shape[0]), self.num_pc
            ).to(device)
            return coarse_pcd, feat_max, batch
        else:
            return coarse_pcd, feat_max

In [20]:
net = HGNet(128, k=2)

In [45]:
cpcd, feat_max = net(p3, return_batch=False)

NameError: name 'p3' is not defined

In [None]:
cpcd.shape, feat_max.shape

In [60]:
tp = torch.rand(40, 3).cuda()
tf = torch.rand(40, 2).cuda()
x_batch = torch.repeat_interleave(torch.arange(4), 10).cuda()

In [63]:
net = HGNet(33, k=2).cuda()
res1, res2, batch_res = net(tp, batch=x_batch)
res1.shape, res2.shape, batch_res.shape

torch.Size([40, 64])
torch.Size([12, 128])
torch.Size([8, 512])
torch.Size([4, 1024])
torch.Size([4, 99])


(torch.Size([132, 3]), torch.Size([4, 512]), torch.Size([132]))

# FBAC block

## FBNet

In [227]:
from FBNet import MLP_CONV
from FBNet import CrossTransformer as FBCrossTransformer
from FBNet import EdgeConv as FBEdgeConv
from FBNet import NodeShuffle as FBNodeShuffle

In [228]:
class FBAC_BLOCK(nn.Module):
    def __init__(self, up_factor=2, cycle_num=1, k=16):
        """
        des: Feedback-Aware Completion block
        input: point cloud: B, 3, N
        param: up_factor: up-sampling ratio
               cycle_num: number of time steps
        return: point cloud: B, 3, N * up_factor

        """
        super(FBAC_BLOCK, self).__init__()
        # self.cyc_num = cyc_num
        self.up_factor = up_factor
        # self.gf_mode = gf_mode
        # self.weight = weight

        self.nodeshuffle = FBNodeShuffle(
            128, 128, neighbor_num=8, scale=up_factor
        )
        self.mlp_delta = MLP_CONV(in_channel=128, layer_dims=[128, 64, 3])

        self.ext = FBEdgeConv(3, 128, k)

        self.mlp = MLP_CONV(in_channel=128 * 2, layer_dims=[256, 128])

        self.fb_exploit = FBCrossTransformer(in_channel=128, dim=64)

        self.up_sampler = nn.Upsample(scale_factor=up_factor)

        # self.alphas = nn.Embedding(cycle_num,1,_weight=torch.ones(cycle_num,1))
        # self.sigmoid = nn.Sigmoid()

    def forward(self, pcd, pcd_next, feat_next, cycle=0):
        """
        Args:
            pcd: Tensor, (B, 3, N_prev)
            pcd_next: Tensor, (B, 3, N_next)
            K_next: Tensor, (B, 128, N_next)

        Returns:
            pcd_child: Tensor, up sampled point cloud, (B, 3, N_prev * up_factor)
        """

        b, C, n_prev = pcd.shape

        # Step 1: Feature Extraction
        feat = self.ext(pcd)
        print(feat.shape)
        feat = self.mlp(
            torch.cat(
                [
                    feat,
                    torch.max(feat, 2, keepdim=True)[0].repeat(
                        (1, 1, feat.size(2))
                    ),
                ],
                1,
            )
        )
        print(
            torch.cat(
                [
                    feat,
                    torch.max(feat, 2, keepdim=True)[0].repeat(
                        (1, 1, feat.size(2))
                    ),
                ],
                1,
            ).shape
        )

        # Step 2: Feedback Exploitation
        if pcd_next is None:
            pcd_next, feat_next = pcd, feat
        feat = self.fb_exploit(pcd, feat, pcd_next, feat_next)

        # Step 3: Feature Expansion
        feat = self.nodeshuffle(feat)

        # Step 4: Coordinate Generation
        delta = self.mlp_delta(feat)
        pcd_child = self.up_sampler(pcd) + delta

        return pcd_child, feat

In [229]:
fbac_block = FBAC_BLOCK().cuda()

In [230]:
f3_ = torch.ones(1, 128, 2048).cuda()

In [231]:
r1, r2 = fbac_block(p3_, p3_, f3_)
r1.shape, r2.shape

torch.Size([1, 128, 2048])
torch.Size([1, 256, 2048])


(torch.Size([1, 3, 4096]), torch.Size([1, 128, 4096]))

In [232]:
tpx = torch.rand(3, 3, 10).cuda()
tpy = torch.rand(3, 3, 20).cuda()
tfx = torch.rand(3, 128, 10).cuda()
tfy = torch.rand(3, 128, 20).cuda()

In [233]:
fbac_block = FBAC_BLOCK(k=2).cuda()
res1, res2 = fbac_block(tpx, tpy, tfy)
res1.shape, res2.shape

torch.Size([3, 128, 10])
torch.Size([3, 256, 10])


(torch.Size([3, 3, 20]), torch.Size([3, 128, 20]))

## My fbac

In [13]:
import upsample
from einops import rearrange, reduce, repeat
from torch_geometric.nn import MLP, DynamicEdgeConv
from torch_geometric.utils import to_dense_batch

In [14]:
import importlib

importlib.reload(upsample)

<module 'upsample' from 'G:\\Knowledge\\Faculta-Trash\\Master-ML2022\\RTML\\3DVision\\PointCompletion\\FBNet\\upsample.py'>

In [15]:
class FbacBlock(nn.Module):
    def __init__(self, up_factor: int = 2):
        super().__init__()
        self.up_factor = up_factor

        # Feature extraction
        self.gcn = DynamicEdgeConv(nn=edge_conv_nn(3, 128), k=16)
        self.mlp = MLP([128 * 2, 256, 128])

        # Node expansion
        self.nodeshuffle = upsample.NodeShuffle(128, 128, k=8, r=up_factor)

        # Coordinate generation
        self.mlp_delta = MLP([128, 128, 64, 3])

        # Feedback exploitation
        self.cross_transformer = CrossTransformer(
            in_channel=128, pos_hidden_dim=64, attn_hidden_dim=64
        )

        self.up_sampler = nn.Upsample(scale_factor=up_factor)

    def forward(
        self,
        pcd,
        pcd_next=None,
        feat_next=None,
        batch_current=None,
        batch_next=None,
        return_batch: bool = True,
    ):
        # b, C, n_prev = pcd.shape

        # Step 1: Feature Extraction
        feat = self.gcn(pcd, batch=batch_current)
        feat = to_dense_batch(feat, batch=batch_current)[0]
        feat = torch.cat(
            [
                feat,
                repeat(
                    reduce(feat, "b n c -> b n", "max"),
                    "b n -> b n new_axis",
                    new_axis=feat.shape[-1],
                ),
            ],
            -1,
        )
        feat = rearrange(feat, "b n c -> (b n) c")
        feat = self.mlp(feat)

        # Step 2: Feedback Exploitation
        if pcd_next is None:
            pcd_next, feat_next, batch_next = pcd, feat, batch_current
        feat = self.cross_transformer(
            pcd,
            pcd_next,
            feat,
            feat_next,
            batch_x=batch_current,
            batch_y=batch_next,
        )

        # Step 3: Feature Expansion
        feat, batch = self.nodeshuffle(
            feat, batch=batch_current, return_batch=True
        )

        # Step 4: Coordinate Generation
        delta = self.mlp_delta(feat)
        u = repeat(pcd, "n c -> (n d) c", d=self.up_factor)
        pcd_child = u + delta

        if return_batch:
            return pcd_child, feat, batch
        else:
            return pcd_child, feat

In [13]:
p = torch.rand(2048, 3)
p_next = torch.rand(2048, 3)
feat_next = torch.rand(2048, 128)

In [240]:
fbac_block = FbacBlock()

In [241]:
f3 = torch.rand(2048, 128)

In [242]:
o, t = fbac_block(p3, p3, f3, return_batch=False)
o.shape, t.shape

(torch.Size([4096, 3]), torch.Size([4096, 128]))

In [243]:
tpx = torch.rand(30, 3)
tpy = torch.rand(60, 3)
tfx = torch.rand(30, 128)
tfy = torch.rand(60, 128)
x_batch = torch.repeat_interleave(torch.arange(3), 10)
y_batch = torch.repeat_interleave(torch.arange(3), 20)

In [244]:
fbac_block = FbacBlock()
res1, res2, batch_res = fbac_block(
    tpx, tpy, tfy, batch_current=x_batch, batch_next=y_batch
)
res1.shape, res2.shape, batch_res.shape

(torch.Size([60, 3]), torch.Size([60, 128]), torch.Size([60]))

# Fbac Refinement

In [51]:
from einops.layers.torch import Rearrange

In [52]:
class FeedbackRefinementNet(nn.Module):
    def __init__(
        self,
        up_factors=None,
        cycle_num=1,
        n_points_start=512,
        return_all: bool = False,
    ):
        super().__init__()
        self.return_all = return_all
        self.n_points_start = n_points_start
        self.cycle_num = cycle_num
        if up_factors is None:
            up_factors = [1]

        self.uppers = nn.ModuleList(
            [FbacBlock(up_factor=factor) for factor in up_factors]
        )

        self.flatten_batch = Rearrange("b n c -> (b n) c")

    def forward(self, pcd, partial):

        # Init input
        arr_pcd = []
        pcd = fps_subsample(
            torch.cat([pcd, partial], dim=1), self.n_points_start
        )  # [b n_start 3]

        feat_state = []
        pcd_state = []

        for cycle in range(self.cycle_num):
            pcd_list = []
            feat_list = []
            for upper_idx, upper in enumerate(self.uppers):
                # First timestep
                if cycle == 0:
                    # Add partial and fps only when they're available
                    if upper_idx > 0:
                        n_points = pcd.shape[1]
                        # Concatenate pcd and partial
                        pcd = torch.cat([pcd, partial], dim=1)
                        # Sample back with fps
                        pcd = fps_subsample(pcd, n_points)

                    batch_current = get_batch(pcd)
                    pcd, feat, b_ = upper(
                        self.flatten_batch(pcd), batch_current=batch_current
                    )
                    pcd = to_dense_batch(pcd, batch=b_)[0]

                    print(pcd.shape, feat.shape)
                # Next timesteps
                else:
                    pcd_next = pcd_state[cycle - 1][upper_idx]
                    feat_next = feat_state[cycle - 1][upper_idx]

                    # First fbac block
                    if upper_idx == 0:
                        pcd = pcd_state[cycle - 1][0]
                        pcd = torch.cat([pcd, partial], dim=1)
                        pcd = fps_subsample(pcd, self.n_points_start)
                    else:
                        pcd = pcd_list[upper_idx - 1]  # take last pcd
                        n_points = pcd_state[cycle - 1][upper_idx - 1].shape[1]
                        pcd = torch.cat([pcd, partial], dim=1)
                        pcd = fps_subsample(pcd, n_points)

                    batch_current = get_batch(pcd)
                    batch_next = get_batch(pcd_next)
                    pcd, feat, b_ = upper(
                        self.flatten_batch(pcd),
                        self.flatten_batch(pcd_next),
                        feat_next,
                        batch_current=batch_current,
                        batch_next=batch_next,
                    )
                    pcd = to_dense_batch(pcd, batch=b_)[0]

                pcd_list.append(pcd)
                feat_list.append(feat)

                if self.return_all:
                    arr_pcd.append(pcd)
                else:
                    if cycle == self.cycle_num - 1:
                        arr_pcd.append(pcd)

            # Saving present time step states
            pcd_state.append(pcd_list)
            feat_state.append(feat_list)
        return arr_pcd

In [53]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        n_pc = 128
        n_points_start = 512
        n_points = 2048

        self.coarse_net = HGNet(num_pc=n_pc)

        up_factors = [1, 2, 2]

        cycle_num = 3
        self.refiner = FeedbackRefinementNet(
            up_factors=up_factors,
            cycle_num=cycle_num,
            return_all=True,
            n_points_start=n_points_start,
        )
        self.flatten_batch = Rearrange("b n c -> (b n) c")

    def forward(self, x):

        # Coarse generation
        coarse_pcd, _, batch = self.coarse_net(
            self.flatten_batch(x), batch=get_batch(x)
        )
        p
        print(coarse_pcd.shape, batch.shape)
        # feedback refinement stage
        coarse_pcd_dense = to_dense_batch(coarse_pcd, batch=batch)[0]
        res_pcds = self.refiner(coarse_pcd_dense, x)

        fine = res_pcds[-1]
        return fine

In [54]:
model = Model().cuda()

RuntimeError: CUDA error: device-side assert triggered

In [55]:
p = torch.rand(3, 2048, 3).cuda()

RuntimeError: CUDA error: device-side assert triggered

In [40]:
get_batch(p).shape

torch.Size([6144])

In [41]:
r1, r2, b = model.coarse_net(model.flatten_batch(p), batch=get_batch(p))

torch.Size([1, 512])


In [32]:
r1.shape

torch.Size([384, 1])

In [33]:
b.shape

torch.Size([128])

In [34]:
p.shape

torch.Size([3, 2048, 3])

In [20]:
#%pdb

In [21]:
out = model(p)

torch.Size([384, 1]) torch.Size([128])


RuntimeError: The expanded size of the tensor (384) must match the existing size (128) at non-singleton dimension 0.  Target sizes: [384].  Tensor sizes: [128]

In [1]:
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1
