In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
import os
import json
import torch
from torch import tensor
import yaml
import numpy as np
import cv2
from pathlib import Path
import matplotlib.pyplot as plt
from copy import deepcopy 
from PIL import Image
import torchvision.transforms as transforms
import torch.nn.functional as F
from scipy.spatial.transform import Rotation 
from nerfstudio.cameras.camera_optimizers import CameraOptimizer
from nerfstudio.utils.eval_utils import eval_setup
from inerf.inerf_trainer import INerfTrainer
from inerf.inerf_utils import get_corrected_pose, load_eval_image_into_pipeline, get_relative_pose, get_absolute_diff_for_pose, get_image

In [None]:
os.chdir('/workspace')
MODEL_PATH = "/workspace/outputs/jackal_training_data_1/plane-nerf/2024-01-14_115715"
EVAL_PATH = "/stored_data/jackal_one_frame"
GROUND_TRUTH_PATH = os.path.join(EVAL_PATH, "ground_truth.json")
TRANSFORM_FILE = "transforms.json"
with open(GROUND_TRUTH_PATH) as f:
    GROUND_TRUTH = json.load(f)

In [None]:
config_path = os.path.join(MODEL_PATH, "config.yml")
config, pipeline, checkpoint_path, _ = eval_setup(
                        Path(config_path),
                        test_mode="inference",
                    )

In [None]:
z = 2.5
pitch = 0.785
r = -2.5
init_tf = np.eye(4)
init_tf[:3, :3] = Rotation.from_rotvec(np.array([pitch, 0, 0])).as_matrix()
init_tf[0, 3] = -r
init_tf[2, 3] = z
init_tf = [init_tf]*300

In [None]:
pipeline = load_eval_image_into_pipeline(pipeline,EVAL_PATH,transform_file=TRANSFORM_FILE)
from plane_nerf.plane_nerf_optimizer import PlaneNerfCameraOptimizer

#custom_camera_optimizer = CameraOptimizer(
custom_camera_optimizer = PlaneNerfCameraOptimizer(
    config = pipeline.model.camera_optimizer.config,
    num_cameras = len(pipeline.datamanager.train_dataset),
    device = pipeline.device,
)
custom_camera_optimizer.config.rot_l2_penalty = 0 #
custom_camera_optimizer.config.trans_l2_penalty = 0 #
pipeline.model.camera_optimizer = custom_camera_optimizer
trainer = INerfTrainer(config)
trainer.setup_inerf(pipeline)

In [None]:
ground_truth_poses = []
for _ , batch in pipeline.datamanager.fixed_indices_train_dataloader:
    tf = GROUND_TRUTH["frames"][int(batch['image_idx'])]["transform_matrix"]
    tf = np.asarray(tf)
    tf = tf[:3, :4 ]
    ground_truth_poses.append(tf)
ground_truth_poses = torch.tensor(ground_truth_poses).to(pipeline.device)

In [None]:
train_loop = 10
n = 100
lr_max = 5e-3
lr_min = 1e-4
store = torch.tensor([])

corrected_pose = get_corrected_pose(trainer)
#print(corrected_pose)
relative_pose = get_relative_pose(ground_truth_poses, corrected_pose)
t_diff, r_diff = get_absolute_diff_for_pose(relative_pose)
#Get averrage absolute translation and rotation error
print("Average translation error: ", torch.mean(t_diff))
print("Average rotation error: ", torch.mean(r_diff))

store = torch.cat((store, torch.tensor([[0, torch.mean(t_diff), torch.mean(r_diff)]])), 0)

for i in range(train_loop):
    for j in range(n):
        lr = lr_min + (lr_max - lr_min) * (i / train_loop)
        trainer.pipeline.train()
        loss, loss_dict, metrics_dict = trainer.train_iteration_inerf(i*n + j,optimizer_lr = 1e-4)
    corrected_pose = get_corrected_pose(trainer)
    #print(corrected_pose)
    relative_pose = get_relative_pose(ground_truth_poses, corrected_pose)
    t_diff, r_diff = get_absolute_diff_for_pose(relative_pose)

    #Get averrage absolute translation and rotation error
    print("Average translation error: ", torch.mean(t_diff))
    print("Average rotation error: ", torch.mean(r_diff))
    
    store = torch.cat((store, torch.tensor([[i+1, torch.mean(t_diff), torch.mean(r_diff)]])), 0)

In [None]:
original_img = cv2.imread(os.path.join(EVAL_PATH, GROUND_TRUTH["frames"][0]["file_path"]))
original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
rendered_img = get_image(trainer.pipeline, corrected_pose[0:,:,:])
rendered_img = rendered_img["rgb"]

In [None]:
#Plot original image and rendered image
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].imshow(original_img)
ax[0].set_title("Original image")
ax[1].imshow(rendered_img)
ax[1].set_title("Rendered image")

In [None]:
#Overlay original image and rendered image
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(original_img)
ax.imshow(rendered_img, alpha=0.5)


In [None]:
#Plot store_t_diff and store_r_diff with  respect to training iteration in 2 subplots

plotting_data = np.asarray(store.to("cpu"))
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(plotting_data[:, 0], plotting_data[:, 1])
plt.xlabel("Training iteration")
plt.ylabel("Absolute translation error")
plt.subplot(1, 2, 2)
plt.plot(plotting_data[:, 0], plotting_data[:, 2])
plt.xlabel("Training iteration")
plt.ylabel("Absolute rotation error")
plt.show()