In [1]:
import os
import sys
sys.path.insert(0, '../')
import torch
import numpy as np
import imageio

from matplotlib import pyplot as plt

try:
    import piplite
    await piplite.install(['ipywidgets'])
except ImportError:
    pass
import ipywidgets as widgets

In [2]:
from engine.trainer import Trainer
from engine.eval import evaluation_path
from engine.get_point_cloud import write_point_cloud,read_point_cloud
from data import dataset_dict
from utils.opt import config_parser


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
path_redirect = [
    # option name, path in the config, redirected path
    ('palette_path', './data_palette', '../data_palette')
]

In [4]:
run_dir = '../logs/chair'
ckpt_path = None
out_dir = os.path.join(run_dir,'demo_out')

In [5]:
#读取数据
def read_data(dataset_type='train'):
    parser = config_parser()
    config_path = os.path.join(run_dir,'args.txt')

    if os.path.exists(config_path):
        with open(config_path,'r') as f:
            args,remainings = parser.parse_known_args(args=[],config_file_contents=f.read())

            if ckpt_path is not None:
                setattr(args,'ckpt',ckpt_path)

            for entry in path_redirect:
                setattr(args,entry[0],getattr(args,entry[0]).replace(entry[1],entry[2]))

            print("Args loaded:", args)
    else:
        print(f"ERROR : cannot read args in {run_dir}.")
    print()

    dataset = dataset_dict[args.dataset_name]
    # train_dataset
    train_dataset = dataset(args.datadir,split='train',downsample=args.downsample_train * 2.,is_stack=True)
    # test_dataset
    test_dataset = dataset(args.datadir,split='test',downsample=args.downsample_test*2., is_stack=True)
    if dataset_type =='train':
        return args,train_dataset
    else:
        return args,test_dataset


In [6]:
from utils.color_decomposition import color_decomposition

def plot_color_decomposition(ds_test_dataset,rgbs,palette_rgb,plot_palette_color_idx=0,dataset=None,fg=None):
    w, h = ds_test_dataset.img_wh
    # palette_rgb = palette_rgb.clip(0.,1.)
    rgbs = torch.tensor(rgbs/255)
    rgbs = torch.reshape(rgbs,(-1,3))
    palette_rgb = torch.tensor(palette_rgb)
    palette_number = palette_rgb.shape[0]
    color_deco = color_decomposition(rgbs,palette_rgb) # bs * h * w
    # print(rgbs)
    print(palette_rgb)
    plot_palette_color =  0
    true_idx = (color_deco != plot_palette_color) # bs * 1

    if dataset is not None and dataset.white_bg and fg is not None:
        all_rgb_cp = dataset.all_rgbs.clone().cpu()
        all_rgb_cp_original = dataset.all_rgbs.clone().cpu()
        fg[fg==True] = true_idx

        all_rgb_cp[fg] = 1.
    else:
        all_rgb_cp = torch.clone(rgbs)
        all_rgb_cp_original = torch.clone(rgbs)
        all_rgb_cp[true_idx] = 1.



    all_rgb_maps = torch.reshape(all_rgb_cp,(-1, h, w, 3))
    # print(all_rgb_maps)
    all_rgb_cp_original = torch.reshape(all_rgb_cp_original,(-1,h,w,3))




    fig,axes = plt.subplots(1,2)
    axes[0].imshow(all_rgb_maps[plot_palette_color_idx].clone().numpy())
    axes[1].imshow(all_rgb_cp_original[plot_palette_color_idx].clone().numpy())

In [7]:
def write_pointcloud(dataset_type='train',):

    #读取数据
    args,dataset = read_data(dataset_type=dataset_type,)

    print("Initializing trainer and model...")
    ckpt_dir = os.path.join(run_dir,"checkpoints")
    tb_dir = os.path.join(run_dir,"tensorboard")

    trainer = Trainer(args,run_dir,ckpt_dir, tb_dir)

    model = trainer.build_network()
    model.eval()
    print()

    #调色板提取
    palette_prior = trainer.palette_prior.detach().cpu().numpy()
    palette = model.renderModule.palette.get_palette_array().detach().cpu().numpy()

    print("==============*****************==================")
    write_point_cloud(dataset, model, args, trainer.renderer, savePath=None, N_vis=5, N_samples=-1, white_bg=False,
               ndc_ray=False, palette=palette, new_palette=None,device='cuda',filename=None)



In [10]:
#写点云
write_pointcloud()

Args loaded: Namespace(L1_weight_inital=8e-05, L1_weight_rest=4e-05, N_vis=5, N_voxel_final=27000000, N_voxel_init=2097156, Ortho_weight=0.0, Plt_bd_weight=1.0, Plt_opaque_conv_weight=0.0, Plt_opaque_sps_weight=0.001, TV_weight_app=0.0, TV_weight_density=0.0, alpha_mask_thre=0.0001, basedir='./logs', batch_size=4096, ckpt=None, config='configs/chair.txt', data_dim_color=27, datadir='/home/ubuntu/Rencq/nerf_data/nerf_synthetic/chair', dataset_name='blender', density_shift=-10.0, distance_scale=25.0, downsample_test=1.0, downsample_train=1.0, expname='chair', export_mesh=0, fea2denseAct='softplus', fea_pe=2, featureC=128, learn_palette=True, lindisp=False, lr_basis=0.001, lr_decay_iters=-1, lr_decay_target_ratio=0.1, lr_init=0.02, lr_upsample_reset=1, model_name='PaletteTensorVM', nSamples=1000000, n_iters=30000, n_lamb_sh=[48, 48, 48], n_lamb_sigma=[16, 16, 16], ndc_ray=0, no_reload=0, palette_init='userinput', palette_path='../data_palette/chair/rgb_palette.npy', perturb=1.0, pos_pe=6,

Loading data train (100): 100%|██████████| 100/100 [00:03<00:00, 32.85it/s]
Loading data test (200): 100%|██████████| 200/200 [00:06<00:00, 33.10it/s]


Initializing trainer and model...


Loading data train (100): 100%|██████████| 100/100 [00:01<00:00, 57.25it/s]
Loading data test (200): 100%|██████████| 200/200 [00:03<00:00, 56.62it/s]


[trainer init] aabb [[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]
[trainer init] num of render samples 443
[trainer init] palette shape torch.Size([4, 3])
[update_stepSize] aabb tensor([-0.7441, -0.7205, -1.0276,  0.6732,  0.7441,  1.0748], device='cuda:0')
[update_stepSize] grid size [260, 268, 385]
[update_stepSize] sampling step size:  tensor(0.0027, device='cuda:0')
[update_stepSize] sampling number:  1070
[init_render_func] shadingMode=PLT_AlphaBlend pos_pe=6 view_pe=2 fea_pe=2 learn_palette=True palette_init=userinput
[TensorBase init] renderModule: PLTRender(
  (palette): FreeformPalette()
  (mlp): Sequential(
    (0): Linear(in_features=150, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.01, inplace=True)
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): LeakyReLU(negative_slope=0.01, inplace=True)
    (4): Linear(in_features=128, out_features=3, bias=True)
  )
)
[TensorBase init] render buffer layout: [RenderBufferProp(name='rgb', len=3, detach_w

  0%|          | 0/5 [00:00<?, ?it/s]



AttributeError: 'numpy.ndarray' object has no attribute 'to'