In [None]:
import os
import torch
import json
from tqdm.auto import tqdm
from gsplat.sh import spherical_harmonics
from internal.models.gaussian_model_simplified import GaussianModelSimplified
from internal.renderers.gsplat_renderer import GSPlatRenderer
from internal.renderers.gsplat_hit_pixel_count_renderer import GSplatHitPixelCountRenderer
from internal.dataparsers.colmap_dataparser import Colmap, ColmapDataParser
from internal.utils.sh_utils import RGB2SH
from internal.utils.gaussian_model_loader import GaussianModelLoader

In [None]:
torch.autograd.set_grad_enabled(False)

In [None]:
partition_base_dir = os.path.expanduser(
    "~/data/image_set/JNUCar_undistorted/colmap/drone/dense_max_2048/0/partitions-threshold_0.2/")

In [None]:
dataparser_outputs = ColmapDataParser(
    os.path.join(partition_base_dir, ".."),
    output_path=os.getcwd(),
    global_rank=0,
    params=Colmap(
        appearance_groups="appearance_image_dedicated",
        eval_step=32
    ),
).get_outputs()

In [None]:
image_name_to_camera_idx = {}
for idx, name in enumerate(dataparser_outputs.train_set.image_names):
    image_name_to_camera_idx[name] = idx
image_name_to_camera_idx

In [None]:
def get_cameras_by_image_list(image_list: list):
    cameras = []
    for i in image_list:
        cameras.append(dataparser_outputs.train_set.cameras[image_name_to_camera_idx[i]])
    return cameras

In [None]:
partitions = torch.load(os.path.join(partition_base_dir, "partitions.pt"),
                        map_location="cpu")

orientation_transformation = partitions["orientation_transformation"]

model_paths = []
for idx, i in enumerate(partitions["ids"]):
    if len(partitions["image_indices"][idx]) < 32:
        continue
    model_paths.append((i, os.path.join(
        "../outputs/JNUAerial-0526/",
        f"P_{i[0]:03d}_{i[1]:03d}.txt")))
partitions.keys(), model_paths

In [None]:
partition_id_to_index = {i: idx for idx, i in enumerate(partitions["ids"])}
partition_id_to_index

In [None]:
# group name to image list
with open(os.path.join(partition_base_dir, "..", "appearance_image_dedicated.json"), "r") as f:
    appearance_groups = json.load(f)
# image name to group name
image_name_to_group_name = {}
for group_name in appearance_groups:
    image_names = appearance_groups[group_name]
    for image_name in image_names:
        image_name_to_group_name[image_name] = group_name

In [None]:
# def calculate_partition_features_dc(model_output_path, partition_id, ckpt, is_in_partition):
#     # the ids of the groups in this model
#     appearance_group_ids = torch.load(os.path.join(model_output_path, "appearance_group_ids.pth"))
# 
#     # load partition image list, then get their correspond appearance ids
#     partition_used_appearance_ids = []
#     with open(os.path.join(partition_base_dir, f"{partition_id[0]:03d}_{partition_id[1]:03d}.txt"), "r") as f:
#         for row in f:
#             image_group_name = image_name_to_group_name[row.rstrip("\n")]
#             partition_used_appearance_ids.append(appearance_group_ids[image_group_name][0])
# 
#     # get average appearance embeddings of these used images
#     average_appearance_embedding = ckpt["hyper_parameters"]["renderer"].model.embedding(
#         torch.tensor(partition_used_appearance_ids, dtype=torch.int)
#     ).mean(dim=0)
# 
#     # calculate rgb_offset from appearance embedding
#     partition_features_extra = ckpt["state_dict"]["gaussian_model._features_extra"][is_in_partition]
#     appearance_mlp_input = torch.concat([partition_features_extra, average_appearance_embedding.unsqueeze(0).repeat(
#         partition_features_extra.shape[0], 1)], dim=-1)
#     rgb_offset = ckpt["hyper_parameters"]["renderer"].model.network.to("cuda")(appearance_mlp_input.to("cuda")) * 2 - 1.
# 
#     # calculate base_rgb from features_dc
#     base_rgb = spherical_harmonics(0, torch.ones(partition_features_extra.shape[0], 3).to("cuda"),
#                                    ckpt["state_dict"]["gaussian_model._features_dc"][is_in_partition].to("cuda")) + 0.5
# 
#     # calculate final rgb values
#     rgbs = (base_rgb + rgb_offset).clamp(0., 1.)
# 
#     # convert to SHs
#     partition_features_dc = RGB2SH(rgbs).unsqueeze(1)
# 
#     return partition_features_dc

In [None]:
def calculate_gaussian_scores(cameras, gaussian_model):
    hit_count_list = []
    opacity_score_list = []
    alpha_score_list = []
    all_visibility_score = torch.zeros((len(cameras), gaussian_model.get_xyz.shape[0]), dtype=torch.float, device="cpu")
    for idx, camera in tqdm(enumerate(cameras), total=len(cameras)):
        hit_count, opacity_score, alpha_score, visibility_score = GSplatHitPixelCountRenderer.hit_pixel_count(
            means3D=gaussian_model.get_xyz,
            opacities=gaussian_model.get_opacity,
            scales=gaussian_model.get_scaling,
            rotations=gaussian_model.get_rotation,
            viewpoint_camera=camera.to_device("cuda"),
        )
        # hit_count_list.append(hit_count.cpu())
        # opacity_score_list.append(opacity_score.cpu())
        # alpha_score_list.append(alpha_score.cpu())
        all_visibility_score[idx] = visibility_score.cpu()
        # visibility_score_list.append(visibility_score.cpu())

    torch.cuda.empty_cache()

    return all_visibility_score

In [None]:
def calculate_partition_features_dc(image_list, ckpt, is_in_partition):
    # get camera correspond to the image list
    cameras = get_cameras_by_image_list(image_list)
    # build appearance id, the order is the same as the list `cameras`
    camera_index_to_appearance_id = torch.tensor([i.appearance_id for i in cameras], dtype=torch.int)

    # get Gaussians located in partition to construct a new state_dict
    state_dict = {}
    for i in ckpt["state_dict"]:
        if i.startswith("gaussian_model._"):
            state_dict[i] = ckpt["state_dict"][i][is_in_partition]

    # construct Gaussian model
    gaussian_model = GaussianModelSimplified.construct_from_state_dict(
        state_dict,
        active_sh_degree=ckpt["hyper_parameters"]["gaussian"].sh_degree,
        device="cuda",
    )

    # calculate Gaussians' visibility score to each camera
    visibility_score = calculate_gaussian_scores(cameras, gaussian_model).T  # [N_gaussians, N_cameras]
    # calculate total visibility score for each Gaussian
    visibility_score_acc = torch.sum(visibility_score, dim=-1)
    # find Gaussian whose total visibility is closed to zero
    visibility_score_acc_is_close_to_zero = torch.isclose(visibility_score_acc, torch.tensor(0.))
    gaussian_to_preserve = ~visibility_score_acc_is_close_to_zero
    # prune
    for i in state_dict:
        state_dict[i] = state_dict[i][gaussian_to_preserve]
    visibility_score_pruned = visibility_score[~visibility_score_acc_is_close_to_zero]
    del visibility_score

    # get top `n_average_cameras` visibility cameras
    n_average_cameras = 32
    visibility_score_pruned_sorted = torch.topk(visibility_score_pruned, k=n_average_cameras, dim=-1)
    visibility_score_pruned_top_k_acc = torch.sum(visibility_score_pruned_sorted.values, dim=-1, keepdim=True)
    # calculate the weight of each camera
    visibility_score_pruned_top_k_pdf = visibility_score_pruned_sorted.values / visibility_score_pruned_top_k_acc
    assert torch.all(torch.isclose(visibility_score_pruned_top_k_pdf.sum(dim=-1), torch.tensor(1.)))

    # pick appearance id
    appearance_ids = camera_index_to_appearance_id[
        visibility_score_pruned_sorted.indices.reshape(-1)]  # [N_gaussians * n_average_cameras]
    # pick appearance embeddings
    appearance_embeddings = ckpt["hyper_parameters"]["renderer"].model.embedding(appearance_ids).reshape((
        visibility_score_pruned_sorted.indices.shape[
            0],
        n_average_cameras,
        -1))  # [N_gaussians, n_average_cameras, N_embedding_dims]
    # multiple embedding with camera weighs
    weighted_appearance_embeddings = appearance_embeddings * visibility_score_pruned_top_k_pdf.unsqueeze(-1)
    # merge `n_average_cameras` embedding to a single embedding
    final_appearance_embeddings = torch.sum(weighted_appearance_embeddings, dim=1)

    # embedding network forward, output rgb_offset
    embedding_network = ckpt["hyper_parameters"]["renderer"].model.network.to("cuda")
    input_tensor = torch.concat([
        state_dict["gaussian_model._features_extra"],
        final_appearance_embeddings,
    ], dim=-1).to("cuda")
    rgb_offset = embedding_network(input_tensor)
    # convert rgb_offset to SHs
    sh_offset = RGB2SH(rgb_offset)

    new_features_dc = state_dict["gaussian_model._features_dc"] + sh_offset.unsqueeze(1).cpu()

    return new_features_dc, gaussian_to_preserve

In [None]:
dtype = torch.float
device = torch.device("cpu")

n_sh_degrees = 0
xyzs = torch.zeros((0, 3), dtype=dtype, device=device)
features_dc = torch.zeros((0, 1, 3), dtype=dtype, device=device)
features_rest = torch.zeros((0, ((n_sh_degrees + 1) ** 2 - 1), 3), dtype=dtype, device=device)
scales = torch.zeros((0, 3), dtype=dtype, device=device)
rotations = torch.zeros((0, 4), dtype=dtype, device=device)
opacities = torch.zeros((0, 1), dtype=dtype, device=device)

with tqdm(model_paths) as t:
    for i in t:
        # if i[0] != (0, 0):
        #     continue
        with open(os.path.join(i[1], "cameras.json"), "r") as f:
            cameras_json = json.load(f)
        image_list = [i["img_name"] for i in cameras_json]
        
        partition_xy = partitions["xys"][partition_id_to_index[i[0]]]
        load_file = GaussianModelLoader.search_load_file(i[1])
        t.set_description(f"{partition_xy}: {load_file}")
        ckpt = torch.load(load_file, map_location=device)
        xyz = ckpt["state_dict"]["gaussian_model._xyz"]
        reoriented_xyz = xyz @ orientation_transformation[:3, :3].T
        # include min bound, exclude max bound
        is_in_partition = torch.logical_and(torch.ge(reoriented_xyz[:, :2], partition_xy),
                                            torch.lt(reoriented_xyz[:, :2], partition_xy + 2 * partitions["radius"]))
        is_in_partition = torch.logical_and(is_in_partition[:, 0], is_in_partition[:, 1])

        """
        [NOTE]: `calculate_partition_features_dc` is related to `GSplatAppearanceEmbeddingRenderer`, if you have not enabled this model, you can simply comment out the part calculating the `new_features_dc`, and use the original `features_dc`.
        """
        new_features_dc, gaussian_to_preserve = calculate_partition_features_dc(image_list, ckpt, is_in_partition)
        gaussian_to_preserve = gaussian_to_preserve.to(device)

        xyzs = torch.concat([xyzs, ckpt["state_dict"]["gaussian_model._xyz"][is_in_partition][gaussian_to_preserve]])
        features_dc = torch.concat([features_dc, new_features_dc.to(device)])
        features_rest = torch.concat(
            [features_rest, ckpt["state_dict"]["gaussian_model._features_rest"][is_in_partition][gaussian_to_preserve]])
        scales = torch.concat(
            [scales, ckpt["state_dict"]["gaussian_model._scaling"][is_in_partition][gaussian_to_preserve]])
        rotations = torch.concat(
            [rotations, ckpt["state_dict"]["gaussian_model._rotation"][is_in_partition][gaussian_to_preserve]])
        opacities = torch.concat(
            [opacities, ckpt["state_dict"]["gaussian_model._opacity"][is_in_partition][gaussian_to_preserve]])

        torch.cuda.empty_cache()

In [None]:
# model_output_path = i[1]
# model_output_path

In [None]:
# # load partition image list, then get their correspond appearance ids
# partition_id = i[0]
# partition_used_appearance_ids = []
# image_list = []
# with open(os.path.join(partition_base_dir, f"{partition_id[0]:03d}_{partition_id[1]:03d}.txt"), "r") as f:
#     for row in f:
#         image_list.append(row.rstrip("\n"))
# 
# cameras = get_cameras_by_image_list(image_list)
# state_dict = {}
# for i in ckpt["state_dict"]:
#     if i.startswith("gaussian_model._"):
#         state_dict[i] = ckpt["state_dict"][i][is_in_partition]
# gaussian_model = GaussianModelSimplified.construct_from_state_dict(
#     state_dict,
#     active_sh_degree=ckpt["hyper_parameters"]["gaussian"].sh_degree,
#     device="cuda",
# )

In [None]:
# for i in state_dict:
#     ckpt["state_dict"][i] = state_dict[i]
# torch.save(ckpt, os.path.join(model_output_path, "checkpoints", "pruned.ckpt"))

In [None]:
import gc

gc.collect()

In [None]:
# # get the contribution to every camera of each Gaussian
# visibility_score = calculate_gaussian_scores(cameras, gaussian_model)  # [N_cameras, N_gaussians]

In [None]:
# visibility_score = visibility_score.T

In [None]:
# visibility_score.shape  # [N_gaussians, N_cameras]

In [None]:
# visibility_score_acc = torch.sum(visibility_score, dim=-1)

In [None]:
# visibility_score_acc_is_close_to_zero = torch.isclose(visibility_score_acc, torch.tensor(0.))
# visibility_score_acc_is_close_to_zero.sum()

In [None]:
# visibility_score_acc_is_close_to_zero.shape

In [None]:
# for i in state_dict:
#     state_dict[i] = state_dict[i][~visibility_score_acc_is_close_to_zero]
# for i in state_dict:
#     ckpt["state_dict"][i] = state_dict[i]
# torch.save(ckpt, os.path.join(model_output_path, "checkpoints", "pruned-closed-to-zero.ckpt"))

In [None]:
# visibility_score_pruned = visibility_score[~visibility_score_acc_is_close_to_zero]

In [None]:
# # get top `n_average_cameras` camera by contribution
# n_average_cameras = 32
# visibility_score_pruned_sorted = torch.topk(visibility_score_pruned, k=n_average_cameras, dim=-1)

In [None]:
# visibility_score_pruned_sorted.indices.shape

In [None]:
# visibility_score_pruned_sorted.indices[0], visibility_score_pruned_sorted.values[0]

In [None]:
# if some visibility of the 2nd-Kst cameras are closed to zero, use the 1st camera
# visibility_score_sorted.indices[:, 1:] = torch.where(
#     torch.isclose(visibility_score_sorted.values[:, 1:], torch.tensor(0.)),
#     visibility_score_sorted.indices[:, :1],
#     visibility_score_sorted.indices[:, 1:],
# )
# visibility_score_sorted.indices[0], visibility_score_sorted.values[0]

In [None]:
# visibility_score_pruned_top_k_acc = torch.sum(visibility_score_pruned_sorted.values, dim=-1, keepdim=True)
# visibility_score_pruned_top_k_acc.shape

In [None]:
# calculate the weight of each camera
# visibility_score_pruned_top_k_pdf = visibility_score_pruned_sorted.values / visibility_score_pruned_top_k_acc
# visibility_score_pruned_top_k_pdf.shape, torch.all(
#     torch.isclose(visibility_score_pruned_top_k_pdf.sum(dim=-1), torch.tensor(1.)))

In [None]:
# camera_index_to_appearance_id = torch.tensor([i.appearance_id for i in cameras], dtype=torch.int)
# camera_index_to_appearance_id

In [None]:
# appearance_ids = camera_index_to_appearance_id[visibility_score_pruned_sorted.indices.reshape(-1)]
# appearance_ids.shape

In [None]:
# appearance_embeddings = ckpt["hyper_parameters"]["renderer"].model.embedding(appearance_ids).reshape(
#     (visibility_score_pruned_sorted.indices.shape[0], n_average_cameras, -1))

In [None]:
# weighted_appearance_embeddings = appearance_embeddings * visibility_score_pruned_top_k_pdf.unsqueeze(-1)

In [None]:
# final_appearance_embeddings = torch.sum(weighted_appearance_embeddings, dim=1)
# final_appearance_embeddings.shape

In [None]:
# embedding_network = ckpt["hyper_parameters"]["renderer"].model.network.to("cuda")

In [None]:
# input_tensor = torch.concat([state_dict["gaussian_model._features_extra"], final_appearance_embeddings], dim=-1).to(
#     "cuda")

In [None]:
# rgb_offset = embedding_network(input_tensor)
# rgb_offset

In [None]:
# sh_offset = RGB2SH(rgb_offset)

In [None]:
# state_dict["gaussian_model._features_dc"].shape, sh_offset.unsqueeze(1).cpu().shape

In [None]:
# new_features_dc = state_dict["gaussian_model._features_dc"] + sh_offset.unsqueeze(1).cpu()

In [None]:
# state_dict["gaussian_model._features_dc"] = new_features_dc

In [None]:
state_dict_key_to_delete = []
for i in ckpt["state_dict"]:
    if i.startswith("renderer."):
        state_dict_key_to_delete.append(i)
for i in state_dict_key_to_delete:
    del ckpt["state_dict"][i]

In [None]:
# ckpt["hyper_parameters"]["renderer"] = GSPlatRenderer()

In [None]:
# for i in state_dict:
#     ckpt["state_dict"][i] = state_dict[i]
# torch.save(ckpt, os.path.join(model_output_path, "checkpoints", "pruned-gsplat-vanilla-renderer.ckpt"))

In [None]:
ckpt["state_dict"]["gaussian_model._xyz"] = xyzs
ckpt["state_dict"]["gaussian_model._features_dc"] = features_dc
ckpt["state_dict"]["gaussian_model._features_rest"] = features_rest
ckpt["state_dict"]["gaussian_model._features_extra"] = torch.empty((features_rest.shape[0], 0))
ckpt["state_dict"]["gaussian_model._scaling"] = scales
ckpt["state_dict"]["gaussian_model._rotation"] = rotations
ckpt["state_dict"]["gaussian_model._opacity"] = opacities
ckpt["hyper_parameters"]["renderer"] = GSPlatRenderer()

In [None]:
torch.save(ckpt, "jnu_aerial-0526.ckpt")

Merging is completed here

# [NOTE] Contents below are optional or just for experimental purposes

Update a specific partition

In [None]:
dtype = torch.float
device = torch.device("cpu")

input_path = "../edited/20240415_232032.ckpt"
ckpt = torch.load(input_path, map_location=device)

In [None]:
xyz = ckpt["state_dict"]["gaussian_model._xyz"]
reoriented_xyz = xyz @ orientation_transformation[:3, :3].T

In [None]:
target_partition_id = (0, 0)
target_partition_ckpt = "../outputs/JNUAerial/P_000_000.txt-random_background/checkpoints/epoch=19-step=30000.ckpt"

In [None]:
partition_xy = partitions["xys"][partition_id_to_index[target_partition_id]]
partition_xy

In [None]:
# include min bound, exclude max bound
is_in_partition = torch.logical_and(torch.ge(reoriented_xyz[:, :2], partition_xy),
                                    torch.lt(reoriented_xyz[:, :2], partition_xy + 2 * partitions["radius"]))
is_in_partition = torch.logical_and(is_in_partition[:, 0], is_in_partition[:, 1])
is_in_partition.sum()

In [None]:
# remove point inside partition
not_in_partition = ~is_in_partition
xyzs = ckpt["state_dict"]["gaussian_model._xyz"][not_in_partition]
features_dc = ckpt["state_dict"]["gaussian_model._features_dc"][not_in_partition]
features_rest = ckpt["state_dict"]["gaussian_model._features_rest"][not_in_partition]
scales = ckpt["state_dict"]["gaussian_model._scaling"][not_in_partition]
rotations = ckpt["state_dict"]["gaussian_model._rotation"][not_in_partition]
opacities = ckpt["state_dict"]["gaussian_model._opacity"][not_in_partition]
xyzs.shape

In [None]:
partition_ckpt = torch.load(target_partition_ckpt, map_location=device)
reoriented_partition_xyz = partition_ckpt["state_dict"]["gaussian_model._xyz"] @ orientation_transformation[:3, :3].T

In [None]:
is_in_partition = torch.logical_and(torch.ge(reoriented_partition_xyz[:, :2], partition_xy),
                                    torch.lt(reoriented_partition_xyz[:, :2], partition_xy + 2 * partitions["radius"]))
is_in_partition = torch.logical_and(is_in_partition[:, 0], is_in_partition[:, 1])
is_in_partition.sum()

In [None]:
xyzs = torch.concat([xyzs, partition_ckpt["state_dict"]["gaussian_model._xyz"][is_in_partition]])
features_dc = torch.concat([features_dc, partition_ckpt["state_dict"]["gaussian_model._features_dc"][is_in_partition]])
features_rest = torch.concat(
    [features_rest, partition_ckpt["state_dict"]["gaussian_model._features_rest"][is_in_partition]])
scales = torch.concat([scales, partition_ckpt["state_dict"]["gaussian_model._scaling"][is_in_partition]])
rotations = torch.concat([rotations, partition_ckpt["state_dict"]["gaussian_model._rotation"][is_in_partition]])
opacities = torch.concat([opacities, partition_ckpt["state_dict"]["gaussian_model._opacity"][is_in_partition]])

In [None]:
ckpt["state_dict"]["gaussian_model._xyz"] = xyzs
ckpt["state_dict"]["gaussian_model._features_dc"] = features_dc
ckpt["state_dict"]["gaussian_model._features_rest"] = features_rest
ckpt["state_dict"]["gaussian_model._scaling"] = scales
ckpt["state_dict"]["gaussian_model._rotation"] = rotations
ckpt["state_dict"]["gaussian_model._opacity"] = opacities

In [None]:
torch.save(ckpt, "jnu_aerial_new.ckpt")

LightGaussian

In [None]:
from internal.utils.light_gaussian import get_count_and_score, calculate_v_imp_score, get_prune_mask

In [None]:
# get partition image list
partition_id = (0, 0)
image_list = []
with open(os.path.join(partition_base_dir, f"{partition_id[0]:03d}_{partition_id[1]:03d}.txt"), "r") as f:
    for row in f:
        image_list.append(row.rstrip("\n"))

# get camera correspond to the image list
cameras = get_cameras_by_image_list(image_list)

In [None]:
ckpt = torch.load("../outputs/JNUAerial-0526/P_000_000.txt/checkpoints/epoch=100-step=132700.ckpt", map_location="cpu")

In [None]:
partition_xy = partitions["xys"][partition_id_to_index[partition_id]]
partition_xy

In [None]:
xyz = ckpt["state_dict"]["gaussian_model._xyz"]
reoriented_xyz = xyz @ orientation_transformation[:3, :3].T
# include min bound, exclude max bound
is_in_partition = torch.logical_and(torch.ge(reoriented_xyz[:, :2], partition_xy),
                                    torch.lt(reoriented_xyz[:, :2], partition_xy + 2 * partitions["radius"]))
is_in_partition = torch.logical_and(is_in_partition[:, 0], is_in_partition[:, 1])

In [None]:
# get Gaussians located in partition to construct a new state_dict
state_dict = {}
for i in ckpt["state_dict"]:
    if i.startswith("gaussian_model._"):
        state_dict[i] = ckpt["state_dict"][i][is_in_partition]

In [None]:
gaussian_model = GaussianModelSimplified.construct_from_state_dict(state_dict, 0, device="cuda")

In [None]:
hit_camera_count_total = torch.zeros((gaussian_model.get_xyz.shape[0]), dtype=torch.int, device=gaussian_model.get_xyz.device)
opacity_score_total = torch.zeros((gaussian_model.get_xyz.shape[0]), dtype=torch.float, device=gaussian_model.get_xyz.device)
alpha_score_total = torch.zeros((gaussian_model.get_xyz.shape[0]), dtype=torch.float, device=gaussian_model.get_xyz.device)
visibility_score_total = torch.zeros((gaussian_model.get_xyz.shape[0]), dtype=torch.float, device=gaussian_model.get_xyz.device)

In [None]:
for idx, camera in tqdm(enumerate(cameras), total=len(cameras)):
    hit_count, opacity_score, alpha_score, visibility_score = GSplatHitPixelCountRenderer.hit_pixel_count(
        means3D=gaussian_model.get_xyz,
        opacities=gaussian_model.get_opacity,
        scales=gaussian_model.get_scaling,
        rotations=gaussian_model.get_rotation,
        viewpoint_camera=camera.to_device("cuda"),
    )
    hit_camera_count_total += hit_count
    opacity_score_total += opacity_score
    alpha_score_total += alpha_score
    visibility_score_total += visibility_score

In [None]:
visibility_close_to_zero = torch.isclose(visibility_score_total, torch.tensor(0.).to(visibility_score_total))
visibility_close_to_zero.sum(), visibility_score_total.shape[0]

In [None]:
# gaussian_model.delete_gaussians(visibility_close_to_zero)

In [None]:
# visible_gaussians = torch.bitwise_not(visibility_close_to_zero)
# hit_camera_count_total = hit_camera_count_total[visible_gaussians]
# opacity_score_total = opacity_score_total[visible_gaussians]
# alpha_score_total = alpha_score_total[visible_gaussians]
# visibility_score_total = visibility_score_total[visible_gaussians]

In [None]:
# (visibility_score_total > 1).sum(), visibility_score_total.shape[0]

In [None]:
visibility_score_total_sorted = torch.sort(visibility_score_total, descending=True)
visibility_score_total_sorted

In [None]:
preserve_indices = visibility_score_total_sorted.indices[:int(visibility_score_total_sorted.indices.shape[0] * 0.1)]
preserve_indices.shape[0]

In [None]:
for i in state_dict:
    ckpt["state_dict"][i] = state_dict[i][preserve_indices.to(ckpt["state_dict"][i].device)].to(ckpt["state_dict"][i].device)

In [None]:
torch.save(ckpt, "visibility_pruned.ckpt")

v_imp_score

In [None]:
v_imp_score = calculate_v_imp_score(gaussian_model.get_scaling, opacity_score_total, 0.1)
v_imp_score, v_imp_score.shape

In [None]:
prune_mask = get_prune_mask(0.9, v_imp_score)
preserve_mask = torch.bitwise_not(prune_mask)
preserve_mask.sum(), preserve_mask.shape

In [None]:
for i in state_dict:
    ckpt["state_dict"][i] = state_dict[i][preserve_mask.to(ckpt["state_dict"][i].device)].to(ckpt["state_dict"][i].device)

In [None]:
torch.save(ckpt, "opacity_score_pruned.ckpt")