In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import json
import torch
import numpy as np
import cv2
from pathlib import Path
import matplotlib.pyplot as plt
from nerfstudio.utils.eval_utils import eval_setup
from plane_nerf.inerf_trainer import load_data_into_trainer
from plane_nerf.inerf_utils import get_corrected_pose, load_eval_image_into_pipeline, get_relative_pose, get_absolute_diff_for_pose, get_image
from scipy.spatial.transform import Rotation



Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
os.chdir('/workspace/plane-nerf/')
MODEL_PATH = "/workspace/plane-nerf/outputs/jackal_floor_training_data_1/plane-nerf/2024-03-11_145657"
DATA_PATH = "/workspace/plane-nerf/data/jackal_one_frame_floor"
TRANSFORM_FILE = "transforms.json"
with open(os.path.join(DATA_PATH, "ground_truth.json")) as f:
    GROUND_TRUTH = json.load(f)
with open(os.path.join(DATA_PATH, TRANSFORM_FILE)) as f:
    TRANSFORM = json.load(f)

In [4]:
config_path = os.path.join(MODEL_PATH, "config.yml")
config, pipeline, _, _ = eval_setup(
                        Path(config_path),
                        test_mode="inference",
                    )

pipeline = load_eval_image_into_pipeline(pipeline,DATA_PATH,TRANSFORM)

config.pipeline.datamanager.pixel_sampler.num_rays_per_batch = 4096 

trainer = load_data_into_trainer(
    config,
    pipeline,
    plane_optimizer = True
)

trainer.pipeline.datamanager.KERNEL_SIZE = 5
trainer.pipeline.datamanager.THRESHOLD = 40
trainer.pipeline.datamanager.METHOD = "sift"

trainer.pipeline.datamanager.get_inerf_batch()  
trainer.pipeline.datamanager.inerf_batch["image"] = trainer.pipeline.datamanager.inerf_batch["image"].to(trainer.pipeline.device)
trainer.pipeline.train()




  camera_to_worlds = torch.cat([camera_to_worlds, tensor([tf]).float()], 0)


Loading PlaneNerfCameraOptimizer


Loading latest Nerfstudio checkpoint from load_dir...


Number of keypoints:  116
Number of rays:  5092
Reduce the number of rays
Final number of rays:  4096


PlaneNerfPipeline(
  (datamanager): PlaneNerfDataManager(
    (train_ray_generator): RayGenerator()
  )
  (_model): PlaneNerfModel(
    (collider): NearFarCollider()
    (field): PlaneNerfField(
      (spatial_distortion): SceneContraction()
      (embedding_appearance): Embedding(
        (embedding): Embedding(300, 32)
      )
      (direction_encoding): SHEncoding(
        (tcnn_encoding): Encoding(n_input_dims=3, n_output_dims=16, seed=1337, dtype=torch.float32, hyperparams={'degree': 4, 'otype': 'SphericalHarmonics'})
      )
      (position_encoding): NeRFEncoding(
        (tcnn_encoding): Encoding(n_input_dims=3, n_output_dims=12, seed=1337, dtype=torch.float32, hyperparams={'n_frequencies': 2, 'otype': 'Frequency'})
      )
      (mlp_base_grid): HashEncoding(
        (tcnn_encoding): Encoding(n_input_dims=3, n_output_dims=32, seed=1337, dtype=torch.float32, hyperparams={'base_resolution': 16, 'hash': 'CoherentPrime', 'interpolation': 'Linear', 'log2_hashmap_size': 19, 'n_featu

In [15]:
FRAMES = 360
poses = []
imgs = []

r = - 2.5
z = 2.5
pitch = 0.785

start_transform_matrix = np.eye(4)
start_transform_matrix[:3, :3] = Rotation.from_rotvec(np.array([pitch, 0, 0])).as_matrix()
start_transform_matrix[1, 3] = r
start_transform_matrix[2, 3] = z

for i in range(FRAMES):
    angle = i * 2 * np.pi / FRAMES
    rot = Rotation.from_rotvec(np.array([0, 0, 1]) * angle).as_matrix()
    rot_matrix = np.eye(4)
    rot_matrix[:3, :3] = rot
    transform_matrix = np.matmul(rot_matrix, start_transform_matrix)
    poses.append(transform_matrix)
    #Convert transform_matrix to tensor
    transform_matrix = torch.tensor([transform_matrix[:3]], device=trainer.pipeline.device).float()
    print(transform_matrix)
    rendered_img = get_image(trainer.pipeline, transform_matrix.float())
    imgs.append(rendered_img)
print(poses)

tensor([[[ 1.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.7074, -0.7068, -2.5000],
         [ 0.0000,  0.7068,  0.7074,  2.5000]]], device='cuda:0')
tensor([[[ 0.8090, -0.4158,  0.4155,  1.4695],
         [ 0.5878,  0.5723, -0.5718, -2.0225],
         [ 0.0000,  0.7068,  0.7074,  2.5000]]], device='cuda:0')
tensor([[[ 0.3090, -0.6728,  0.6722,  2.3776],
         [ 0.9511,  0.2186, -0.2184, -0.7725],
         [ 0.0000,  0.7068,  0.7074,  2.5000]]], device='cuda:0')
tensor([[[-0.3090, -0.6728,  0.6722,  2.3776],
         [ 0.9511, -0.2186,  0.2184,  0.7725],
         [ 0.0000,  0.7068,  0.7074,  2.5000]]], device='cuda:0')
tensor([[[-0.8090, -0.4158,  0.4155,  1.4695],
         [ 0.5878, -0.5723,  0.5718,  2.0225],
         [ 0.0000,  0.7068,  0.7074,  2.5000]]], device='cuda:0')
tensor([[[-1.0000e+00, -8.6630e-17,  8.6561e-17,  3.0616e-16],
         [ 1.2246e-16, -7.0739e-01,  7.0683e-01,  2.5000e+00],
         [ 0.0000e+00,  7.0683e-01,  7.0739e-01,  2.5000e+00]]],
       de

In [22]:
import cv2
from PIL import Image

cv2_images = [img["rgb"].numpy() for img in imgs]
H,W,_ = cv2_images[0].shape
fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Use 'mp4v' codec for MP4 format
out = cv2.VideoWriter('output_video.mp4', fourcc, 10.0, (W, H))
for i in range(len(cv2_images)):
    #min_pil = cv2_images[i].min()
    #max_pil = cv2_images[i].max()
    #cv2_images[i] = (cv2_images[i] - min_pil) / (max_pil - min_pil)
    cv2_images[i] = (cv2_images[i] * 255).astype(np.uint8)
    cv2_images[i] = cv2.cvtColor(cv2_images[i], cv2.COLOR_RGB2BGR)
    out.write(cv2_images[i])
out.release()

In [18]:
print(imgs)

[{'rgb': tensor([[[0.4457, 0.4435, 0.2823],
         [0.4295, 0.4165, 0.4268],
         [0.4630, 0.4298, 0.3993],
         ...,
         [0.4200, 0.4537, 0.3399],
         [0.3308, 0.3646, 0.4447],
         [0.4435, 0.4386, 0.2990]],

        [[0.3692, 0.3877, 0.4454],
         [0.3923, 0.3687, 0.4302],
         [0.3651, 0.3307, 0.3518],
         ...,
         [0.3843, 0.4016, 0.3837],
         [0.4109, 0.4081, 0.2789],
         [0.4203, 0.4227, 0.3695]],

        [[0.3144, 0.3181, 0.3948],
         [0.4246, 0.3644, 0.3101],
         [0.3895, 0.3842, 0.4194],
         ...,
         [0.4490, 0.4353, 0.4461],
         [0.3494, 0.3684, 0.3969],
         [0.4064, 0.4203, 0.3249]],

        ...,

        [[0.4020, 0.3980, 0.3454],
         [0.4007, 0.4062, 0.4065],
         [0.4087, 0.3992, 0.3981],
         ...,
         [0.4027, 0.4070, 0.3693],
         [0.4220, 0.4181, 0.4168],
         [0.3284, 0.3625, 0.4383]],

        [[0.4289, 0.4189, 0.4046],
         [0.3704, 0.3505, 0.2904],
   