In [None]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np

import sys
sys.path.append('../')
from rkhs_splatting.utils.data_utils import read_all
from rkhs_splatting.utils.camera_utils import parse_camera
from rkhs_splatting.rkhs_model import RKHSModel
from rkhs_splatting.rkhs_render import RKHSRenderer
from rkhs_splatting.rkhs_model_global_scale import RKHSModelGlobalScale
from rkhs_splatting.rkhs_render_global_scale import RKHSRendererGlobalScale
from rkhs_splatting.utils.point_utils import get_point_clouds

import datetime
import pathlib
from icecream import ic
from spatialmath import SE3
import plotly.graph_objects as go

from plotly_utils import *
from test_utils import *
from experiments.points_trainer import GSSTrainer

In [None]:
device = 'cuda'
torch.cuda.set_device(1)
folder = '../data/B075X65R3X'
data = read_all(folder, resize_factor=0.5)
data = {k: v.to(device) for k, v in data.items()}


# use only one training image
for key,value in data.items():
    data[key] = value[0:1]

# ic(data['camera'].shape)
# ic(data['depth'].shape)
# ic(data['alpha'].shape)


points = get_point_clouds(data['camera'], data['depth'], data['alpha'], data['rgb'])
# raw_points = points.random_sample(2**14)
raw_points = points.generate_random_noise(2**14)
# raw_points = points.generate_random_color(2**14)
# raw_points.write_ply(open('points.ply', 'wb'))

map_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=True)
map_model.create_from_pcd(pcd=raw_points, initial_scaling=0.01)
input_model = RKHSModelGlobalScale(sh_degree=4, debug=False, trainable=False)
renderer = RKHSRendererGlobalScale(white_bkgd=True)

# folder_name = datetime.datetime.now().strftime("%Y-%m-%d__%H-%M-%S")
folder_name = 'test_sofa'
results_folder = pathlib.Path('../result/'+folder_name)
results_folder.mkdir(parents=True, exist_ok=True)


n_train = 10000

trainer = GSSTrainer(
    model=map_model,
    input_model=input_model,
    renderer=renderer,
    data=data,
    use_input_frames=False,
    train_batch_size=1, 
    train_num_steps=n_train,
    i_image=n_train//400,
    train_lr=1e-4,#3e-3
    amp=True,
    fp16=False,
    results_folder=results_folder,
    use_rkhs_rgb=True,
    use_rkhs_geo=True,
    min_scale=0.005,
    radii_multiplier=3,
    writer=False
)

trainer.on_evaluate_step()
trainer.train()

# open('../result/rkhs_global_scale_loss.txt', 'w').write('')
# with LineProfiler(trainer.gauss_render.render) as lp:
#     try:
#         trainer.train()
#     except torch.cuda.OutOfMemoryError as e:
#         print(e)
#         print('done')
# lp.print_stats(
#     stream=open('../result/rkhs_global_scale_loss.txt', 'a'),
#     columns=('active_bytes.all.allocated', 'active_bytes.all.freed','active_bytes.all.current','active_bytes.all.peak', 'reserved_bytes.all.current')
# )