# Test RKHS Loss

In [1]:
import torch
import numpy as np
import rkhs_splatting.utils as utils
from rkhs_splatting.trainer import Trainer
import rkhs_splatting.utils.loss_utils as loss_utils
from rkhs_splatting.utils.data_utils import read_all
from rkhs_splatting.utils.camera_utils import to_viewpoint_camera
from rkhs_splatting.utils.point_utils import get_point_clouds, get_point_clouds_tiles
from rkhs_splatting.gauss_model import GaussModelGlobalScale
from rkhs_splatting.gauss_render import GaussRendererGlobalScale
import datetime
import pathlib
from icecream import ic

import contextlib

from pytorch_memlab import LineProfiler
from torch.profiler import profile, ProfilerActivity
from torch.utils.tensorboard import SummaryWriter

USE_GPU_PYTORCH = True
USE_PROFILE = False

class GSSTrainer(Trainer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.data = kwargs.get('data')
        self.gaussRender = GaussRendererGlobalScale(**kwargs.get('render_kwargs', {}))
        self.lambda_dssim = 0.2
        self.lambda_depth = 0.0
        # create a file self.results_folder / f'eval.csv'
        with open(self.results_folder / 'eval.csv', 'w') as f:
            f.write('iter,loss,total,l1,ssim,depth,psnr\n')
        self.tensorboard_writer = SummaryWriter(log_dir=self.results_folder)


    
    def on_train_step(self):
        # ind = np.random.choice(len(self.data['camera']))
        ind = 0
        camera = self.data['camera'][ind]
        rgb = self.data['rgb'][ind]
        depth = self.data['depth'][ind]
        alpha = self.data['alpha'][ind]
        mask = (self.data['alpha'][ind] > 0.5)
        if USE_GPU_PYTORCH:
            camera = to_viewpoint_camera(camera)

        if USE_PROFILE:
            prof = profile(activities=[ProfilerActivity.CUDA], with_stack=True)
        else:
            prof = contextlib.nullcontext()

        with prof:
            max_scaling = torch.scalar_tensor(0.025, device="cuda")
            if self.model.get_scaling > max_scaling:
                self.model.set_scaling(max_scaling)
            out = self.gaussRender(pc=self.model, camera=camera)

        if USE_PROFILE:
            print(prof.key_averages(group_by_stack_n=True).table(sort_by='self_cuda_time_total', row_limit=20))



        l1_loss = loss_utils.l1_loss(out['render'], rgb)
        depth_loss = loss_utils.l1_loss(out['depth'][..., 0][mask], depth[mask])
        ssim_loss = 1.0-loss_utils.ssim(out['render'], rgb)

        # ic(self.data['camera'][ind].unsqueeze(0).shape)
        # ic(depth.shape)
        # ic(alpha.shape)
        # ic(rgb.shape)

        points = get_point_clouds_tiles(self.data['camera'][ind].unsqueeze(0), depth.unsqueeze(0), alpha.unsqueeze(0), rgb.unsqueeze(0))


        rkhs_loss = loss_utils.rkhs_global_scale_loss(out['tiles'], points, rgb, self.model.get_scaling)


        # total_loss = (1-self.lambda_dssim) * l1_loss + self.lambda_dssim * ssim_loss + depth_loss * self.lambda_depth
        total_loss = rkhs_loss[0] + rkhs_loss[1] - 2*rkhs_loss[2]
        psnr = utils.img2psnr(out['render'], rgb)
        log_dict = {'total': total_loss,'l1':l1_loss, 'ssim': ssim_loss, 'depth': depth_loss, 'psnr': psnr}

        with open(self.results_folder / 'eval.csv', 'a') as f:
            f.write(f'{self.step},{total_loss},{l1_loss},{ssim_loss},{depth_loss},{psnr}\n')

        self.tensorboard_writer.add_scalar('loss/total', total_loss, self.step)
        self.tensorboard_writer.add_scalar('loss/rkhs_loss0', rkhs_loss[0], self.step)
        self.tensorboard_writer.add_scalar('loss/rkhs_loss1', rkhs_loss[1], self.step)
        self.tensorboard_writer.add_scalar('loss/rkhs_loss2', rkhs_loss[2], self.step)
        self.tensorboard_writer.add_scalar('loss/l1', l1_loss, self.step)
        self.tensorboard_writer.add_scalar('loss/ssim', ssim_loss, self.step)
        self.tensorboard_writer.add_scalar('loss/depth', depth_loss, self.step)
        self.tensorboard_writer.add_scalar('params/scaling', self.model.get_scaling, self.step)
        
        self.tensorboard_writer.add_scalar('psnr', psnr, self.step)

        return total_loss, log_dict

    def on_evaluate_step(self, **kwargs):
        import matplotlib.pyplot as plt
        ind = np.random.choice(len(self.data['camera']))
        camera = self.data['camera'][ind]
        if USE_GPU_PYTORCH:
            camera = to_viewpoint_camera(camera)

        rgb = self.data['rgb'][ind].detach().cpu().numpy()
        out = self.gaussRender(pc=self.model, camera=camera)
        rgb_pd = out['render'].detach().cpu().numpy()
        depth_pd = out['depth'].detach().cpu().numpy()[..., 0]
        depth = self.data['depth'][ind].detach().cpu().numpy()
        depth = np.concatenate([depth, depth_pd], axis=1)
        depth = (1 - depth / depth.max())
        depth = plt.get_cmap('jet')(depth)[..., :3]
        image = np.concatenate([rgb, rgb_pd], axis=1)
        image = np.concatenate([image, depth], axis=0)
        utils.imwrite(str(self.results_folder / f'image-{self.step}.png'), image)


device = 'cuda'
folder = './data/B075X65R3X'
data = read_all(folder, resize_factor=0.5)
data = {k: v.to(device) for k, v in data.items()}
data['depth_range'] = torch.Tensor([[1,3]]*len(data['rgb'])).to(device)

# ic(data['camera'].shape)
# ic(data['depth'].shape)
# ic(data['alpha'].shape)
# ic(data['rgb'].shape)


points = get_point_clouds(data['camera'], data['depth'], data['alpha'], data['rgb'])
raw_points = points.random_sample(2**14)
# raw_points.write_ply(open('points.ply', 'wb'))

gaussModel = GaussModelGlobalScale(sh_degree=4, debug=False)
gaussModel.create_from_pcd(pcd=raw_points)

render_kwargs = {
    'white_bkgd': True,
}
# folder_name = datetime.datetime.now().strftime("%Y-%m-%d__%H-%M-%S")
folder_name = 'test'
results_folder = pathlib.Path('result/'+folder_name)
results_folder.mkdir(parents=True, exist_ok=True)

trainer = GSSTrainer(model=gaussModel, 
    data=data,
    train_batch_size=1, 
    train_num_steps=25000,
    i_image =100,
    train_lr=1e-3, 
    amp=False,
    fp16=False,
    results_folder=results_folder,
    render_kwargs=render_kwargs,
)

trainer.on_evaluate_step()
trainer.on_train_step()

  from .autonotebook import tqdm as notebook_tqdm
ic| scales: tensor(0.0100, device='cuda:0')


Number of points at initialisation :  16384


dataloader_config = DataLoaderConfiguration(split_batches=False)


(tensor(-1278.4956, device='cuda:0', grad_fn=<SubBackward0>),
 {'total': tensor(-1278.4956, device='cuda:0', grad_fn=<SubBackward0>),
  'l1': tensor(0.0925, device='cuda:0', grad_fn=<MeanBackward0>),
  'ssim': tensor(0.0887, device='cuda:0', grad_fn=<RsubBackward1>),
  'depth': tensor(0.4860, device='cuda:0', grad_fn=<MeanBackward0>),
  'psnr': 14.633354327967334})

In [2]:
from gaussian_splatting.utils.camera_utils import parse_camera
from icecream import ic
from plotly_utils import *

ind=0
camera = trainer.data['camera'][ind].unsqueeze(0)
rgb = trainer.data['rgb'][ind].unsqueeze(0)
depth = trainer.data['depth'][ind].unsqueeze(0)
alpha = trainer.data['alpha'][ind].unsqueeze(0)
mask = (trainer.data['alpha'][ind] > 0.5).unsqueeze(0)
Hs, Ws, intrinsics, c2ws = parse_camera(camera)
W, H = int(Ws[0].item()), int(Hs[0].item())


points = get_point_clouds_tiles(camera, depth, alpha, rgb)
N = len(points)
M = len(points[0])
P = points[0][0].coords.shape[0]

out = trainer.gaussRender(pc=trainer.model, camera=to_viewpoint_camera(trainer.data['camera'][ind]))
# rkhs_loss = loss_utils.rkhs_global_scale_loss(out['tiles'], points, rgb, trainer.model.get_scaling)

prediction_tiles = out['tiles']
gt_points = points
gt_rgb = rgb
scale3d = trainer.model.get_scaling

mean2d_tiles = prediction_tiles['mean2d']
mean3d_tiles = prediction_tiles['mean3d']
scale2d_tiles = prediction_tiles['scale2d']
label_tiles = prediction_tiles['label']

N = len(mean2d_tiles)
M = len(mean2d_tiles[0])
T = gt_rgb.shape[0]//N

scale3d_squared = scale3d**2

import numpy as np
mean_tile_number = np.mean([scale2d_tiles[v][u].shape[0] for v in range(N) for u in range(M)])

# local map norm, training image norm, inner product
loss = [0, 0, 0]
init = False
# for v in range(N):
#     for u in range(M):
        
#         B = mean2d_tiles[v][u].shape[0]
#         if B == 0 and init:
#             continue
#         init = True
#         print(v,u)

#         pc_tile = gt_points[v][u]
#         pc_tile = pc_tile.random_sample(300)
#         gt_label_tile = torch.from_numpy(pc_tile.select_channels(['R', 'G', 'B'])).to(scale3d.device)
#         gt_points_tile = torch.from_numpy(pc_tile.coords).to(scale3d.device)
#         P = gt_label_tile.shape[0]

#         # ic(M, N, B, P, T)

#         gt_label_tile_unsq0 = gt_label_tile.unsqueeze(0)
#         gt_label_tile_unsq1 = gt_label_tile.unsqueeze(1)
#         gt_points_tile_unsq0 = gt_points_tile.unsqueeze(0)
#         gt_points_tile_unsq1 = gt_points_tile.unsqueeze(1)

#         label_tile = label_tiles[v][u][0][:300] # only rgb for now
#         label_tile_unsq0 = label_tile.unsqueeze(0)
#         label_tile_unsq1 = label_tile.unsqueeze(1)
#         mean3d_tile = mean3d_tiles[v][u][:300]
#         mean3d_tile_unsq0 = mean3d_tile.unsqueeze(0)
#         mean3d_tile_unsq1 = mean3d_tile.unsqueeze(1)

#         # inner product between local map and current frame
#         label2 = (label_tile_unsq1 - gt_label_tile_unsq0).pow(2).sum(-1)
#         point2 = (-0.5 * (mean3d_tile_unsq1 - gt_points_tile_unsq0).pow(2).sum(-1) / scale3d_squared).exp()
#         # point2 = (-0.5 * (mean3d_tile_unsq1 - gt_points_tile_unsq0).abs().pow(3).sum(-1).pow(2/3) / scale3d_squared).exp()
#         loss2 = label2 * point2

#         # local map inner product
#         label0 = (label_tile_unsq1 - label_tile_unsq0).pow(2).sum(-1)
#         point0 = (-0.5 * (mean3d_tile_unsq1 - mean3d_tile_unsq0).pow(2).sum(-1) / scale3d_squared).exp()
#         loss0 = label0 * point0

#         # current frame inner product
#         label1 = (gt_label_tile_unsq1 - gt_label_tile_unsq0).pow(2).sum(-1)
#         point1 = (-0.5 * (gt_points_tile_unsq1 - gt_points_tile_unsq0).pow(2).sum(-1) / scale3d_squared).exp()
#         loss1 = label1 * point1

#         # ic(loss0.sum(), loss1.sum(), loss2.sum())

#         loss[0] = loss0.sum() + loss[0]
#         loss[1] = loss1.sum() + loss[1]
#         loss[2] = loss2.sum() + loss[2]


# ic(loss)

In [3]:
v=2
u=2

B = mean2d_tiles[v][u].shape[0]
pc_tile = gt_points[v][u]
# pc_tile = pc_tile.random_sample(300)
gt_label_tile = torch.from_numpy(pc_tile.select_channels(['R', 'G', 'B'])/255.0).to(scale3d.device)
gt_points_tile = torch.from_numpy(pc_tile.coords).to(scale3d.device)
P = gt_label_tile.shape[0]


gt_label_tile_unsq0 = gt_label_tile.unsqueeze(0)
gt_label_tile_unsq1 = gt_label_tile.unsqueeze(1)
gt_points_tile_unsq0 = gt_points_tile.unsqueeze(0)
gt_points_tile_unsq1 = gt_points_tile.unsqueeze(1)

label_tile = label_tiles[v][u][0] # only rgb for now
label_tile_unsq0 = label_tile.unsqueeze(0)
label_tile_unsq1 = label_tile.unsqueeze(1)
mean3d_tile = mean3d_tiles[v][u]
mean3d_tile_unsq0 = mean3d_tile.unsqueeze(0)
mean3d_tile_unsq1 = mean3d_tile.unsqueeze(1)

In [10]:
ic(M, N, B, P, T)

ic(label_tile_unsq1.shape)
ic(gt_label_tile_unsq0.shape)
ic((label_tile_unsq1 - gt_label_tile_unsq0).shape)

# print first 3
ic(label_tile_unsq1[:3])
ic(gt_label_tile_unsq0[0, :3])
ic((label_tile_unsq1 - gt_label_tile_unsq0).shape)
ic((label_tile_unsq1 - gt_label_tile_unsq0)[:3, :3])
ic((label_tile_unsq1 - gt_label_tile_unsq0).pow(2).sum(-1).shape)
ic((label_tile_unsq1 - gt_label_tile_unsq0).pow(2).sum(-1)[:3, :3])

ic| M: 4, N: 4, B: 3013, P: 4096, T: 0
ic| label_tile_unsq1.shape: torch.Size([3013, 1, 3])
ic| gt_label_tile_unsq0.shape: torch.Size([1, 4096, 3])
ic| (label_tile_unsq1 - gt_label_tile_unsq0).shape: torch.Size([3013, 4096, 3])
ic| label_tile_unsq1[:3]: tensor([[[0.6902, 0.4627, 0.3529]],
                          
                                  [[0.7529, 0.5137, 0.4157]],
                          
                                  [[0.7333, 0.4941, 0.3765]]], device='cuda:0', grad_fn=<SliceBackward0>)
ic| gt_label_tile_unsq0[0, :3]: tensor([[0.7412, 0.4784, 0.3608],
                                        [0.7020, 0.4431, 0.3255],
                                        [0.6784, 0.4196, 0.3020]], device='cuda:0')
ic| (label_tile_unsq1 - gt_label_tile_unsq0).shape: torch.Size([3013, 4096, 3])
ic| (label_tile_unsq1 - gt_label_tile_unsq0)[:3, :3]: tensor([[[-0.0510, -0.0157, -0.0078],
                                                               [-0.0118,  0.0196,  0.0275],
        

tensor([[0.0029, 0.0013, 0.0046],
        [0.0044, 0.0157, 0.0273],
        [0.0006, 0.0062, 0.0141]], device='cuda:0', grad_fn=<SliceBackward0>)

In [23]:
# (-0.5 * (mean3d_tile_unsq1 - gt_points_tile_unsq0).pow(2).sum(-1) / scale3d_squared).exp()
ic(mean3d_tile_unsq1[:3])
ic(gt_points_tile_unsq0[0, :3])
ic((mean3d_tile_unsq1 - gt_points_tile_unsq0).shape)
ic((mean3d_tile_unsq1 - gt_points_tile_unsq0)[:3,:3])
ic((mean3d_tile_unsq1 - gt_points_tile_unsq0).pow(2).sum(-1)[:3,:3])
ic(scale3d_squared)
ic(((mean3d_tile_unsq1 - gt_points_tile_unsq0).pow(2).sum(-1) / scale3d_squared))

ic| mean3d_tile_unsq1[:3]: tensor([[[0.3049, 0.5026, 0.4256]],
                           
                                   [[0.2670, 0.4231, 0.4998]],
                           
                                   [[0.3084, 0.4932, 0.4271]]], device='cuda:0', grad_fn=<SliceBackward0>)
ic| gt_points_tile_unsq0[0, :3]: tensor([[0.2326, 0.3935, 0.5118],
                                         [0.2278, 0.3926, 0.5082],
                                         [0.2229, 0.3917, 0.5045]], device='cuda:0')
ic| (mean3d_tile_unsq1 - gt_points_tile_unsq0).shape: torch.Size([3013, 4096, 3])
ic| (mean3d_tile_unsq1 - gt_points_tile_unsq0)[:3,:3]: tensor([[[ 0.0723,  0.1090, -0.0863],
                                                                [ 0.0771,  0.1100, -0.0826],
                                                                [ 0.0820,  0.1109, -0.0789]],
                                                       
                                                               [[ 0.0344, 

 device='cuda:0', grad_fn=<SliceBackward0>)
ic| scale3d_squared: tensor(1.0000e-04, device='cuda:0', grad_fn=<PowBackward0>)
ic| (mean3d_tile_unsq1 - gt_points_tile_unsq0).pow(2).sum(-1) / scale3d_squared: tensor([[  245.5807,   248.6852,   252.5532,  ...,  6337.1797,  6453.3540,
                                                                                           6728.5181],
                                                                                         [   21.9938,    25.3738,    29.5216,  ...,  6931.5171,  7050.3867,
                                                                                           7331.9058],
                                                                                         [  228.6542,   232.0342,   236.1821,  ...,  6371.8721,  6488.8457,
                                                                                           6764.6738],
                                                                                         ...,
    

tensor([[  245.5807,   248.6852,   252.5532,  ...,  6337.1797,  6453.3540,
          6728.5181],
        [   21.9938,    25.3738,    29.5216,  ...,  6931.5171,  7050.3867,
          7331.9058],
        [  228.6542,   232.0342,   236.1821,  ...,  6371.8721,  6488.8457,
          6764.6738],
        ...,
        [15080.5449, 14954.1230, 14828.4131,  ...,  5082.9565,  5067.2808,
          4968.5518],
        [15146.9541, 15018.6699, 14891.0811,  ...,  5011.4873,  4992.7109,
          4890.0830],
        [15301.5107, 15171.3252, 15041.8262,  ...,  4958.8579,  4937.0625,
          4829.7988]], device='cuda:0', grad_fn=<DivBackward0>)

: 

In [20]:
import numpy as np
np.sum(np.array([0.0758,  0.0996, -0.0848])**2)
# np.array([0.3084, 0.4932, 0.4271])-np.array([0.2326, 0.3935, 0.5118])

0.02285684