In [1]:
import os
from pathlib import Path

import h5py
import numpy as np
import open3d as o3d
import torch
from einops import rearrange
from FBNet import Model as FBModel
from fbnet_geometric import (
    AdaptGraphPooling,
    CrossTransformer,
    FbacBlock,
    FeedbackRefinementNet,
    HGNet,
    Model,
    get_batch,
)
from torch_geometric.nn import knn, knn_graph
from torch_geometric.utils import to_dense_batch
from torchinfo import summary
from tqdm.notebook import tqdm

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


# Utils

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
def viz_many(clouds: list):
    pcds = []
    for i, p in enumerate(clouds):
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(
            p + i * np.array([3, 0, 0])
        )  # shift to see them side by side
        pcds.append(pcd)
    o3d.visualization.draw_geometries(pcds)

In [4]:
p = torch.rand(3, 2048, 3).cuda()
p_ = rearrange(p, "b n c -> (b n) c")
batch = get_batch(p).cuda()

In [5]:
p_ = torch.rand(3 * 2048, 3).cuda()
batch = torch.repeat_interleave(torch.arange(3), 2048).cuda()

In [6]:
model = Model().cuda()
summary(model, depth=4)

Layer (type:depth-idx)                                  Param #
Model                                                   --
├─HGNet: 1-1                                            --
│    └─DynamicEdgeConv: 2-1                             --
│    │    └─MaxAggregation: 3-1                         --
│    │    └─Sequential: 3-2                             --
│    │    │    └─Linear: 4-1                            224
│    │    │    └─BatchNorm1d: 4-2                       64
│    │    │    └─LeakyReLU: 4-3                         --
│    │    │    └─Linear: 4-4                            2,112
│    └─AdaptGraphPooling: 2-2                           --
│    │    └─SumAggregation: 3-3                         --
│    │    └─Sequential: 3-4                             --
│    │    │    └─Linear: 4-5                            64
│    │    │    └─BatchNorm1d: 4-6                       32
│    │    │    └─LeakyReLU: 4-7                         --
│    │    │    └─Linear: 4-8                   

In [7]:
summary(model, depth=1)

Layer (type:depth-idx)                                  Param #
Model                                                   --
├─HGNet: 1-1                                            820,742
├─FeedbackRefinementNet: 1-2                            113,625
├─Rearrange: 1-3                                        --
Total params: 934,367
Trainable params: 934,367
Non-trainable params: 0

In [8]:
coarse_pcd, pcds, _ = model(p)

In [9]:
coarse_pcd.shape, pcds[-1].shape

(torch.Size([3, 128, 3]), torch.Size([3, 2048, 3]))

# Data

In [10]:
import torch.utils.data as data


class MvpDataset(data.Dataset):
    def __init__(self, prefix="train"):
        if prefix == "train":
            self.file_path = "./data/MVP_Train_CP.h5"
        elif prefix == "val":
            self.file_path = "./data/MVP_Test_CP.h5"
        elif prefix == "test":
            self.file_path = "./data/MVP_ExtraTest_Shuffled_CP.h5"
        else:
            raise ValueError("ValueError prefix should be [train/val/test] ")

        self.prefix = prefix

        input_file = h5py.File(self.file_path, "r")
        self.input_data = np.array(input_file["incomplete_pcds"][()])

        print(self.input_data.shape)

        if prefix != "test":
            self.gt_data = np.array(input_file["complete_pcds"][()])
            self.labels = np.array(input_file["labels"][()])

            print(self.gt_data.shape, self.labels.shape)
            c_idxs = np.where((self.labels == 0) | (self.labels == 1))[0]
            print(c_idxs)
            self.input_data = self.input_data[c_idxs]
            self.gt_data = self.gt_data[c_idxs // 26]
            self.labels = self.labels[c_idxs]

            print(self.gt_data.shape, self.labels.shape)

        input_file.close()
        self.len = self.input_data.shape[0]

    def __len__(self):
        return self.len

    def __getitem__(self, index):
        partial = torch.from_numpy((self.input_data[index]))

        if self.prefix != "test":
            complete = torch.from_numpy((self.gt_data[index // 26]))
            label = self.labels[index]
            return label, partial, complete
        else:
            return partial

In [11]:
dataset = MvpDataset(prefix="train")

(62400, 2048, 3)
(2400, 2048, 3) (62400,)
[    0     1     2 ... 10397 10398 10399]
(10400, 2048, 3) (10400,)


In [12]:
i = 1
pcd, gt = dataset[i][1], dataset[i][2]
viz_many([pcd, gt])

In [13]:
batch_size = 4
epochs = 10

In [14]:
trainloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [15]:
# _, pcd, gt = dataset[3]
# viz_many([pcd, gt])

# Train

In [16]:
from ChamferDistancePytorch.chamfer3D.dist_chamfer_3D import (
    chamfer_3DDist as ChamferLoss,
)


def chamfer_dist(output, gt, return_raw=False):
    # https://github.com/wutong16/Density_aware_Chamfer_Distance

    # cham_loss = dist_chamfer_3D.chamfer_3DDist()
    cham_loss = ChamferLoss()
    dist1, dist2, idx1, idx2 = cham_loss(gt, output)
    cd_p = (torch.sqrt(dist1).mean(1) + torch.sqrt(dist2).mean(1)) / 2
    cd_t = dist1.mean(1) + dist2.mean(1)

    res = [cd_p, cd_t]
    if return_raw:
        res.extend([dist1, dist2, idx1, idx2])
    return res

Loaded compiled 3D CUDA chamfer distance


In [27]:
points1 = torch.rand(32, 1000, 3).cuda()
points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda()

chamfer_dist(points1, points2)

[tensor([0.0513, 0.0519, 0.0520, 0.0514, 0.0506, 0.0514, 0.0510, 0.0512, 0.0511,
         0.0520, 0.0515, 0.0508, 0.0509, 0.0510, 0.0509, 0.0509, 0.0519, 0.0518,
         0.0508, 0.0515, 0.0511, 0.0516, 0.0506, 0.0506, 0.0498, 0.0515, 0.0504,
         0.0514, 0.0510, 0.0512, 0.0502, 0.0515], device='cuda:0',
        grad_fn=<DivBackward0>),
 tensor([0.0061, 0.0063, 0.0062, 0.0061, 0.0060, 0.0061, 0.0061, 0.0060, 0.0061,
         0.0063, 0.0062, 0.0060, 0.0060, 0.0060, 0.0060, 0.0061, 0.0062, 0.0062,
         0.0060, 0.0062, 0.0061, 0.0062, 0.0059, 0.0060, 0.0058, 0.0062, 0.0059,
         0.0061, 0.0060, 0.0060, 0.0058, 0.0062], device='cuda:0',
        grad_fn=<AddBackward0>)]

In [17]:
def train(model, trainloader, loss_fn, optimizer):

    total_loss = 0.0
    total_proportion = 0.0
    for (i, d) in (t := tqdm(enumerate(trainloader), total=len(trainloader))):
        # Extract source and target point clouds and batches
        p, gt = d[1].to(device), d[2].to(device)

        # Train step
        optimizer.zero_grad()
        pred_coarse, pred_pcds, _ = model(p)

        # Calculate loss
        _, cd_t = loss_fn(pred_coarse, gt)
        loss = cd_t.mean()
        for pcd in pred_pcds:
            _, cd_t = chamfer_dist(pcd, gt)
            loss += cd_t.mean()

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        t.set_description(f"loss = {loss:.8f}")
        # t.set_description(f"loss = {total_loss / ((i+1) * trainloader.batch_size) :.8f}")
    return total_loss / len(trainloader.dataset)

In [18]:
# model = Model().to(device)
fbmodel = FBModel().to(device)
optimizer = torch.optim.Adam(
    params=model.parameters(), lr=0.001, betas=[0.9, 0.999]
)
loss_fn = chamfer_dist

#Time steps:3


In [19]:
train_losses = []
for epoch in tqdm(range(1, epochs + 1)):
    train_loss = train(model, trainloader, loss_fn, optimizer)
    # train_loss = train_w_refiner(model, trainloader, loss_fn, optimizer, alpha=0.5)
    train_losses.append(train_loss)
    # val_loss = evaluate(model, valloader, loss_fn)
    # history.val_loss.append(val_loss)
    print(f"{epoch=} \t {train_loss=:.6f}")

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/2600 [00:00<?, ?it/s]

epoch=1 	 train_loss=0.141331


  0%|          | 0/2600 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [20]:
input_list = []
gt_list = []
data_dir = Path("data")
for _, _, files in os.walk(data_dir / "partial_input"):
    for file in files:
        prefix = file.replace(".npy", "")
        input_file = data_dir / "partial_input" / file
        gt_file = data_dir / "gt" / file

        inputs = torch.from_numpy(np.load(input_file)).unsqueeze(0).contiguous()
        gt = torch.from_numpy(np.load(gt_file)).unsqueeze(0).contiguous()

        input_list.append(inputs)
        gt_list.append(gt)

In [21]:
input_list[0].shape

torch.Size([1, 2048, 3])

In [22]:
preds = []
with torch.no_grad():
    for X in tqdm(input_list):
        X = X.cuda()
        _, _, fine = model(X)
        preds.append(fine.cpu())

  0%|          | 0/16 [00:00<?, ?it/s]

In [23]:
i = 6
viz_many([input_list[i][0].numpy(), preds[i][0].numpy(), gt_list[i][0].numpy()])