In [None]:
%load_ext autoreload
%autoreload 2
from typing import List, Optional
from omegaconf import OmegaConf
import os
import time
import json
import wandb
import logging
import argparse

import torch
from datasets.driving_dataset import DrivingDataset
from datasets.my_dataset import MyDataset
from utils.misc import import_str
from models.trainers import BasicTrainer


In [None]:
cfg = OmegaConf.load(os.path.join("/mnt/e/Output/background/149", "config.yaml"))
%cd /home/a/drivestudio

dataset = MyDataset(cfg.data)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trainer = import_str(cfg.trainer.type)(
    **cfg.trainer,
    num_timesteps=dataset.num_img_timesteps,
    model_config=cfg.model,
    num_train_images=len(dataset.train_image_set),
    num_full_images=len(dataset.full_image_set),
    test_set_indices=dataset.test_timesteps,
    scene_aabb=dataset.get_aabb().reshape(2, 3),
    device=device

)

trainer.resume_from_checkpoint(
    ckpt_path="/mnt/e/Output/background/149/checkpoint_final.pth",
    load_only_model=True
)

In [16]:

cam2worlds = torch.load("notebook/data/cam2worlds.pth")
intrinsics = torch.load("notebook/data/intrinsics.pth")
width, height = 960, 640

for idx in range(len(cam2worlds)):
    c2w = cam2worlds[idx]
    intrinsic = intrinsics[idx]
    name = "fake_truth"

    a, b = dataset.pixel_source.get_image(0)
    device = trainer.device

    cam_info = {
        "camera_to_world": c2w.to(device),
        "intrinsics": intrinsic.to(device),
        "height": torch.tensor(height, dtype=torch.long, device=device),
        "width": torch.tensor(width, dtype=torch.long, device=device),
    }

    x, y = torch.meshgrid(
        torch.arange(width),
        torch.arange(height),
        indexing="xy",
    )
    x, y = x.flatten(), y.flatten()
    x, y = x.to(device), y.to(device)

    pixel_coords = (
        torch.stack([y / height, x / width], dim=-1)
        .float()
        .reshape(height, width, 2)
    )
    from datasets.base.pixel_source import get_rays

    intrinsic = intrinsic * dataset.pixel_source.downscale_factor
    intrinsic[2, 2] = 1.0
    intrinsic = intrinsic.to(device)
    c2w = c2w.to(device)
    _, viewdirs, _ = get_rays(x, y, c2w, intrinsic)

    viewdirs = viewdirs.reshape(height, width, 3)

    image_id = torch.full(
        (height, width),
        0,
        dtype=torch.long,
    )

    frame_idx = 0

    normalized_time = torch.full(
        (height, width),
        dataset.pixel_source.normalized_time[0],
        dtype=torch.float32,
    )

    image_info = {
        "viewdirs": viewdirs.to(device),
        "img_idx": image_id.to(device),
        "pixel_coords": pixel_coords.to(device),
        "normed_time": normalized_time.to(device),
    }

    output = trainer(image_info, cam_info, False)
    from utils.visualization import to8b
    import cv2
    img = to8b(output["rgb"])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    cv2.imwrite(f"/mnt/e/Output/background/149/render/{idx}.png",img)

In [None]:
from utils.visualization import to8b
import cv2
img = to8b(output["rgb_sky_blend"])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
cv2.imwrite(f"/mnt/e/Output/background/149/render/sky.png",img)

In [None]:
output["rgb_sky_blend"]