In [None]:
import os

if not os.path.exists('tiny_nerf_data.npz'):
    !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz

In [2]:
import numpy as np

data = np.load('tiny_nerf_data.npz')
# Each image is a different view of the same scene
imgs = data['images']
# Each pose is a 4x4 transformation matrix
poses = data['poses']
# Focal length of the camera
# Seems to mean the distance from the camera to the image plane
# TODO: Given this is a constant, why do we care?
focal = data['focal']

In [3]:
import tensorflow as tf

def get_rays(H, W, focal, c2w):
    i, j = tf.meshgrid(tf.range(W, dtype=tf.float32), tf.range(H, dtype=tf.float32), indexing='xy')
    dirs = tf.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -tf.ones_like(i)], -1)
    rays_d = tf.reduce_sum(dirs[..., tf.newaxis, :] * c2w[:3,:3], -1)
    rays_o = tf.broadcast_to(c2w[:3,-1], tf.shape(rays_d))
    return rays_o, rays_d

(100, 100, 3) (100, 100, 3)


In [39]:
import torch
from einops import rearrange, repeat

def get_rays_torch(H, W, focal, c2w):
    # https://stackoverflow.com/questions/36013063/what-is-the-purpose-of-meshgrid-in-python-numpy
    # Summary: we want to represent a grid of points in 2D space
    # i[0, 0] is the x coordinate of the top-left point in the grid
    # (top b/c y-axis points down)
    # j[0, 0] is the y coordinate of the top-left point in the grid
    # i[0, 1] is the x coordinate of the 2nd point in the top row of the grid
    # j[0, 1] is the y coordinate of the 2nd point in the top row of the grid
    # Another explanation: torch.arange(W) is drawing the x-coordinate (vertical) lines of the grid
    # torch.arange(H) is drawing the y-coordinate (horizontal) lines of the grid
    # the pair (i[0, 0], j[0, 0]) is xy coordinates of the intersection of the first vertical (x) and first horizontal (y) line
    # the pair (i[0, 1], j[0, 0]) is non-sensical
    i, j = torch.meshgrid(torch.arange(W, dtype=torch.float32), torch.arange(H, dtype=torch.float32), indexing='xy')

    assert i.shape == (H, W) and j.shape == (H, W)
    assert (i[0] == torch.arange(W, dtype=torch.float32)).all()
    assert (i[1] == torch.arange(W, dtype=torch.float32)).all()
    assert (j[0] == torch.zeros(H, dtype=torch.float32)).all()
    assert (j[1] == torch.ones(H, dtype=torch.float32)).all()

    # The first element computes the x-direction of the rays by taking the difference between the x-coordinate 
    # of each point and the center of the image (which is at W * 0.5), and then dividing by the focal length. 
    # The second element computes the y-direction of the rays using a similar formula, but with a negative sign 
    # to account for the fact that the y-axis is pointing downwards in the image. Finally, the third element is 
    # a vector of negative ones, representing the z-direction of the rays, since we assume that the camera is looking 
    # towards the negative z-axis.
    ray_dirs = torch.stack([(i - W * 0.5) / focal, -(j - H * 0.5) / focal, -torch.ones_like(i)], dim=-1)

    # Each direction is a 3D vector composed of (x, y, z) coordinates
    assert ray_dirs.shape == (H, W, 3)

    # c2w (camera-to-world) is a 4x4 rotation matrix
    # [ R  t ]
    # [ 0  1 ]
    R, t = c2w[:3, :3], c2w[:3, -1]

    origin_points_of_rays = repeat(t, 'c -> h w c', h=H, w=W)
    direction_of_rays = ray_dirs @ R.T
    assert origin_points_of_rays.shape == (H, W, 3) and direction_of_rays.shape == (H, W, 3)
    return origin_points_of_rays, direction_of_rays

In [40]:
# test that the two functions are equivalent
H, W = imgs[0].shape[:2]
tf_rays_o, tf_rays_d = get_rays(H, W, focal, poses[0])
torch_rays_o, torch_rays_d = get_rays_torch(H, W, torch.from_numpy(focal), torch.from_numpy(poses[0]))

# convert all numpy arrays to torch tensors
tf_rays_o = torch.from_numpy(tf_rays_o.numpy())
tf_rays_d = torch.from_numpy(tf_rays_d.numpy())

# use torch testing
torch.testing.assert_close(tf_rays_o, torch_rays_o)
torch.testing.assert_close(tf_rays_d, torch_rays_d)