In [None]:
import os, sys
os.environ['CUDA_VISIBLE_DEVICES'] = '3,'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from pathlib import Path
import random
from argparse import ArgumentParser, Namespace
from tqdm import tqdm

import numpy as np
np.set_printoptions(suppress=True)
from PIL import Image
import lpips
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
torch.cuda.empty_cache()

from gaussian_renderer import render
from scene import Scene, GaussianModel, colmap_loader
from scene.cameras import Camera
from scene.colmap_loader import read_extrinsics_binary, read_intrinsics_binary
from scene.multipleview_dataset import multipleview_dataset_kubric
from scene.dataset_readers import sceneLoadTypeCallbacks, format_infos, getNerfppNorm, fetchPly, SceneInfo
from scene.dataset import FourDGSdataset
from utils.graphics_utils import focal2fov
from utils.loss_utils import l1_loss

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

setup_seed(6666)
# torch.cuda.set_device(0)
# torch.autograd.set_detect_anomaly(False)

In [2]:
dataset = Namespace(**{
    "sh_degree": 3, "source_path": None, "model_path": None, "images": 'images', "resolution": -1,
    "white_background": True, "data_device": "cuda", "eval": True,
    "render_process": True, "add_points": False, "extension": ".png",
    "llffhold": 8
})
hyper = Namespace(**{
    'net_width': 128, 'timebase_pe': 4, 'defor_depth': 0, 'posebase_pe': 10,
    'scale_rotation_pe': 2, 'opacity_pe': 2, 'timenet_width': 64, 'timenet_output': 32,
    'bounds': 1.6, 'plane_tv_weight': 0.0002, 'time_smoothness_weight': 0.001, 'l1_time_planes': 0.0001,
    'kplanes_config': {
        'grid_dimensions': 2, 
        'input_coordinate_dim': 4, 
        'output_coordinate_dim': 16, 
        'resolution': [64, 64, 64, 150]
    }, 'multires': [1, 2],
    'no_dx': False, 'no_grid': False, 'no_ds': False, 'no_dr': False, 'no_do': False, 'no_dshs': False,
    'empty_voxel': False, 'grid_pe': 0, 'static_mlp': False, 'apply_rotation': False,
})
opt = Namespace(**{
    'position_lr_init': 1.6e-4, 'position_lr_final': 1.6e-06, 'position_lr_delay_mult': 0.01, 'position_lr_max_steps': 20000,
    'deformation_lr_init': 1.6e-4, 'deformation_lr_final': 1.6e-05, 'deformation_lr_delay_mult': 0.01,
    'grid_lr_init': 1.6e-3, 'grid_lr_final': 1.6e-4, 
    'feature_lr': 0.0025, 'opacity_lr': 0.05, 'scaling_lr': 0.005, 'rotation_lr': 0.001, 'percent_dense': 0.01, 
})
pipe = Namespace(**{
    'convert_SHs_python': False,
    'compute_cov3D_python': False,
    'debug': False,
})

In [None]:
datadir = Path("samples/ParallelDomain/scene_000000/colmap")
dataset_type = "Kubric"

# scene_info = sceneLoadTypeCallbacks[dataset_type](datadir)
cam_extrinsics = read_extrinsics_binary(os.path.join(datadir, "sparse/0/images.bin"))
cam_intrinsics = read_intrinsics_binary(os.path.join(datadir, "sparse/0/cameras.bin"))
cam_dict = {k:v.name for k,v in cam_extrinsics.items()}
cam_ids = [k for k,v in cam_dict.items() if 'yaw-0' in v]

train_cam_infos = multipleview_dataset_kubric(
    cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 
    image_folder=os.path.join(datadir, "../images"), split="train", cam_ids=cam_ids, 
    image_length=50, factor=3
)
train_cams = FourDGSdataset(train_cam_infos, None, dataset_type)

# test_cam_infos = multipleview_dataset_kubric(
#     cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 
#     image_folder=os.path.join(datadir, "../images"), split="test", cam_ids=cam_ids
# )
# test_cams = FourDGSdataset(test_cam_infos, None, dataset_type)
print(len(train_cams))
print(train_cam_infos.image_paths)
train_cam_infos[16]

In [None]:
datadir = Path("samples/parallel_scene_000000/colmap")
dataset_type = "Kubric"

expname = Path("samples/parallel_scene_000000/outputs_v2")
expname.mkdir(parents=True, exist_ok=True)
(expname / "point_cloud").mkdir(parents=True, exist_ok=True)

gaussians = GaussianModel(3, hyper)

# scene_info = sceneLoadTypeCallbacks[dataset_type](datadir)
cam_extrinsics = read_extrinsics_binary(os.path.join(datadir, "sparse/0/images.bin"))
cam_intrinsics = read_intrinsics_binary(os.path.join(datadir, "sparse/0/cameras.bin"))
cam_dict = {k:v.name for k,v in cam_extrinsics.items()}
cam_ids = [k for k,v in cam_dict.items() if 'yaw-0' in v]

total_cam_infos = multipleview_dataset_kubric(
    cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 
    image_folder=os.path.join(datadir, "../images"), split="train"
)
nerf_normalization = getNerfppNorm(format_infos(total_cam_infos,"train"))

train_cam_infos = multipleview_dataset_kubric(
    cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 
    image_folder=os.path.join(datadir, "../images"), split="train", cam_ids=cam_ids
)
train_cams = FourDGSdataset(train_cam_infos, None, dataset_type)

# test_cam_infos = multipleview_dataset_kubric(
#     cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, 
#     image_folder=os.path.join(datadir, "../images"), split="test", cam_ids=cam_ids
# )
# test_cams = FourDGSdataset(test_cam_infos, None, dataset_type)
print(len(train_cams))

pcd = fetchPly(os.path.join(datadir, "points3D_multipleview.ply"))
xyz_max, xyz_min = pcd.points.max(axis=0), pcd.points.min(axis=0)
gaussians._deformation.deformation_net.set_aabb(xyz_max, xyz_min)
gaussians.create_from_pcd(pcd, nerf_normalization["radius"], 0)

In [4]:
@torch.no_grad()
def render_imgs(model_path, gaussians, viewpoints, render_func, pipe, background, stage, iteration, dataset_type):
    image_path = model_path / f"{stage}_render/images"
    image_path.mkdir(parents=True, exist_ok=True)
    for idx in range(len(viewpoints)):
        render_pkg = render_func(viewpoints[idx], gaussians, pipe, background, stage=stage, cam_type=dataset_type)
        image, depth, gt = render_pkg["render"], render_pkg["depth"], viewpoints[idx].original_image

        gt_np = gt.permute(1,2,0).cpu().numpy()
        image_np = image.permute(1, 2, 0).cpu().numpy()  # (H, W, 3)
        depth_np = depth.permute(1, 2, 0).cpu().numpy()
        depth_np /= depth_np.max()
        depth_np = np.repeat(depth_np, 3, axis=2)
        image_np = np.concatenate((gt_np, image_np, depth_np), axis=1)
        image_with_labels = Image.fromarray((np.clip(image_np,0,1) * 255).astype('uint8'))  
        image_with_labels.save(image_path / f"{iteration}_{idx}.jpg")

In [None]:
gaussians.training_setup(opt)
background = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
ema_loss_for_log = 0.0

first_iter = 0
final_iter = 3000
progress_bar = tqdm(range(first_iter, final_iter), desc="Training progress")
first_iter += 1

batch_size = 1
print("data loading done")
viewpoint_stack_loader = DataLoader(train_cams, batch_size=batch_size, shuffle=True, num_workers=16, collate_fn=list)
loader = iter(viewpoint_stack_loader)

for iteration in range(first_iter, final_iter+1):        
    gaussians.update_learning_rate(iteration)
    if iteration % 1000 == 0:
        gaussians.oneupSHdegree()

    try:
        viewpoint_cams = next(loader)
    except StopIteration:
        loader = iter(viewpoint_stack_loader)

    images, gt_images, radii_list, visibility_filter_list, viewspace_point_tensor_list = [], [], [], [], []
    for viewpoint_cam in viewpoint_cams:
        render_pkg = render(
            viewpoint_cam, gaussians, pipe, background, 
            stage='coarse', cam_type='MultipleView')
        images.append(render_pkg["render"].unsqueeze(0))
        gt_images.append(viewpoint_cam.original_image.unsqueeze(0).cuda())
        radii_list.append(render_pkg["radii"].unsqueeze(0))
        visibility_filter_list.append(render_pkg["visibility_filter"].unsqueeze(0))
        viewspace_point_tensor_list.append(render_pkg["viewspace_points"])
    image_tensor = torch.cat(images,0)
    gt_image_tensor = torch.cat(gt_images,0)
    radii = torch.cat(radii_list,0).max(dim=0).values
    visibility_filter = torch.cat(visibility_filter_list).any(dim=0)

    Ll1 = l1_loss(image_tensor, gt_image_tensor[:,:3,:,:])
    loss = Ll1
    loss.backward()

    viewspace_point_tensor_grad = torch.zeros_like(render_pkg["viewspace_points"])
    for idx in range(0, len(viewspace_point_tensor_list)):
        viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad

    with torch.no_grad():
        # Progress bar
        ema_loss_for_log = 0.4 * Ll1.item() + 0.6 * ema_loss_for_log
        total_point = gaussians._xyz.shape[0]
        if iteration % 10 == 0:
            progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "point":f"{total_point}"})
            progress_bar.update(10)
        if (iteration % 1000 == 0):
            print("\n[ITER {}] Saving Gaussians".format(iteration))
            (expname / f"point_cloud/coarse_iteration_{iteration}").mkdir(parents=True, exist_ok=True)
            gaussians.save_ply(expname / f"point_cloud/coarse_iteration_{iteration}/point_cloud.ply")
            gaussians.save_deformation(expname / f"point_cloud/coarse_iteration_{iteration}")
        if iteration % 300 == 299:
            # render_imgs(expname, gaussians, [ test_cams[iteration% len(test_cams)]], 
            #     render, pipe, background, "coarse_test", iteration, 'MultipleView')
            render_imgs(expname, gaussians, [train_cams[iteration%len(train_cams)]], 
                render, pipe, background, "coarse_train", iteration, 'MultipleView')

        # Densification
        if iteration < 10000 :
            gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
            gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)

            size_threshold = 20 if iteration > 3000 else None
            if iteration > 500 and iteration % 100 == 0 and gaussians.get_xyz.shape[0] < 360000:
                gaussians.densify(0.0002, 0.005, nerf_normalization["radius"], size_threshold, 5, 5, expname, iteration, "coarse")
            if iteration > 500 and iteration % 100 == 0 and gaussians.get_xyz.shape[0] > 200000:
                gaussians.prune(0.0002, 0.005, nerf_normalization["radius"], size_threshold)
            # if iteration % 3000 == 0:
            #     gaussians.reset_opacity()

        if iteration < 3000:
            gaussians.optimizer.step()
            gaussians.optimizer.zero_grad(set_to_none = True)
        if (iteration in [1000, 2000, 3000]):
            print("\n[ITER {}] Saving Checkpoint".format(iteration))
            torch.save((gaussians.capture(), iteration), expname / f"/chkpnt_coarse_{iteration}.pth")
progress_bar.close()

In [None]:
torch.cuda.empty_cache()

gaussians.training_setup(opt)
background = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
ema_loss_for_log = 0.0

first_iter = 0
final_iter = 15000
progress_bar = tqdm(range(first_iter, final_iter), desc="Training progress")
first_iter += 1

batch_size = 1
print("data loading done")
viewpoint_stack_loader = DataLoader(train_cams, batch_size=batch_size, shuffle=True, num_workers=16, collate_fn=list)
loader = iter(viewpoint_stack_loader)

for iteration in range(first_iter, final_iter+1):        
    gaussians.update_learning_rate(iteration)
    if iteration % 1000 == 0:
        gaussians.oneupSHdegree()

    try:
        viewpoint_cams = next(loader)
    except StopIteration:
        loader = iter(viewpoint_stack_loader)

    images, gt_images, radii_list, visibility_filter_list, viewspace_point_tensor_list = [], [], [], [], []
    for viewpoint_cam in viewpoint_cams:
        render_pkg = render(
            viewpoint_cam, gaussians, pipe, background, 
            stage='fine', cam_type='MultipleView')
        images.append(render_pkg["render"].unsqueeze(0))
        gt_images.append(viewpoint_cam.original_image.unsqueeze(0).cuda())
        radii_list.append(render_pkg["radii"].unsqueeze(0))
        visibility_filter_list.append(render_pkg["visibility_filter"].unsqueeze(0))
        viewspace_point_tensor_list.append(render_pkg["viewspace_points"])
    image_tensor = torch.cat(images,0)
    gt_image_tensor = torch.cat(gt_images,0)
    radii = torch.cat(radii_list,0).max(dim=0).values
    visibility_filter = torch.cat(visibility_filter_list).any(dim=0)

    Ll1 = l1_loss(image_tensor, gt_image_tensor[:,:3,:,:])
    tv_loss = gaussians.compute_regulation(hyper.time_smoothness_weight, hyper.l1_time_planes, hyper.plane_tv_weight)
    loss = Ll1 + tv_loss
    loss.backward()
    
    viewspace_point_tensor_grad = torch.zeros_like(render_pkg["viewspace_points"])
    for idx in range(0, len(viewspace_point_tensor_list)):
        viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad

    with torch.no_grad():
        # Progress bar
        ema_loss_for_log = 0.4 * Ll1.item() + 0.6 * ema_loss_for_log
        total_point = gaussians._xyz.shape[0]
        if iteration % 10 == 0:
            progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}", "point":f"{total_point}"})
            progress_bar.update(10)
        if (iteration in [1000, 2000, 5000, 10000, 15000]):
            print("\n[ITER {}] Saving Gaussians".format(iteration))
            (expname / f"point_cloud/fine_iteration_{iteration}").mkdir(parents=True, exist_ok=True)
            gaussians.save_ply(expname / f"point_cloud/fine_iteration_{iteration}/point_cloud.ply")
            gaussians.save_deformation(expname / f"point_cloud/fine_iteration_{iteration}")
        if iteration % 10 == 0:
            render_imgs(expname, gaussians, [ test_cams[iteration% len(test_cams)]], 
                render, pipe, background, "fine_test", iteration, 'MultipleView')
            render_imgs(expname, gaussians, [train_cams[iteration%len(train_cams)]], 
                render, pipe, background, "fine_train", iteration, 'MultipleView')

        # Densification
        if iteration < 10000 :
            gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
            gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)

            size_threshold = 20 if iteration > 15000 else None
            if iteration > 500 and iteration % 100 == 0 and gaussians.get_xyz.shape[0] < 360000:
                gaussians.densify(0.0002, 0.005, scene_info.nerf_normalization["radius"], size_threshold, 5, 5, expname, iteration, "fine")
            if iteration > 500 and iteration % 100 == 0 and gaussians.get_xyz.shape[0] > 200000:
                gaussians.prune(0.0002, 0.005, scene_info.nerf_normalization["radius"], size_threshold)
            if iteration % 15000 == 0:
                gaussians.reset_opacity()

        if iteration < final_iter:
            gaussians.optimizer.step()
            gaussians.optimizer.zero_grad(set_to_none = True)
        if (iteration in [1000, 2000, 5000, 10000, 15000]):
            print("\n[ITER {}] Saving Checkpoint".format(iteration))
            torch.save((gaussians.capture(), iteration), expname / f"/chkpnt_fine_{iteration}.pth")
progress_bar.close()

In [None]:
# render_pkg = render(
#     viewpoint_cam, gaussians, pipe, background, 
#     stage='fine', cam_type='MultipleView')
# def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, stage="fine", cam_type=None):

screenspace_points = torch.zeros_like(gaussians.get_xyz, dtype=gaussians.get_xyz.dtype, requires_grad=True, device="cuda") + 0
means3D = gaussians.get_xyz
raster_settings = GaussianRasterizationSettings(
    image_height=int(viewpoint_cam.image_height),
    image_width=int(viewpoint_cam.image_width),
    tanfovx=math.tan(viewpoint_cam.FoVx * 0.5),
    tanfovy=math.tan(viewpoint_cam.FoVy * 0.5),
    bg=background,
    scale_modifier=1.0,
    viewmatrix=viewpoint_cam.world_view_transform.cuda(),
    projmatrix=viewpoint_cam.full_proj_transform.cuda(),
    sh_degree=gaussians.active_sh_degree,
    campos=viewpoint_cam.camera_center.cuda(),
    prefiltered=False,
    debug=pipe.debug
)
time = torch.tensor(viewpoint_cam.time).to(means3D.device).repeat(means3D.shape[0],1)
