In [13]:
import torch
import numpy as np
import os
import imageio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

import dataset

Instead of sphere with few parameters we'll have voxels with way more parameters.
our sphere model isnt complex enough to fit to enough scenes

Therefore we'll move to voxels to represent more scenes, and finally we'll move to Neural Networks that can represent even more complex scenes


## 1. Camera / Dataset

- Instrinsic: intrinsic paramters eg. the camera parameters the orientation, focal length etc.
for NeRF we make sure the intrinsic parameters like focal length are the same for input image while the position is changing overtime.

- do not use autofocus on your phone use the same focal length for all images.
- most nerf code assumes that intrinsic params between input images remain the same

In [8]:
def get_rays(datapath, mode='train'):
    pose_file_names = [f for f in os.listdir(datapath + '/{mode}/pose/') if f.endswith('.txt')]
    intrinsics_file_names= [f for f in os.listdir(datapath + '/{mode}/intrinsics/') if f.endswith('.txt')]
        
    img_file_names = [f for f in os.listdir(datapath + '/imgs/') if 'train' in f]
    print(f"Image files length: {len(img_file_names)}")
    print(f"Pose files length: {len(pose_file_names)}")
    print(f"Intrinsic files length: {len(intrinsics_file_names)}")
    
    assert len(pose_file_names) == len(intrinsics_file_names)
    assert len(img_file_names) == len(pose_file_names)
    
    #Read
    N = len(pose_file_names)
    poses = np.zeros((N, 4, 4)) #N 4x4 matrixes (homogeneous)
    intrinsics = np.zeros((N, 4, 4))
    
    images = []
    
    for i in range(N):
        pose_name = pose_file_names[i]
        pose = open(datapath + '{mode}/pose/' + pose_name).read().split()
        poses[i] = np.array(pose, dtype=float).reshape(4,4)
        
        # print(poses[i])
        
        intrinsics_name = intrinsics_file_names[i]
        intrinsic = open(datapath + '{mode}/intrinsics/' + intrinsics_name).read().split()
        intrinsics[i] = np.array(intrinsic, dtype=float).reshape(4, 4)
        
        # print(intrinsics[i])
        
        #Read images
        image_name = img_file_names[i]
        img = imageio.imread(datapath + '/imgs/' + image_name)
        max_img_intensity = float(img.max()) #255 
        img = img / max_img_intensity #normalizing pixel intensities so theyre between 0-1
        
        images.append(img[None, ...]) #unsqueeze 1st dim in numpy
        
    print(f"Image size: {img.shape}")   
    images = np.concatenate(images)
    print(images.shape)
    
    H = images.shape[1]
    W = images.shape[2]
    
    # remove the 4th dimension to get rid of the alpha channel ie opacity
    if images.shape[3] == 4: #RGBA -> RGB
        images = images[..., :3] * images[..., -1:] + (1-images[..., -1:])
    
    # plt.imshow(images[0])
    # plt.show()
    
    rays_origin = np.zeros((N, H*W, 3))
    rays_direction = np.zeros((N, H*W, 3))
    target_px_values = images.reshape((N, H*W, 3))

    for i in range(N):
        
        camera2world = poses[i]
        f = intrinsics[i,0,0]
        
        u = np.arange(W)
        v = np.arange(H)
        u, v = np.meshgrid(u, v)

        dirs = np.stack((u - W / 2,
                        -(v - H / 2),
                        - np.ones_like(u) * f), axis=-1)

        dirs = (camera2world[:3, :3] @ dirs[..., None]).squeeze(-1)
        dirs = dirs / np.linalg.norm(dirs, axis=-1, keepdims=True)
        
        rays_direction[i] = dirs.reshape(-1, 3)
        rays_origin[i] += camera2world[:3, 3]
        
    return rays_origin, rays_direction, target_px_values
    
    
    

In [12]:
BATCH_SIZE = 1024
origin, direction, target_px_values = dataset.get_rays('fox/fox/',mode='train')

dataloader = DataLoader(torch.cat((torch.from_numpy(origin),
                                   torch.from_numpy(direction),
                                   torch.from_numpy(target_px_values)), dim=1),
                        batch_size=BATCH_SIZE, shuffle=True)

test_origin, test_direction, test_target_px_values = dataset.get_rays('fox/fox/', mode='test')


print(f"origin shape: {origin.shape}")
print(f"direction shape: {direction.shape}")

(90, 160000, 3)

## 2. Rendering

## 3. Model

## 4. Training