Function: Render multi-view images by poses from a pretrained GS model (gsplat) and save them to disk.

In [1]:
import json
import math
import os
import time
from dataclasses import dataclass, field
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union

import imageio
import nerfview
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import tyro
import viser
import yaml
from datasets.colmap import Dataset, Parser
from datasets.traj import (
    generate_interpolated_path,
    generate_ellipse_path_z,
    generate_spiral_path,
)
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from fused_ssim import fused_ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from typing_extensions import Literal, assert_never
from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed
from lib_bilagrid import (
    BilateralGrid,
    slice,
    color_correct,
    total_variation_loss,
)

from gsplat.compression import PngCompression
from gsplat.distributed import cli
from gsplat.rendering import rasterization
from gsplat.strategy import DefaultStrategy, MCMCStrategy


In [2]:
from simple_trainer import Runner,Config
# current path is ./examples
cfg1=Config(data_dir="/home/percool/CLab/ggs/datasets/co3d/apple/189_20393_38136_part/",\
            result_dir="results/apple_189_part_1_unnorm",\
                )
cfg2=Config(data_dir="/home/percool/CLab/ggs/datasets/co3d/apple/189_20393_38136/")
# Load data: Training data should contain initial points and colors.
# parser1 = Parser(
#     data_dir=cfg1.data_dir,
#     factor=cfg1.data_factor,
#     normalize=cfg1.normalize_world_space,
#     test_every=cfg1.test_every,
# )
parser2 = Parser(
    data_dir=cfg2.data_dir,
    factor=cfg2.data_factor,
    normalize=cfg2.normalize_world_space,
    test_every=cfg2.test_every,
)

[Parser] 202 images, taken by 1 cameras.


In [3]:
cfg1.data_factor
# parser2.imsize_dict

1

In [4]:
class Runner2(Runner):
    def __init__(
        self, cfg: Config
    ) -> None:
        super().__init__(0,0,1,cfg)
    
    def load_Gaussian_splats(self, ckpt_path: str) -> None:
        """Load splats from a file."""
        print("Loading splats from %s..." % ckpt_path)
        device=self.device
        means, quats, scales, opacities, sh0, shN = [], [], [], [], [], []
        if ckpt_path.endswith(".pt"):
            ckpt = torch.load(ckpt_path, map_location=device)["splats"]
            means.append(ckpt["means"])
            quats.append(F.normalize(ckpt["quats"], p=2, dim=-1))
            scales.append(torch.exp(ckpt["scales"]))
            opacities.append(torch.sigmoid(ckpt["opacities"]))
            sh0.append(ckpt["sh0"])
            shN.append(ckpt["shN"])
        self.means = torch.cat(means, dim=0)
        self.quats = torch.cat(quats, dim=0)
        self.scales = torch.cat(scales, dim=0)
        self.opacities = torch.cat(opacities, dim=0)
        self.sh0 = torch.cat(sh0, dim=0)
        self.shN = torch.cat(shN, dim=0)
        self.colors = torch.cat([self.sh0, self.shN], dim=-2)
        self.sh_degree = int(math.sqrt(self.colors.shape[-2]) - 1)

    # @torch.no_grad()
    # def viewer_render_fn(self,camera_state: nerfview.CameraState, img_wh: Tuple[int, int]):
    #     device=self.device
    #     width, height = img_wh
    #     c2w = camera_state.c2w
    #     K = camera_state.get_K(img_wh)
    #     c2w = torch.from_numpy(c2w).float().to(device)
    #     K = torch.from_numpy(K).float().to(device)
    #     viewmat = c2w.inverse()

    #     backend="gsplat"
    #     if backend == "gsplat":
    #         rasterization_fn = rasterization
    #     elif backend == "inria":
    #         from gsplat import rasterization_inria_wrapper

    #         rasterization_fn = rasterization_inria_wrapper
    #     else:
    #         raise ValueError

    #     render_colors, render_alphas, meta = rasterization_fn(
    #         self.means,  # [N, 3]
    #         self.quats,  # [N, 4]
    #         self.scales,  # [N, 3]
    #         self.opacities,  # [N]
    #         self.colors,  # [N, 3]
    #         viewmat[None],  # [1, 4, 4]
    #         K[None],  # [1, 3, 3]
    #         width,
    #         height,
    #         sh_degree=self.sh_degree,
    #         render_mode="RGB+ED",
    #         # this is to speedup large-scale rendering by skipping far-away Gaussians.
    #         # radius_clip=3,
    #     )
    #     render_rgbs = render_colors[0, ..., 0:3].cpu().numpy()
    #     return render_rgbs
    
    
    def rasterize_splats2(
        self,
        camtoworlds: Tensor,
        Ks: Tensor,
        width: int,
        height: int,
        **kwargs,
    ) -> Tuple[Tensor, Tensor, Dict]:
        means = self.means  # [N, 3]
        # quats = F.normalize(self.splats["quats"], dim=-1)  # [N, 4]
        # rasterization does normalization internally
        quats = self.quats  # [N, 4]
        scales = self.scales  # [N, 3]
        opacities = self.opacities  # [N,]

        image_ids = kwargs.pop("image_ids", None)
        if self.cfg.app_opt:
            colors = self.app_module(
                features=self.splats["features"],
                embed_ids=image_ids,
                dirs=means[None, :, :] - camtoworlds[:, None, :3, 3],
                sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree),
            )
            colors = colors + self.splats["colors"]
            colors = torch.sigmoid(colors)
        else:
            # colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1)  # [N, K, 3]
            colors=self.colors

        rasterize_mode = "antialiased" if self.cfg.antialiased else "classic"
        render_colors, render_alphas, info = rasterization(
            means=means,
            quats=quats,
            scales=scales,
            opacities=opacities,
            colors=colors,
            viewmats=torch.linalg.inv(camtoworlds),  # [C, 4, 4]
            Ks=Ks,  # [C, 3, 3]
            width=width,
            height=height,
            packed=self.cfg.packed,
            absgrad=(
                self.cfg.strategy.absgrad
                if isinstance(self.cfg.strategy, DefaultStrategy)
                else False
            ),
            sparse_grad=self.cfg.sparse_grad,
            rasterize_mode=rasterize_mode,
            distributed=self.world_size > 1,
            **kwargs,
        )
        return render_colors, render_alphas, info
    
    @torch.no_grad()
    def render_by_poses(self, parser_new: Parser)->None:
        """Entry for rendering by poses."""
        print("Running poses-based rendering...")
        cfg = self.cfg
        device = self.device

        camtoworlds_all = parser_new.camtoworlds

        # camtoworlds_all = np.concatenate(
        #     [
        #         camtoworlds_all,
        #         np.repeat(
        #             np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0
        #         ),
        #     ],
        #     axis=1,
        # )  # [N, 4, 4]

        camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device)
        K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
        width, height = list(self.parser.imsize_dict.values())[0]

        canvas_all,canvas_rgb_all,canvas_depth_all = [],[],[]
        for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"):
            camtoworlds = camtoworlds_all[i : i + 1]
            Ks = K[None]

            renders, _, _ = self.rasterize_splats2(
                camtoworlds=camtoworlds,
                Ks=Ks,
                width=width,
                height=height,
                sh_degree=cfg.sh_degree,
                near_plane=cfg.near_plane,
                far_plane=cfg.far_plane,
                render_mode="RGB+ED",
            )  # [1, H, W, 4]
            colors = torch.clamp(renders[..., 0:3], 0.0, 1.0)  # [1, H, W, 3]
            depths = renders[..., 3:4]  # [1, H, W, 1]
            depths = (depths - depths.min()) / (depths.max() - depths.min())
            canvas_list = [colors, depths.repeat(1, 1, 1, 3)]

            # write images
            canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
            canvas = (canvas * 255).astype(np.uint8)
            canvas_all.append(canvas)

            canvas_rgb = colors.squeeze(0).cpu().numpy()
            canvas_rgb = (canvas_rgb * 255).astype(np.uint8)
            canvas_rgb_all.append(canvas_rgb)

            canvas_depth = depths.repeat(1, 1, 1, 3).squeeze(0).cpu().numpy()
            canvas_depth = (canvas_depth * 255).astype(np.uint8)
            canvas_depth_all.append(canvas_depth)

        # save to images
        images_dir = f"{cfg.result_dir}/images_poses"
        folders = [images_dir+"/rgb",images_dir+"/depth",images_dir+"/combined"]
        canvases=[canvas_rgb_all,canvas_depth_all,canvas_all]
        for folder_onetype,canvas_onetype in zip(folders,canvases):
            os.makedirs(folder_onetype, exist_ok=True)
            for canvas,name in zip(canvas_onetype,parser_new.image_names):
                imageio.imwrite(f"{folder_onetype}/{name}", canvas)
        print(f"Images saved to {images_dir}")

In [5]:
runner2=Runner2(cfg1)
runner2.load_Gaussian_splats("results/apple_189_part_1_unnorm/ckpts/ckpt_29999_rank0.pt")
runner2.render_by_poses(parser2)

[Parser] 30 images, taken by 1 cameras.
Scene scale: 1.63017629678518
Model initialized. Number of GS: 10071


Loading splats from results/apple_189_part_1_unnorm/ckpts/ckpt_29999_rank0.pt...
Running poses-based rendering...


Rendering trajectory: 100%|██████████| 202/202 [00:02<00:00, 81.84it/s] 


Images saved to results/apple_189_part_1_unnorm/images_poses
