In [None]:
import os
import torch
import json
from tqdm.auto import tqdm
from gsplat.sh import spherical_harmonics
from internal.models.gaussian_model_simplified import GaussianModelSimplified
from internal.renderers.gsplat_renderer import GSPlatRenderer
from internal.renderers.gsplat_hit_pixel_count_renderer import GSplatHitPixelCountRenderer
from internal.dataparsers.colmap_dataparser import Colmap, ColmapDataParser
from internal.configs.dataset import ColmapParams
from internal.utils.sh_utils import RGB2SH
from internal.utils.gaussian_model_loader import GaussianModelLoader
from internal.utils.light_gaussian import get_count_and_score, calculate_v_imp_score, get_prune_mask

In [None]:
torch.autograd.set_grad_enabled(False)

In [None]:
partition_base_dir = os.path.expanduser("~/data/image_set/JNUCar_undistorted/colmap/drone/dense_max_2048/0/partitions-threshold_0.2/")

In [None]:
dataparser_outputs = ColmapDataParser(
    os.path.join(partition_base_dir, ".."),
    output_path=os.getcwd(),
    global_rank=0,
    params=Colmap(
        appearance_groups="appearance_image_dedicated",
        eval_step=32
    ),
).get_outputs()

In [None]:
image_name_to_camera_idx = {}
for idx, name in enumerate(dataparser_outputs.train_set.image_names):
    image_name_to_camera_idx[name] = idx
len(image_name_to_camera_idx)

In [None]:
def get_cameras_by_image_list(image_list: list):
    cameras = []
    for i in image_list:
        cameras.append(dataparser_outputs.train_set.cameras[image_name_to_camera_idx[i]])
    return cameras

In [None]:
partitions = torch.load(os.path.join(partition_base_dir, "partitions.pt"),
                        map_location="cpu")

orientation_transformation = partitions["orientation_transformation"]

model_paths = []
for idx, i in enumerate(partitions["ids"]):
    if len(partitions["image_indices"][idx]) < 32:
        continue
    model_paths.append((i, os.path.join(
        "../outputs/JNUAerial-0526/",
        f"P_{i[0]:03d}_{i[1]:03d}.txt")))
partitions.keys(), model_paths

In [None]:
partition_id_to_index = {i: idx for idx, i in enumerate(partitions["ids"])}
partition_id_to_index

In [None]:
dtype = torch.float
device = torch.device("cpu")

n_sh_degrees = 0

DICT_KEY_PREFIX = "gaussian_model._"

prune_percent = 0.6

n_gaussians_before_pruning = 0
n_gaussians_after_pruning = 0

with tqdm(model_paths) as t:
    for i in t:
        if i[0] != (0, 0):
            continue
            
        partition_id = i[0]
        model_output_path = i[1]
            
        partition_xy = partitions["xys"][partition_id_to_index[i[0]]]
        load_file = GaussianModelLoader.search_load_file(i[1])
        t.set_description(f"{partition_xy}: {load_file}")
        ckpt = torch.load(load_file, map_location=device)
        xyz = ckpt["state_dict"]["gaussian_model._xyz"]
        reoriented_xyz = xyz @ orientation_transformation[:3, :3].T
        # include min bound, exclude max bound
        is_in_partition = torch.logical_and(torch.ge(reoriented_xyz[:, :2], partition_xy),
                                            torch.lt(reoriented_xyz[:, :2], partition_xy + 2 * partitions["radius"]))
        is_in_partition = torch.logical_and(is_in_partition[:, 0], is_in_partition[:, 1])
        
        # get Gaussians located in partition to construct a new state_dict
        state_dict = {}
        for i in ckpt["state_dict"]:
            if i.startswith(DICT_KEY_PREFIX):
                state_dict[i] = ckpt["state_dict"][i][is_in_partition]
        # construct Gaussian model
        gaussian_model = GaussianModelSimplified.construct_from_state_dict(
            state_dict,
            active_sh_degree=ckpt["hyper_parameters"]["gaussian"].sh_degree,
            device="cuda",
        )
        
        n_gaussians_before_pruning += gaussian_model.get_xyz.shape[0]
        
        # get partition image list
        with open(os.path.join(model_output_path, "cameras.json"), "r") as f:
            cameras_json = json.load(f)
        image_list = [i["img_name"] for i in cameras_json]
        # with open(os.path.join(partition_base_dir, f"{partition_id[0]:03d}_{partition_id[1]:03d}.txt"), "r") as f:
        #     for row in f:
        #         image_list.append(row.rstrip("\n"))
                
        cameras = get_cameras_by_image_list(image_list)
    
        # calculate scores
        hit_camera_count, opacity_score, alpha_score, visibility_score = get_count_and_score(
            gaussian_model,
            cameras,
            anti_aliased=True,
        )
        
        # prune by visibility
        # # get prune indices
        # visibility_score_close_to_zero = torch.isclose(visibility_score, torch.tensor(0.).to(visibility_score))
        # visibility_score_close_to_zero_count = visibility_score_close_to_zero.sum()
        # prune_percent = 0.9
        # # ignore the Gaussians visibility score close zero
        # keep_count = ((visibility_score.shape[0] - visibility_score_close_to_zero_count) * (1 - prune_percent)).to(torch.int)
        # # get the indices (in partition) to be preserved
        # visibility_score_sorted = torch.sort(visibility_score, descending=True)
        # local_indices_to_preserved = visibility_score_sorted.indices[:keep_count].cpu()
        
        # prune by opacity
        v_imp_score = calculate_v_imp_score(gaussian_model.get_scaling, opacity_score, 0.1)
        prune_mask = get_prune_mask(prune_percent, v_imp_score)
        local_indices_to_preserved = (~prune_mask).nonzero().squeeze(-1).cpu()
        
        # prune local state_dict
        for i in state_dict:
            state_dict[i] = state_dict[i][local_indices_to_preserved]
        # # the indices (in partition) to be pruned
        # local_indices_to_prune = visibility_score_sorted.indices[keep_count:]
        # is_in_partition_indices = is_in_partition.nonzero().squeeze(-1)
        # gaussian_indices_to_prune = is_in_partition_indices[local_indices_to_prune.to(device=is_in_partition_indices.device)]
        # 
        # # convert prune indices to preserve mask
        # preserve_mask = torch.ones_like(is_in_partition, dtype=torch.bool)
        # preserve_mask[gaussian_indices_to_prune] = False
        
        # update state_dict of checkpoint
        """
        [NOTE]
        The codes related to the `static part` below have not been released yet.
        So, rather than move some Gaussian to `static part`, you should prune the `state_dict`, `optimizer_states` and `gaussian_model_extra_state_dict` of `ckpt` according to the mask `local_indices_to_preserved`.
        """
        for i in state_dict:
            # move those Gaussian outside the partition to static part, which will not be optimized during finetune
            static_gaussian_property_key = f"static_{i}"
            static_gaussian_property = ckpt["state_dict"][i][~is_in_partition]
            if static_gaussian_property_key in ckpt["state_dict"]:
                original_static_gaussian_num = ckpt["state_dict"][static_gaussian_property_key].shape[0]
                new_static_gaussian_num = static_gaussian_property.shape[0]
                static_gaussian_property = torch.concat([
                    ckpt["state_dict"][static_gaussian_property_key],
                    static_gaussian_property,
                ], dim=0)
                print(f"#{partition_id}: {original_static_gaussian_num} static Gaussians exists, merge with {new_static_gaussian_num} new static Gaussians, total {static_gaussian_property.shape[0]} after merging")
            ckpt["state_dict"][static_gaussian_property_key] = static_gaussian_property
            # make optimizable Gaussians only contains those locating in partition
            ckpt["state_dict"][i] = state_dict[i]
    
        # prune optimizer state
        for i in ckpt["optimizer_states"][0]["state"]:
            for j in ["exp_avg", "exp_avg_sq"]:
                ckpt["optimizer_states"][0]["state"][i][j] = ckpt["optimizer_states"][0]["state"][i][j][is_in_partition][local_indices_to_preserved]
        
        # prune extra state_dict
        for i in ["max_radii2D", "xyz_gradient_accum", "denom"]:
            ckpt["gaussian_model_extra_state_dict"][i] = ckpt["gaussian_model_extra_state_dict"][i][is_in_partition][local_indices_to_preserved]
            
        n_gaussians_after_pruning += local_indices_to_preserved.shape[0]
        
        # save checkpoint
        checkpoint_save_dir = os.path.join(model_output_path, "pruned_checkpoints")
        os.makedirs(checkpoint_save_dir, exist_ok=True)
        torch.save(ckpt, os.path.join(checkpoint_save_dir, f"latest-opacity_pruned-{prune_percent}.ckpt"))
        
f"{n_gaussians_after_pruning} / {n_gaussians_before_pruning}"