In [22]:
import torch
from torch import nn

from gaussian_splatting.utils.graphics_utils import getProjectionMatrix2, getWorld2View2
from utils.slam_utils import image_gradient, image_gradient_mask


class Camera(nn.Module):
    def __init__(
        self,
        uid,
        color,
        depth,
        gt_T,
        projection_matrix,
        fx,
        fy,
        cx,
        cy,
        fovx,
        fovy,
        image_height,
        image_width,
        device="cuda:0",
    ):
        super(Camera, self).__init__()
        self.uid = uid
        self.device = device

        T = torch.eye(4, device=device)
        self.R = T[:3, :3]
        self.T = T[:3, 3]
        self.R_gt = gt_T[:3, :3]
        self.T_gt = gt_T[:3, 3]

        self.original_image = color
        self.depth = depth
        self.grad_mask = None

        self.fx = fx
        self.fy = fy
        self.cx = cx
        self.cy = cy
        self.FoVx = fovx
        self.FoVy = fovy
        self.image_height = image_height
        self.image_width = image_width

        self.rgb_pyramid = None
        self.depth_pyramid = None 
        self.intrinsics_pyramid = None 

        self.cam_rot_delta = nn.Parameter(
            torch.zeros(3, requires_grad=True, device=device)
        )
        self.cam_trans_delta = nn.Parameter(
            torch.zeros(3, requires_grad=True, device=device)
        )

        self.exposure_a = nn.Parameter(
            torch.tensor([0.0], requires_grad=True, device=device)
        )
        self.exposure_b = nn.Parameter(
            torch.tensor([0.0], requires_grad=True, device=device)
        )

        self.projection_matrix = projection_matrix.to(device=device)

    @staticmethod
    def init_from_dataset(dataset, idx, projection_matrix):
        gt_color, gt_depth, gt_pose = dataset[idx]
        return Camera(
            idx,
            gt_color,
            gt_depth,
            gt_pose,
            projection_matrix,
            dataset.fx,
            dataset.fy,
            dataset.cx,
            dataset.cy,
            dataset.fovx,
            dataset.fovy,
            dataset.height,
            dataset.width,
            device=dataset.device,
        )

    @staticmethod
    def init_from_gui(uid, T, FoVx, FoVy, fx, fy, cx, cy, H, W):
        projection_matrix = getProjectionMatrix2(
            znear=0.01, zfar=100.0, fx=fx, fy=fy, cx=cx, cy=cy, W=W, H=H
        ).transpose(0, 1)
        return Camera(
            uid, None, None, T, projection_matrix, fx, fy, cx, cy, FoVx, FoVy, H, W
        )

    @property
    def world_view_transform(self):
        return getWorld2View2(self.R, self.T).transpose(0, 1)

    @property
    def full_proj_transform(self):
        return (
            self.world_view_transform.unsqueeze(0).bmm(
                self.projection_matrix.unsqueeze(0)
            )
        ).squeeze(0)

    @property
    def camera_center(self):
        return self.world_view_transform.inverse()[3, :3]

    def update_RT(self, R, t):
        self.R = R.to(device=self.device)
        self.T = t.to(device=self.device)

    def compute_grad_mask(self, config):
        edge_threshold = config["Training"]["edge_threshold"]

        gray_img = self.original_image.mean(dim=0, keepdim=True)
        gray_grad_v, gray_grad_h = image_gradient(gray_img)
        mask_v, mask_h = image_gradient_mask(gray_img)
        gray_grad_v = gray_grad_v * mask_v
        gray_grad_h = gray_grad_h * mask_h
        img_grad_intensity = torch.sqrt(gray_grad_v**2 + gray_grad_h**2)

        if config["Dataset"]["type"] == "replica":
            row, col = 32, 32
            multiplier = edge_threshold
            _, h, w = self.original_image.shape
            for r in range(row):
                for c in range(col):
                    block = img_grad_intensity[
                        :,
                        r * int(h / row) : (r + 1) * int(h / row),
                        c * int(w / col) : (c + 1) * int(w / col),
                    ]
                    th_median = block.median()
                    block[block > (th_median * multiplier)] = 1
                    block[block <= (th_median * multiplier)] = 0
            self.grad_mask = img_grad_intensity
        else:
            median_img_grad_intensity = img_grad_intensity.median()
            self.grad_mask = (
                img_grad_intensity > median_img_grad_intensity * edge_threshold
            )

    def get_pyramid(n_levels):
        """ function to build the levels of the pyramid for tracking
        updates the rgb and the image and, the lists are sorted from 
        coarse to fine lebels"""
        rgb_pyramid = []
        depth_pyramid = []

        for i in reversed(range(n_levels)):
            pass 

    def clean(self):
        self.original_image = None
        self.depth = None
        self.grad_mask = None

        self.cam_rot_delta = None
        self.cam_trans_delta = None

        self.exposure_a = None
        self.exposure_b = None


In [21]:
from utils.pose_utils import update_pose
from utils.slam_utils import get_loss_tracking, get_median_depth
from gaussian_splatting.gaussian_renderer import render

class FrontEnd:
    def __init__(self, config) -> None:
        self.use_every_f_frames = 1
        self.config = config 
        self.n_levels = config['Training']['n_pyramid_levels']

    def tracking(self, cur_frame_idx, viewpoint):
            prev = self.cameras[cur_frame_idx - self.use_every_n_frames]
            viewpoint.update_RT(prev.R, prev.T)
            viewpoint.get_pyramid(self.n_levels)        

            opt_params = []
            opt_params.append(
                {
                    "params": [viewpoint.cam_rot_delta],
                    "lr": self.config["Training"]["lr"]["cam_rot_delta"],
                    "name": "rot_{}".format(viewpoint.uid),
                }
            )
            opt_params.append(
                {
                    "params": [viewpoint.cam_trans_delta],
                    "lr": self.config["Training"]["lr"]["cam_trans_delta"],
                    "name": "trans_{}".format(viewpoint.uid),
                }
            )
            opt_params.append(
                {
                    "params": [viewpoint.exposure_a],
                    "lr": 0.01,
                    "name": "exposure_a_{}".format(viewpoint.uid),
                }
            )
            opt_params.append(
                {
                    "params": [viewpoint.exposure_b],
                    "lr": 0.01,
                    "name": "exposure_b_{}".format(viewpoint.uid),
                }
            )

            pose_optimizer = torch.optim.Adam(opt_params)
            for tracking_itr in range(self.tracking_itr_num):
                if tracking_itr in viewpoint.tracking_itr_splits:
                    viewpoint.down_level()
                render_pkg = render(
                    viewpoint, self.gaussians, self.pipeline_params, self.background
                )
                image, depth, opacity = (
                    render_pkg["render"],
                    render_pkg["depth"],
                    render_pkg["opacity"],
                )
                pose_optimizer.zero_grad()
                loss_tracking = get_loss_tracking(
                    self.config, image, depth, opacity, viewpoint
                )
                loss_tracking.backward()

                with torch.no_grad():
                    pose_optimizer.step()
                    converged = update_pose(viewpoint)

                # if tracking_itr % 10 == 0:
                #     self.q_main2vis.put(
                #         gui_utils.GaussianPacket(
                #             current_frame=viewpoint,
                #             gtcolor=viewpoint.original_image,
                #             gtdepth=viewpoint.depth
                #             if not self.monocular
                #             else np.zeros((viewpoint.image_height, viewpoint.image_width)),
                #         )
                #     )
                if converged:
                    break

            self.median_depth = get_median_depth(depth, opacity)
            return render_pkg

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [19]:
from utils.config_utils import load_config
from utils.dataset import load_dataset

path_config = "./configs/rgbd/tum/fr1_desk.yaml"
config = load_config(path_config)
dataset = load_dataset(args=None, path=None, config=config)
projection_matrix = getProjectionMatrix2(
    znear=0.01,
    zfar=100, 
    cx=dataset.cx, 
    cy=dataset.cy, 
    fx=dataset.fx, 
    fy=dataset.fy, 
    W=dataset.width, 
    H=dataset.height
)
viewpoint = Camera.init_from_dataset(dataset=dataset, idx=0, projection_matrix=projection_matrix)

In [35]:
import torch.nn.functional as F
import matplotlib.pyplot as plt 

downsample_rgb = F.interpolate(viewpoint.original_image.unsqueeze(0), 
                                 scale_factor=0.5, 
                                 mode="bilinear", 
                                 align_corners=False).squeeze()

down_depth = F.interpolate(torch.from_numpy(viewpoint.depth).unsqueeze(0).unsqueeze(0), 
                                 scale_factor=0.5, 
                                 mode="bilinear", 
                                 align_corners=False).squeeze()

# plt.imshow(down_depth.cpu())