In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Dataset

Tiny Lego dataset
- https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz

In [None]:
if not os.path.exists("tiny_nerf_data.npz"):
  !wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz

dataset = np.load("tiny_nerf_data.npz")
print(dataset["images"].shape)
print(dataset["poses"].shape)
print(dataset['focal'])
fig, axs = plt.subplots(nrows=1, ncols=10, figsize=(30,4))
for i, ax in enumerate(axs.flatten()):
    plt.sca(ax)
    plt.imshow(dataset['images'][i])
    plt.title('Image: {}'.format(i+1))
    plt.axis("off")

# Data Loading (Pinhole Camera Model)

In [None]:
poses = dataset['poses']
dirs = np.stack([np.sum([0, 0, -1] * pose[:3, :3], axis=-1) for pose in poses])
origins = poses[:, :3, 3]

ax = plt.figure(figsize=(12,8)).add_subplot(projection="3d")
_ = ax.quiver(
    origins[..., 0].flatten(),
    origins[..., 1].flatten(),
    origins[..., 2].flatten(),
    dirs[..., 0].flatten(),
    dirs[..., 1].flatten(),
    dirs[..., 2].flatten(), length=0.5, normalize=True
)
plt.show()

# Functions


In [None]:
def get_rays(h: int, w: int, focal_length: float, pose: torch.Tensor):
  i, j = torch.meshgrid(
      torch.arange(w, dtype=torch.float32).to(pose),
      torch.arange(h, dtype=torch.float32).to(pose),
      indexing='ij')
  i, j = i.transpose(-1, -2), j.transpose(-1, -2)
  rays_d = torch.stack([(i - w * .5) / focal_length,
                            -(j - h * .5) / focal_length,
                            -torch.ones_like(i)
                           ], dim=-1)
  rays_d = torch.sum(rays_d[..., None, :] * pose[:3, :3], dim=-1)
  rays_o = pose[:3, -1].expand(rays_d.shape)
  return rays_o, rays_d

In [None]:
def stratified_sampling(
    rays_o,
    rays_d,
    near,
    far,
    n,
):
  # shape: (num_samples)
  t = torch.linspace(near, far, n).to(rays_o)
  
  # ray_origins: (width, height, 3)
  # noise_shape = (width, height, num_samples)
  noise_shape = list(rays_o.shape[:-1]) + [n]
  
  # depth_values: (num_samples)
  t = t + torch.rand(noise_shape).to(rays_o) * (far - near) / n
  
  # (width, height, num_samples, 3) = (width, height, 1, 3) + (width, height, 1, 3) * (num_samples, 1)
  # query_points:  (width, height, num_samples, 3)
  x = rays_o[..., None, :] + rays_d[..., None, :] * t[..., :, None]
  
  return x, t

In [None]:
def positional_encoding(
    x, L=6, include_input=True
) -> torch.Tensor:
  encoding = [x] if include_input else []
  frequency_bands = 2.0 ** torch.linspace(
        0.0,
        L - 1,
        L,
        dtype=x.dtype,
        device=x.device,
  )
  for freq in frequency_bands:
    encoding.append(torch.sin(x * freq * np.pi))
    encoding.append(torch.cos(x * freq * np.pi))
  
  return torch.cat(encoding, dim=-1)