In [10]:
import sys
sys.path.append("../")

In [3]:
from torch.utils.data import DataLoader
from easydict import EasyDict as edict

from train import NeRFSystem
from models.nerf import Embedding, NeRF
from models.rendering import render_rays
from datasets.blender import BlenderDataset

In [4]:
hparams = edict(
            root_dir = "/scratch/saksham/data/nerf_synthetic/lego/",
            dataset_name = "blender",
            N_samples = 64,
            N_importance = 128,
            loss_type = "mse",
            batch_size = 1024,
            chunk = 32*1024,
            use_disp = True,
            perturb = True,
            noise_std = 1.0
          )

In [5]:
# data paths 
blender_data_path = "/scratch/saksham/data/nerf_synthetic/lego/"
train_dataset = BlenderDataset(root_dir = blender_data_path, split = "train")
train_loader = DataLoader(train_dataset,
                          shuffle=True,
                          num_workers=4,
                          batch_size= hparams.batch_size,
                          pin_memory=True)

In [6]:
data_iter = iter(train_loader)

In [7]:
batch = next(data_iter)

In [8]:
rays, rgbs = batch['rays'], batch['rgbs']
rays.shape, rgbs.shape

(torch.Size([1024, 8]), torch.Size([1024, 3]))

### Build Model

In [35]:
embedding_xyz = Embedding(3, 10) # 10 is the default number
embedding_dir = Embedding(3, 4) # 4 is the default number

embeddings = {'xyz': embedding_xyz,
              'dir': embedding_dir}

nerf_coarse = NeRF()
models = {'coarse' : nerf_coarse}

if hparams.N_importance > 0:
    nerf_fine = NeRF()
    models['fine'] = nerf_fine

### Model Forward

In [36]:
B = rays.shape[0]

In [39]:
rays.shape

torch.Size([1024, 8])

In [37]:
from collections import defaultdict

results = defaultdict(list)

**Inputs:**
- models: list of NeRF models (coarse and fine) defined in nerf.py 
- embeddings: list of embedding models of origin and direction defined in nerf.py 
- rays: (N_rays, 3+3+2), ray origins and directions, near and far depths 
- N_samples: number of coarse samples per ray 
- use_disp: whether to sample in disparity space (inverse depth) 
- perturb: factor to perturb the sampling position on the ray (for coarse model only)
- noise_std: factor to perturb the model's prediction of sigma
- N_importance: number of fine samples per ray
- chunk: the chunk size in batched inference 
- white_back: whether the background is white (dataset dependent) 
- test_time: whether it is test (inference only) or not. If True, it will not do inference
                   on coarse rgb to save time

In [38]:
for i in range(0, B, hparams.chunk):
    rendered_ray_chunks = \
        render_rays(models,
                    embeddings,
                    rays[i:i + hparams.chunk],
                    hparams.N_samples,
                    hparams.use_disp,
                    hparams.perturb,
                    hparams.noise_std,
                    hparams.N_importance,
                    hparams.chunk, # chunk size is effective in val mode
                    train_dataset.white_back)

    for k, v in rendered_ray_chunks.items():
        results[k] += [v]

In [63]:
results.keys()

dict_keys(['weights_coarse', 'opacity_coarse', 'z_vals_coarse', 'rgb_coarse', 'depth_coarse', 'weights_fine', 'opacity_fine', 'z_vals_fine', 'rgb_fine', 'depth_fine'])

In [64]:
results['weights_coarse'][0].shape

torch.Size([1024, 64])