In [None]:
from numba.cuda import detect
%load_ext autoreload
%autoreload 2
from typing import List, Optional
from omegaconf import OmegaConf
import os
import time
import json
import wandb
import logging
import argparse

import torch

from datasets.my_dataset import MyDataset
from utils.misc import import_str
from models.trainers import BasicTrainer
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))


In [None]:
data_root = "/mnt/e/Output/background/023_test"

cfg = OmegaConf.load(os.path.join(data_root, "config.yaml"))
%cd /home/a/drivestudio
cfg.data.data_root = "data/waymo/processed/training"

dataset = MyDataset(cfg.data)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trainer = import_str(cfg.trainer.type)(
    **cfg.trainer,
    num_timesteps=dataset.num_img_timesteps,
    model_config=cfg.model,
    num_train_images=len(dataset.train_image_set),
    num_full_images=len(dataset.full_image_set),
    test_set_indices=dataset.test_timesteps,
    scene_aabb=dataset.get_aabb().reshape(2, 3),
    device=device

)
ckpt_path = os.path.join(data_root, "checkpoint_10000.pth")

trainer.resume_from_checkpoint(
    ckpt_path=ckpt_path,
    load_only_model=True
)

In [None]:
import pyiqa
import random

device = trainer.device
render_dir = os.path.join(cfg.log_dir, f"render")
pred_dir = os.path.join(cfg.log_dir, f"pred")

os.makedirs(render_dir, exist_ok=True)
for file in os.listdir(render_dir):
    if file.endswith(".png"):
        os.remove(os.path.join(render_dir, file))

# os.makedirs(pred_dir, exist_ok=True)
# for file in os.listdir(pred_dir):
#     if file.endswith(".png"):
#         os.remove(os.path.join(pred_dir, file))

# iqa_metric = pyiqa.create_metric('brisque', device=device)

In [None]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

In [None]:
from datasets.my_dataset import get_fake_gt_samples
random.seed(0)
cam2worlds, intrinsics, norm_times, step_times, depth_maps = get_fake_gt_samples(dataset, min_coverage=0.6, max_coverage=0.8, num_points=100)

In [None]:
from matplotlib import pyplot as plt
d = depth_maps[0].numpy()
plt.imshow(d)

In [None]:
image_info_list, cam_info_list = [], []
width, height = 960, 640
from utils.visualization import to8b
import cv2
import random

for idx in range(len(cam2worlds)):
    c2w = cam2worlds[idx]
    intrinsic = intrinsics[idx]
    depth_map = depth_maps[idx]
    step_time = step_times[idx]
    norm_time = norm_times[idx]


    cam_info = {
        "camera_to_world": c2w.to(device),
        "intrinsics": intrinsic.to(device),
        "height": torch.tensor(height, dtype=torch.long, device=device),
        "width": torch.tensor(width, dtype=torch.long, device=device),
    }

    x, y = torch.meshgrid(
        torch.arange(width),
        torch.arange(height),
        indexing="xy",
    )
    x, y = x.flatten(), y.flatten()
    x, y = x.to(device), y.to(device)

    pixel_coords = (
        torch.stack([y / height, x / width], dim=-1)
        .float()
        .reshape(height, width, 2)
    )
    from datasets.base.pixel_source import get_rays

    intrinsic = intrinsic * dataset.pixel_source.downscale_factor
    intrinsic[2, 2] = 1.0
    intrinsic = intrinsic.to(device)
    c2w = c2w.to(device)
    origins, viewdirs, direction_norm = get_rays(x, y, c2w, intrinsic)

    viewdirs = viewdirs.reshape(height, width, 3)

    image_id = torch.full(
        (height, width),
        0,
        dtype=torch.long,
    )

    normalized_time = torch.full(
        (height, width),
        norm_time,
        dtype=torch.float32,
    )

    image_info = {
        "origins": origins.to(device),
        "direction_norm": direction_norm.to(device),
        "viewdirs": viewdirs.to(device),
        "img_idx": image_id.to(device),
        "pixel_coords": pixel_coords.to(device),
        "normed_time": normalized_time.to(device),
        "depth_map": depth_map.to(device),

    }
    # train_step_camera_downscale = trainer._get_downscale_factor()
    # image_info, cam_info = dataset.train_image_set[idx]
    # for k, v in image_info.items():
    #     if isinstance(v, torch.Tensor):
    #         image_info[k] = v.cuda(non_blocking=True)
    # for k, v in cam_info.items():
    #     if isinstance(v, torch.Tensor):
    #         cam_info[k] = v.cuda(non_blocking=True)

    output = trainer(image_info, cam_info, False)

    sky_mask = output["opacity"].cpu().detach()
    sky_mask = sky_mask.reshape(height, width)
    sky_mask = (sky_mask > 0.5).float()

    image_info["sky_masks"] = sky_mask.to(device)

    img = to8b(output["rgb"])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # iqa_img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device)

    # score = iqa_metric(iqa_img)
    # if score > 60:
    #     continue

    save_path = os.path.join(render_dir, f"{idx:03d}.png")
    cv2.imwrite(save_path, img)

    image_info_list.append(image_info)
    cam_info_list.append(cam_info)

In [None]:
for idx in range(len(cam2worlds)):
    render_img_path = os.path.join(render_dir, f"{idx:03d}.png")
    pred_img_path = os.path.join(pred_dir, f"{idx:03d}.png")

    if not (os.path.exists(render_img_path) and os.path.exists(pred_img_path)):
        print(pred_img_path)
        continue

    pred_img = cv2.imread(pred_img_path)
    pred_img = cv2.cvtColor(pred_img, cv2.COLOR_RGB2BGR)
    pred_img = torch.from_numpy(pred_img).float() / 255.0

    image_info_list[idx]["pixels"] = pred_img.to(device)

to_delete = [idx for idx, image_info in enumerate(image_info_list)  if "pixels" not in image_info]
for idx in reversed(to_delete):
    del image_info_list[idx]
    del cam_info_list[idx]

torch.cuda.empty_cache()
dataset.load_fake_gt(image_info_list, cam_info_list, True)

In [None]:
torch.save(image_info_list, "notebook/data/image_info_list_023_test.pth")
torch.save(cam_info_list, "notebook/data/cam_info_list_023_test.pth")
# image_info_list = torch.load("notebook/data/image_info_list.pth")
# cam_info_list = torch.load("notebook/data/cam_info_list.pth")
# dataset.load_fake_gt(image_info_list, cam_info_list, True)


In [None]:
trainer.init_gaussians_from_dataset(dataset=dataset)
trainer.initialize_optimizer()
trainer.init_viewer(8080)

In [None]:

for step in range(0, 10000):
    #----------------------------------------------------------------------------
    #----------------------------  training step  -------------------------------
    if step % 100 == 0:
        print(step)
    # prepare for training
    trainer.set_train()
    trainer.preprocess_per_train_step(step=step)

    trainer.optimizer_zero_grad() # zero grad
    # get data
    use_fake_gt = random.random() < 0.1
    if use_fake_gt and step > 500:
        image_infos, cam_infos = dataset.fake_gt_next()
    else:
        train_step_camera_downscale = trainer._get_downscale_factor()
        image_infos, cam_infos = dataset.train_image_set.next(train_step_camera_downscale)
    for k, v in image_infos.items():
        if isinstance(v, torch.Tensor):
            image_infos[k] = v.cuda(non_blocking=True)
    for k, v in cam_infos.items():
        if isinstance(v, torch.Tensor):
            cam_infos[k] = v.cuda(non_blocking=True)

    # forward & backward
    outputs = trainer(image_infos, cam_infos, False)
    trainer.update_visibility_filter()
    loss_dict = trainer.compute_losses(
        outputs=outputs,
        image_infos=image_infos,
        cam_infos=cam_infos,
    )
    # check nan or inf
    for k, v in loss_dict.items():
        if torch.isnan(v).any():
            raise ValueError(f"NaN detected in loss {k} at step {step}")
        if torch.isinf(v).any():
            raise ValueError(f"Inf detected in loss {k} at step {step}")
    trainer.backward(loss_dict)

    # after training step
    trainer.postprocess_per_train_step(step=step)

In [None]:
trainer.save_checkpoint(
    log_dir=cfg.log_dir,
    save_only_model=True,
    is_final=False,
)

In [None]:
dataset.pixel_source.camera_data