*Exploratory Data Analysis*

# Visualizing the Training Data

In this notebook we visualize the camera poses during training.

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from mpl_toolkits.mplot3d import Axes3D
import torch
from ipywidgets import interact

from run_dnerf_helpers import get_rays
from utils import load_owndataset_data
from utils import Arrow3D, draw_transformed, draw_cam

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
scene_name = "johannes"
images, depth_maps, poses, times, render_poses, render_times, hwff, i_split = load_owndataset_data(f"./data/{scene_name}", True, 1, render_pose_type="spiral")

i_train, _, _ = i_split

poses = [poses[_] for _ in i_train]
images = [images[_] for _ in i_train]

In [10]:
# xlim = [-3, 3]
# ylim = [-1, 1]
# zlim = [0, 6]
xlim = [-1., 1.]
ylim = [-1., 1.]
zlim = [-1, 1]


def series(ith_frame):
    pose = poses[ith_frame]

    fig = plt.figure(figsize=(10, 10))
    
    ax0 = fig.add_subplot(131)
    ax0.imshow(images[ith_frame])
    ax0 = fig.add_subplot(132)
    ax0.imshow(depth_maps[ith_frame])

    ax1 = fig.add_subplot(133, projection='3d')
    ax1.set_title(f"Time = {times[ith_frame]:.2f}\n")

    ax1.set_xlabel('X')
    ax1.set_xlim(*xlim)
    ax1.set_ylabel('Y')
    ax1.set_ylim(*ylim)
    ax1.set_zlabel('Z')
    ax1.set_zlim(*zlim)
    ax1.set_box_aspect((xlim[1]-xlim[0], ylim[1]-ylim[0], zlim[1]-zlim[0]))       # -> length of 1 in each dimension is visually the equal

    # The world coordinate system
    arrow_prop_dict = dict(mutation_scale=10, arrowstyle='simple', shrinkA=0, shrinkB=0)
    ax1.add_artist(Arrow3D([0, .5], [0, 0], [0, 0], **arrow_prop_dict, color='r'))
    ax1.add_artist(Arrow3D([0, 0], [0, .5], [0, 0], **arrow_prop_dict, color='b'))
    ax1.add_artist(Arrow3D([0, 0], [0, 0], [0, .5], **arrow_prop_dict, color='g'))
    ax1.text(-.1, -.1, 0.0, r'$0$')

    tcx, tcy, tcz, _ = draw_transformed(pose, ax1, linestyle="--", axes_len=0.5, mutation_scale=10)

    # Draw the training camera coordinate frame
    #tcx, tcy, tcz, _ = draw_transformed(pose, ax1, arrowstyle='simple', axes_len=0.7, linewidth=1.5, mutation_scale=20, edgecolor="black")

#     lgnd1 = plt.legend(handles=[tcx, tcy, tcz], 
#             labels=["X", "Y", "Z"], 
#             title="Global coordinate frame", loc=1)
    plt.legend(handles=[Line2D([0], [0], color='r', ls="--"), 
                        Line2D([0], [0], color='b', ls="--"), 
                        Line2D([0], [0], color='g', ls="--"), 
                        Line2D([0], [0], color='black', ls="-")],
            labels=["X", "Y", "Z", "Camera frustrum"], 
            title="Training view camera", 
            bbox_to_anchor=(1.7, 0.3))

    # draw camera rays
    c2w = torch.Tensor(pose[:3])
    H, W, focal_x, focal_y = hwff
    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H), indexing='ij')      # shape [240, 320], [240, 320]
    i = i.t()           # pixel coordinates in X-dir
    j = j.t()           # in Y-dir
    dirs = torch.stack([(i-W*.5)/focal_x, -(j-H*.5)/focal_y, -torch.ones_like(i)], -1)                                          # shape [240, 320, 3]
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    rays_o = c2w[:3,-1].expand(rays_d.shape)
    draw_cam(rays_o, rays_d, ax1, focal_dist=0.5)       # rays_o and rays_d are already in world-coordinates

    ax1.view_init(elev=30., azim=20., vertical_axis='y')        # only works with matplotlib >= 3.5
    ax1.dist = 6.2
    
    fig.tight_layout()

inter = interact(series, ith_frame=(0, len(poses)-1, 1))

interactive(children=(IntSlider(value=34, description='ith_frame', max=69), Output()), _dom_classes=('widget-i…