In [None]:
import os
import sys

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import torch
from skimage import measure
from tqdm import tqdm, trange

from core.dataset import ScanNet
from core.integrate import  FeatureFusionScalableTSDFVolume

In [None]:
dataset = ScanNet('/home/quanta/Datasets/ScanNet/')
scan_id = 'scene0000_00'
id = dataset.scan_id_list.index(scan_id)
single_instance = dataset[id]['scan_dataset']

In [None]:
H, W = single_instance.color_height, single_instance.color_width

In [None]:
tsdf_device = 'cuda:1'
tsdf_volume = FeatureFusionScalableTSDFVolume(
    voxel_size=0.015,
    sdf_trunc=0.075,
    margin=0.08,
    device=tsdf_device,
)


In [None]:
save_dir = '/home/quanta/Experiments/feature-instance-fusion/scannet_scene0000_00/'
tsdf_volume.load(save_dir + 'tsdf/tsdf_vol.pt')

In [None]:
torch.cuda.empty_cache()
dim=1024
dtype=torch.float16
tsdf_volume.reset_feature(dim=dim, include_var=False, dtype=dtype)

In [None]:
indent = 1
for idx in trange(0, len(single_instance), indent):
    # for idx in trange(0, 500, indent):
    inputs = single_instance.get_torch_tensor(
        idx,
        device=tsdf_device,
        keys={
            "depth",
            "depth_intr",
            "pose",
            "color_intr",
        },
    )
    fake_feat = torch.randn(size=(H, W, dim), dtype=dtype, device=tsdf_device)

    tsdf_volume.integrate_feature_with_exsisting_voxel(
        feat_img=fake_feat,
        feat_intr=inputs["color_intr"],
        depth=inputs["depth"],
        depth_intr=inputs["depth_intr"],
        cam_pose=inputs["pose"],
    )

In [None]:
# 11.303g vram use, good! for dim=512 and float32
# 19.304g for dim=1024 and float 32
# 10.542g vram for dim=1024 and float16

In [None]:
# integrate color
torch.cuda.empty_cache()
dim=3
dtype=torch.float32
tsdf_volume.reset_feature(dim=dim, include_var=False, dtype=dtype)

In [None]:
indent = 1
for idx in trange(0, len(single_instance), indent):
    # for idx in trange(0, 500, indent):
    inputs = single_instance.get_torch_tensor(
        idx,
        device=tsdf_device,
        keys={
            "depth",
            "depth_intr",
            "pose",
            "color_intr",
            "color",
        },
    )

    tsdf_volume.integrate_feature_with_exsisting_voxel(
        feat_img=inputs["color"],
        feat_intr=inputs["color_intr"],
        feat_original_h=H,
        feat_original_w=W,
        depth=inputs["depth"],
        depth_intr=inputs["depth_intr"],
        cam_pose=inputs["pose"],
    )

In [None]:
os.makedirs(save_dir + 'color', exist_ok=True)
tsdf_volume.save_feats(save_dir + 'color/color_feats.pt')

In [None]:
verts = np.load(save_dir + 'tsdf/verts.npy')
faces = np.load(save_dir + 'tsdf/faces.npy')

In [None]:
color = tsdf_volume.extract_feat_on_grid(verts=verts)[0]

In [None]:
np.save(save_dir + 'color/color.npy', color)

In [None]:
import open3d as o3d

draw = o3d.visualization.EV.draw
mesh = o3d.geometry.TriangleMesh(
    vertices=o3d.utility.Vector3dVector(verts),
    triangles=o3d.utility.Vector3iVector(faces),
)
mesh.vertex_colors = o3d.utility.Vector3dVector(color)

In [None]:
draw([mesh])