In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sys
import os
import json
import torch
import yaml
import numpy as np
import cv2
from pathlib import Path
import matplotlib.pyplot as plt
from copy import deepcopy 
from PIL import Image
import torchvision.transforms as transforms
from nerfstudio.engine.trainer import Trainer
from nerfstudio.utils.eval_utils import eval_setup
from nerfstudio.data.dataparsers.base_dataparser import transform_poses_to_original_space
from plane_nerf.plane_nerf_utils import invert_SO3xR3_pose, transform_original_space_to_pose


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
os.chdir('/workspace')
#Path params
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "/workspace/outputs/jackal_four_spins_noise/plane-nerf/2024-01-08_142518"
DATA_PATH = "/stored_data/jackal_four_spins"
INPUT_FRAME = 0
TARGET_FRAME = 6
#Model params
L_RATE = 1e-2
N_STEP = 500

TRANSFORM_PATH = os.path.join(DATA_PATH, "transforms.json")
with open(TRANSFORM_PATH) as f:
    TRANSFORM_JSON = json.load(f)

WIDTH = TRANSFORM_JSON["w"]
HEIGHT = TRANSFORM_JSON["h"]

print(TRANSFORM_JSON)
INPUT_PATH = {
    "tf" : TRANSFORM_JSON["frames"][INPUT_FRAME]["transform_matrix"],
    "images" : os.path.join(DATA_PATH, "images", f"{INPUT_FRAME}.png"),
    "masks" : os.path.join(DATA_PATH, "masks", f"{INPUT_FRAME}.png")
}
TARGET_PATH = {
    "tf" : TRANSFORM_JSON["frames"][TARGET_FRAME]["transform_matrix"],
    "images" : os.path.join(DATA_PATH, "images", f"{TARGET_FRAME}.png"),
    "masks" : os.path.join(DATA_PATH, "masks", f"{TARGET_FRAME}.png")
}

{'focal_length': 50.0, 'fov_x': 0.85172, 'w': 1640, 'h': 1232, 'pixel_width': 0.06965178305458342, 'pixel_height': 0.06965178305458342, 'fl_x': 717.8567124522416, 'fl_y': 717.8567124522416, 'cx': 820, 'cy': 616, 'camera_model': 'OPENCV', 'frames': [{'file_path': 'images/0.png', 'transform_matrix': [[1.0, 0.0, 0.0, 0.0], [0.0, 0.7648421872844884, -0.644217687237691, -1.5], [0.0, 0.644217687237691, 0.7648421872844884, 1.0], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/0.png'}, {'file_path': 'images/1.png', 'transform_matrix': [[0.9921147013144778, -0.0958601444987601, 0.08074188586081742, 0.18799985034645633], [0.12533323356430423, 0.7588111781904621, -0.6391378383553254, -1.4881720519717168], [0.0, 0.644217687237691, 0.7648421872844884, 1.0], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/1.png'}, {'file_path': 'images/2.png', 'transform_matrix': [[0.9685831611286312, -0.19020851725470017, 0.1602104239487451, 0.3730348307472822], [0.24868988716485482, 0.7408132635245464, -0.6239784039596586, 

In [4]:
config_path = os.path.join(MODEL_PATH, "config.yml")

config = yaml.load(Path(config_path).read_text(), Loader=yaml.Loader)

config, pipeline, checkpoint_path, _ = eval_setup(
                        Path(config_path),
                        test_mode="inference",
                    )



In [5]:
#Create Custom Camera and Batch

pipeline.eval()
pipeline.datamanager.setup_eval()

Output()

In [8]:
#Spoof dataparser with a new image and pose

custom_train_dataparser_outputs = deepcopy(pipeline.datamanager.train_dataparser_outputs)
#Load new image through image path
custom_train_dataparser_outputs.image_filenames = [Path(TARGET_PATH["images"]).as_posix()]
custom_train_dataparser_outputs.mask_filenames = [Path(TARGET_PATH["masks"]).as_posix()]

#Load new pose through Camera object
custom_cameras = pipeline.datamanager.train_dataparser_outputs.cameras[0]
custom_camera_to_worlds = torch.tensor([INPUT_PATH["tf"]]).float()
#Convert from SE3 to SO3xR3 by removing last row in tensor
custom_camera_to_worlds = custom_camera_to_worlds[:,:3, :]
print(custom_camera_to_worlds)

custom_cameras.camera_to_worlds = transform_original_space_to_pose(custom_camera_to_worlds,
                                                                   custom_train_dataparser_outputs.dataparser_transform,
                                                                   custom_train_dataparser_outputs.dataparser_scale,
                                                                   "opengl")[0]
custom_train_dataparser_outputs.cameras = custom_cameras
print(custom_cameras)


tensor([[[ 1.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.7648, -0.6442, -1.5000],
         [ 0.0000,  0.6442,  0.7648,  1.0000]]])
Cameras(camera_to_worlds=tensor([[ 1.0000e+00, -1.4901e-09, -1.7691e-09, -3.9736e-10],
        [ 0.0000e+00,  7.6484e-01, -6.4422e-01, -1.0000e+00],
        [ 2.3131e-09,  6.4422e-01,  7.6484e-01,  0.0000e+00]]), fx=tensor([280.1392]), fy=tensor([280.1392]), cx=tensor([320.]), cy=tensor([210.]), width=tensor([640]), height=tensor([420]), distortion_params=tensor([0., 0., 0., 0., 0., 0.]), camera_type=tensor([1]), times=None, metadata=None)
Cameras(camera_to_worlds=tensor([[ 1.0000e+00, -1.4901e-09, -1.7691e-09, -3.9736e-10],
        [ 0.0000e+00,  7.6484e-01, -6.4422e-01, -1.0000e+00],
        [ 2.3131e-09,  6.4422e-01,  7.6484e-01,  0.0000e+00]]), fx=tensor([280.1392]), fy=tensor([280.1392]), cx=tensor([320.]), cy=tensor([210.]), width=tensor([640]), height=tensor([420]), distortion_params=tensor([0., 0., 0., 0., 0., 0.]), camera_type=tensor([1