In [None]:
import sys
sys.path.append('../')
import numpy as np
import matplotlib.pyplot as plt
import open3d as o3d
import torch
import seaborn as sns
import json

from scene.artgs import ArtGS
from scene.gaussian_model import GaussianModel

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def visualize_point_cloud(xyz, rgb=None, save_path=None):
    """
    xyz: (N, 3) numpy array
    rgb: (N, 3) numpy array
    """
    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(xyz)
    if rgb is not None:
        point_cloud.colors = o3d.utility.Vector3dVector(rgb)
    o3d.visualization.draw_geometries([point_cloud], )
    if save_path is not None:
        o3d.io.write_point_cloud(save_path, point_cloud)

## Show Initialized Cano GS and Centers

In [None]:
dataset = 'paris'
# subset = 'sapien'
# scenes = 'foldchair_102255 washer_103776 fridge_10905 blade_103706 storage_45135 oven_101917 stapler_103111 USB_100109 laptop_10211 scissor_11100'.split(' ')
subset = 'realscan'
scenes = 'real_fridge real_storage'.split(' ')

# dataset = 'dta'
# subset = 'sapien'
# scenes = 'fridge_10489 storage_47254'.split(' ')

dataset = 'artgs'
subset = 'sapien'
# scenes = 'oven_101908 table_25493 storage_45503 storage_47648 table_31249'.split(' ')
scene = 'table_31249'
with torch.no_grad():
    path = f'../outputs/{dataset}/{subset}/{scene}/coarse_gs/point_cloud/iteration_10000/point_cloud.ply'
    gs = GaussianModel(3)
    gs.load_ply(path)
    xyz, color = gs.get_xyz.cpu().numpy(), gs.get_rgb.cpu().numpy()
    print(xyz.shape, color.shape)
    center_info = np.load(f'../outputs/{dataset}/{subset}/{scene}/coarse_gs/point_cloud/iteration_10000/center_info.npy')
    center = center_info[:, :3]
    num_slots = center.shape[0]

    # mannually correct the center
    
    # center[1] += np.array([0., 0., 0.])
    # center[2] += np.array([0., 0., 0.])
    # center_info[:, :3] = center
    # center_info[:, 3] /= 4
    # np.save(f'outputs/{dataset}/{subset}/{scene}/coarse_gs/point_cloud/iteration_10000/center_info.npy', center_info)
    
    pallete = np.array(sns.color_palette("hls", num_slots))
    plt.imsave('pallete.png', pallete[None])
    xyz_center = (center[None] + np.random.randn(1000, center.shape[0], 3) * 0.01).reshape(-1, 3)
    rgb_center = pallete[None].repeat(1000, 0).reshape(-1, 3)
    print(xyz_center.shape, rgb_center.shape)
    xyz_vis = np.concatenate([xyz, xyz_center])
    rgb_vis = np.concatenate([color, rgb_center])
    visualize_point_cloud(xyz_vis, rgb_vis)

(26571, 3) (26571, 3)
(5000, 3) (5000, 3)


## Show Initialized Cano GS and Centers with Segmentations

In [None]:
class Args(object):
    def __init__(self):
        self.slot_size = 32
        self.joint_types = 's,r'
        self.num_slots = len(self.joint_types.split(','))
        self.gumbel = True
        self.scale_factor = 1.
        self.use_art_type_prior = True
        self.shift_weight = 0.5
        self.tau_decay_steps = 10000
args = Args()
joint_types = json.load(open(f'../arguments/joint_types_cgs.json', 'r'))[dataset][subset][scene]
args.joint_types = joint_types
args.num_slots = len(joint_types.split(','))
num_slots = args.num_slots
pallete = np.array(sns.color_palette("hls", num_slots))
pallete[0] = [0, 0, 0]

deform = ArtGS(args).cuda()
center, scale = deform.seg_model.init_from_file(f'../outputs/{dataset}/{subset}/{scene}/coarse_gs/point_cloud/iteration_10000/center_info.npy')
deform.update(20000)
deform.eval()   

with torch.no_grad():
    mask = deform.get_mask(torch.tensor(xyz).cuda()).argmax(1).cpu().numpy()
    m = mask == 0
    xyz = np.concatenate([xyz[m], xyz[~m]])
    color = np.concatenate([color[m], color[~m]])
    mask = np.concatenate([mask[m], mask[~m]])
    color = pallete[mask]
    center = center.cpu().numpy()
    xyz_center = (center[None] + np.random.randn(1000, center.shape[0], 3) * 0.01).reshape(-1, 3)
    rgb_center = pallete[None].repeat(1000, 0).reshape(-1, 3)
    xyz_vis = np.concatenate([xyz, xyz_center])
    rgb_vis = np.concatenate([color, rgb_center])
    visualize_point_cloud(xyz_vis, rgb_vis)

Update current level of HashGrid to 12
