In [1]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np

import sys
sys.path.append('../')
from rkhs_splatting.utils.data_utils import read_all
from rkhs_splatting.utils.camera_utils import parse_camera
from rkhs_splatting.rkhs_model import RKHSModel
from rkhs_splatting.rkhs_render import RKHSRenderer
from rkhs_splatting.rkhs_model_global_scale import RKHSModelGlobalScale
from rkhs_splatting.rkhs_render_global_scale import RKHSRendererGlobalScale
from rkhs_splatting.utils.point_utils import get_point_clouds

import datetime
import pathlib
from icecream import ic
from spatialmath import SE3
import plotly.graph_objects as go

from plotly_utils import *
from test_utils import *
from experiments.points_trainer import GSSTrainer

In [55]:
"""
global scale, single frame, single point
"""

# pos, rgba
train_pc = create_pc(np.array([
    [0,0,0, 1,0,0,1],
]))
init_map = create_pc(np.array([
    [3,1,0, 0.5,0.5,0,1],
    [-2,2,0, 0.5,0.5,0,1],
]))

c2w = SE3.Tz(-10) #SE3.Rx(-0.5)
camera_data = create_camera(256, 256, 355, 355, c2w)
camera = to_viewpoint_camera(camera_data)
map_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=True)
map_model.create_from_pcd(init_map, initial_scaling=0.5)
input_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=False)
renderer = RKHSRendererGlobalScale(white_bkgd=True)
input_model.create_from_pcd(train_pc, initial_scaling=0.5)
input_frame = renderer(
    camera,
    input_model.get_xyz,
    input_model.get_opacity,
    input_model.get_scaling,
    input_model.get_features
)

fig = go.Figure()
plot_camera(fig, c2w.R, c2w.t, 3, 'camera0', True)
plot_pc(fig, train_pc, 'train_pc')
plot_pc(fig, init_map, 'init_map')
fig.show()

In [4]:
"""
global scale, single frame, multiple points
"""

# pos, rgba
train_pc = create_random_pc(n=10, mu=0, sigma=2, rgba=np.array([1,0,0,1]))
init_map = create_random_pc(n=10, mu=0, sigma=5, rgba=np.array([1,0,0,1]))

c2w = SE3.Tz(-10) #SE3.Rx(-0.5)
camera_data = create_camera(256, 256, 355, 355, c2w)
camera = to_viewpoint_camera(camera_data)
map_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=True)
map_model.create_from_pcd(init_map, initial_scaling=0.5)
input_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=False)
renderer = RKHSRendererGlobalScale(white_bkgd=True)
input_model.create_from_pcd(train_pc, initial_scaling=0.5)
input_frame = renderer(
    camera,
    input_model.get_xyz,
    input_model.get_opacity,
    input_model.get_scaling,
    input_model.get_features
)

fig = go.Figure()
plot_camera(fig, c2w.R, c2w.t, 3, 'camera0', True)
plot_pc(fig, train_pc, 'train_pc')
plot_pc(fig, init_map, 'init_map')
fig.show()

In [9]:

folder_name = datetime.datetime.now().strftime("%Y-%m-%d__%H-%M-%S")
folder_name = 'test'
results_folder = pathlib.Path('../result/'+folder_name)
results_folder.mkdir(parents=True, exist_ok=True)

n_train = 1000

trainer = GSSTrainer(
    model=map_model,
    input_model=input_model,
    renderer=renderer,
    # data=data,
    use_input_frames=True,
    input_frames=[input_frame],
    train_batch_size=1, 
    train_num_steps=n_train,
    i_image=n_train//10,
    train_lr=1e-2,#3e-3
    amp=True,
    fp16=False,
    results_folder=results_folder,
    use_rkhs_rgb=False,
    use_rkhs_geo=True
)

trainer.on_evaluate_step()
trainer.train()

ic(input_model.get_xyz, input_model.get_features)
ic(map_model.get_xyz, map_model.get_features)
ic(init_map.coords, init_map.select_channels(['R', 'G', 'B', 'A']))

  0%|          | 0/1000 [00:00<?, ?it/s]

ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)
ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)
ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)
ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)
ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)
ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)
ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)
ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)
ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)
ic| scale3d: Parameter containing:
             tensor(0.5000, device='cuda:0', requires_grad=True)


KeyboardInterrupt: 