In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np

import sys
sys.path.append('../')
from rkhs_splatting.utils.data_utils import read_all
from rkhs_splatting.utils.camera_utils import parse_camera
from rkhs_splatting.rkhs_model import RKHSModel
from rkhs_splatting.rkhs_render import RKHSRenderer
from rkhs_splatting.rkhs_model_global_scale import RKHSModelGlobalScale
from rkhs_splatting.rkhs_render_global_scale import RKHSRendererGlobalScale
from rkhs_splatting.utils.point_utils import get_point_clouds

import datetime
import pathlib
from icecream import ic
from spatialmath import SE3
import plotly.graph_objects as go

from plotly_utils import *
from test_utils import *
from experiments.points_trainer import GSSTrainer

In [None]:
"""
global scale
"""

# map: manual input
# train_pc = create_pc(np.array([
#     [1,1,0, 1,0,0,1],
#     [-1,-1,0, 1,0,0,1],
# ]))
# init_map = create_pc(np.array([
#     [1.5,1.5,0, 1,0,0,1],
#     [-1.5,-1.5,0, 1,0,0,1],
#     [0,0,0, 1,0,0,1],
#     [0,0.3,0, 1,0,0,1],
# ]))

# map: random points
train_pc = create_random_pc(n=100, mu=[0,0,0], sigma=5, rgba=np.array([1,0,0,1]), shape='line')
init_map = create_random_pc(n=25, mu=0, sigma=2, rgba=np.array([1,0,0,1]))

radii_multiplier = 5
map_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=True, scale_trainable=False)
map_model.create_from_pcd(init_map, initial_scaling=0.5)
input_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=False)
renderer = RKHSRendererGlobalScale(white_bkgd=True)
input_model.create_from_pcd(train_pc, initial_scaling=0.1)

fig = go.Figure()
input_frames = []

camera_intrinsic = [256, 256, 355, 355] # cx, cy, fx, fy
n_cameras = 1
delta_deg = 30.0
camera_c2w_init = SE3.Tz(-10)
for i in range(n_cameras):
    camera_c2w = SE3.Ry(delta_deg*i / 180 * np.pi)@camera_c2w_init
    camera_data = create_camera(*camera_intrinsic, camera_c2w)
    camera = to_viewpoint_camera(camera_data)
    input_frame = renderer(
        camera,
        input_model.get_xyz,
        input_model.get_opacity,
        input_model.get_scaling,
        input_model.get_features,
        radii_multiplier=radii_multiplier
    )
    input_frames.append(input_frame)
    plot_camera(fig, camera_c2w.R, camera_c2w.t, 3, f'camera {i}', True)

plot_pc(fig, train_pc, 'train_pc', marker_line_width=2)
plot_pc(fig, init_map, 'init_map')
fig.show()

In [None]:
folder_name = datetime.datetime.now().strftime("%Y-%m-%d__%H-%M-%S")
folder_name = 'test'
results_folder = pathlib.Path('../result/'+folder_name)
results_folder.mkdir(parents=True, exist_ok=True)

n_train = 300

trainer = GSSTrainer(
    model=map_model,
    input_model=input_model,
    renderer=renderer,
    # data=data,
    use_input_frames=True,
    input_frames=input_frames,
    train_batch_size=1, 
    train_num_steps=n_train,
    i_image=n_train//40,
    train_lr=1e-2,#3e-3
    amp=True,
    fp16=False,
    results_folder=results_folder,
    use_rkhs_rgb=True,
    use_rkhs_geo=True,
    min_scale=0.1,
    radii_multiplier=radii_multiplier,
    writer=False
)

trainer.on_evaluate_step()
trainer.train()

# ic(input_model.get_xyz, input_model.get_features)
# ic(map_model.get_xyz, map_model.get_features)
# ic(init_map.coords, init_map.select_channels(['R', 'G', 'B', 'A']))

In [None]:
fig = go.Figure()

plot_pc(fig, train_pc, 'train_pc', marker_line_width=2, marker_line_color='black')
plot_pc(fig, map_model.to_pc(), 'rkhs_map', marker_line_width=2, marker_line_color='yellow')
# use same scale for x,y,z
fig.update_scenes(aspectratio=dict(x=1, y=0.001, z=0.001))
fig.show()