In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np

import sys
sys.path.append('../')
from rkhs_splatting.utils.camera_utils import parse_camera
from rkhs_splatting.rkhs_model import RKHSModel
from rkhs_splatting.rkhs_render import RKHSRenderer
from rkhs_splatting.rkhs_model_global_scale import RKHSModelGlobalScale
from rkhs_splatting.rkhs_render_global_scale import RKHSRendererGlobalScale
from rkhs_splatting.utils.dataloader import TartanAirLoader

import datetime
import pathlib
from icecream import ic
from spatialmath import SE3
import plotly.graph_objects as go

from plotly_utils import *
from test_utils import *
from experiments.points_trainer import GSSTrainer

torch.cuda.set_device(1)

In [None]:
"""
global scale
"""

#config
input_source = 'tartanair' # tartanair, sample, manual, random
scale = 0.1
input_initial_scaling = scale
map_initial_scaling = scale
map_minimum_scaling = scale
radii_multiplier = 6
scale_trainable = False
tile_size = 64

# input
train_pcs = []
cameras = []
if input_source=='sample':
    train_pcs, cameras = load_sample_dataset('../data/B075X65R3X', [0,1], 0.5)
elif input_source=='tartanair':
    dataset = TartanAirLoader('/home/junzhewu/dataset/rzh/tartanair/soulcity/Easy/P001')
    train_pcs, cameras = load_custom_dataset(dataset, [0,1,30], 0.5)
    tile_size = 80
elif input_source=='manual':
    train_pcs = [create_pc(np.array([
        [1,1,0, 1,0,0,1],
        [-1,-1,0, 1,0,0,1],
    ]))]
elif input_source=='random':
    train_pcs = [create_random_pc(n=4000, mu=[0,0,0], sigma=5, rgba=np.array([1,0,0,1]), shape='line')]

# initial map
# init_map = create_pc(np.array([
#     [1.5,1.5,0, 1,0,0,1],
#     [-1.5,-1.5,0, 1,0,0,1],
#     [0,0,0, 1,0,0,1],
#     [0,0.3,0, 1,0,0,1],
# ]))
# init_map = create_random_pc(n=2**10, mu=0, sigma=1, rgba=np.array([1,0,0,1]))
init_map = train_pcs[0].random_sample(2**14)
# init_map = train_pcs[0].generate_random_noise(2**12)
# init_map = train_pcs[0].generate_random_color(2**10)

# cameras
if len(cameras)==0:
    camera_intrinsic = [355, 355, 128, 128] # fx, fy, cx, cy
    n_cameras = 1
    delta_deg = 30.0
    camera_c2w_init = SE3.Tz(-10)
    for i in range(n_cameras):
        camera_c2w = SE3.Ry(delta_deg*i / 180 * np.pi)@camera_c2w_init
        camera_data = create_camera(*camera_intrinsic, camera_c2w)
        cameras.append(to_viewpoint_camera(camera_data))

# render
map_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=True, scale_trainable=scale_trainable)
map_model.create_from_pcd(init_map, initial_scaling=map_initial_scaling)
input_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=False)
renderer = RKHSRendererGlobalScale(white_bkgd=True)

input_frames = []
for i, camera in enumerate(cameras):
    # single frame or multiple frames
    if len(train_pcs)==1:
        if i==0:
            input_model.create_from_pcd(train_pcs[0], initial_scaling=input_initial_scaling)
    else:
        input_model.create_from_pcd(train_pcs[i], initial_scaling=input_initial_scaling)
    # render
    input_frame = renderer(
        camera,
        input_model.get_xyz,
        input_model.get_opacity,
        input_model.get_scaling,
        input_model.get_features,
        radii_multiplier=radii_multiplier,
        tile_size=tile_size
    )
    input_frames.append(input_frame)


# plot
fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='data'))
for i in range(len(cameras)):
    camera_c2w = cameras[i].c2w.cpu().detach().numpy()
    plot_camera(fig, camera_c2w[:3,:3], camera_c2w[:3,3], 3, f'camera {i}', True)

    train_pc = train_pcs[i].random_sample(2**12)
    plot_pc(fig, train_pc, 'train_pc', marker_line_width=0, marker_size=2)
    map_pc = init_map
    plot_pc(fig, map_pc, 'map_pc', marker_line_width=0, marker_size=2)
fig.show()

# show_image(dataset.read_current_rgbd()[0])
# show_image(dataset.read_current_rgbd()[1].clip(0,100))

In [None]:
folder_name = datetime.datetime.now().strftime("%Y-%m-%d__%H-%M-%S")
folder_name = 'test'
results_folder = pathlib.Path('../result/'+folder_name)
results_folder.mkdir(parents=True, exist_ok=True)

n_train = 1000

trainer = GSSTrainer(
    model=map_model,
    input_model=input_model,
    renderer=renderer,
    # data=data,
    use_input_frames=True,
    input_frames=input_frames,
    train_batch_size=1, 
    train_num_steps=n_train,
    i_image=n_train//40,
    train_lr=1e-2,#3e-3
    amp=True,
    fp16=False,
    results_folder=results_folder,
    use_rkhs_rgb=True,
    use_rkhs_geo=True,
    min_scale=map_minimum_scaling,
    radii_multiplier=radii_multiplier,
    tile_size=tile_size,
    writer=False
)

trainer.on_evaluate_step()
trainer.train()

# ic(input_model.get_xyz, input_model.get_features)
# ic(map_model.get_xyz, map_model.get_features)
# ic(init_map.coords, init_map.select_channels(['R', 'G', 'B', 'A']))

In [None]:
fig = go.Figure()
fig.update_layout(scene=dict(aspectmode='cube'))
train_pc = train_pcs[0].random_sample(2**14)
plot_pc(fig, train_pc, 'train_pc', marker_line_width=0, marker_line_color='black', marker_size=1)
plot_pc(fig, map_model.to_pc(), 'rkhs_map', marker_line_width=0, marker_line_color='yellow', marker_size=1)
# fig.update_scenes(aspectratio=dict(x=1, y=0.001, z=0.001))
fig.show()