In [1]:
import torch
import taichi as ti
from opt import get_opts
from train import NeRFSystem
from datasets.ray_utils import get_rays
from modules.intersection import RayAABBIntersector as ray_intersect

def taichi_init(args):
    taichi_init_args = {"arch": ti.cuda, "device_memory_GB": 8.0}
    if args.half2_opt:
        taichi_init_args["half2_vectorization"] = True

    ti.init(**taichi_init_args)

[Taichi] version 1.4.1, llvm 15.0.4, commit e67c674e, linux, python 3.9.16
[I 03/29/23 20:19:03.819 2437685] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout


In [2]:
prefix_args = [
 "--root_dir", "/home/loyot/workspace/code/ngp_pl_gui/Synthetic_NeRF/Lego",
 "--exp_name", "Lego", "--perf",
 "--num_epochs", "20", 
 "--batch_size", "8192", 
 "--lr", "1e-2", "--no_save_test"
]
hparams = get_opts(prefix_args)

device = torch.device('cuda')
taichi_init(hparams)
system = NeRFSystem(hparams).to(device)
system.setup("train")
system.configure_optimizers()

[Taichi] Starting on arch=cuda
GridEncoding: Nmin=16 b=1.31951 F=2 T=2^19 L=16
per_level_scale:  1.3195079107728942
offset_:  5710032
total_hash_size:  11420064
Loading 100 train images ...


100it [00:01, 54.64it/s]


Loading 200 test images ...


200it [00:03, 55.70it/s]


([Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      capturable: False
      differentiable: False
      eps: 1e-15
      foreach: None
      fused: None
      initial_lr: 0.01
      lr: 0.01
      maximize: False
      weight_decay: 0
  )],
 [<torch.optim.lr_scheduler.CosineAnnealingLR at 0x7f26f00c0250>])

## Setup

In [3]:
dataset = system.train_dataset
dataset.poses = dataset.poses.to(device)
dataset.directions = dataset.directions.to(device)
model = system.model # Taichi NGP
optimizer = system.net_opt

## Train Step

In [4]:
# Get data from dataset
batch = dataset[0]
poses = dataset.poses[batch['img_idxs']]
directions = dataset.directions[batch['pix_idxs']]
rgb_gt = batch['rgb'].to(device)

In [5]:
# Update occupancy grid
model.update_density_grid()

In [6]:
print("occ_grid shape: ", model.density_grid.shape)
print("128 x 128 x 128: ", 128 **3)

occ_grid shape:  torch.Size([1, 2097152])
128 x 128 x 128:  2097152


In [7]:
# Generate rays
rays_o, rays_d = get_rays(directions, poses)

In [8]:
print("ray shape: ", rays_o.shape)

ray shape:  torch.Size([8192, 3])


### Render

In [9]:
# Ray intersection with box
rays_o = rays_o.contiguous() # make sure buffer is contiguous
rays_d = rays_d.contiguous()

NEAR_DISTANCE = 0.01
box_center = model.center
box_size = model.half_size

_, hits_t, _ = ray_intersect.apply(rays_o, rays_d, box_center, box_size, 1)

hits_mask = (hits_t[:, 0, 0] >= 0) & (hits_t[:, 0, 0] < NEAR_DISTANCE)
hits_t[hits_mask, 0, 0] = NEAR_DISTANCE
hits_t = hits_t[:, 0]

In [10]:
# Ray marching
MAX_SAMPLES = 1024

marching_results = model.ray_marching(
    rays_o, rays_d, # ray
    hits_t, # near & far
    model.density_bitfield, # occupancy grid
    model.cascades, 
    model.scale, 
    0, 
    model.grid_size,
    MAX_SAMPLES
)

rays_a, xyzs, dirs, deltas, ts, n_samples = marching_results

In [12]:
# Ray level -> Sample level
print("xyzs shape: ", xyzs.shape)

xyzs shape:  torch.Size([2055705, 3])


In [13]:
print("rays_a shape: ", rays_a.shape)

rays_a shape:  torch.Size([8192, 3])


In [14]:
non_zero_mask = rays_a[:, 2] > 0
rays_a[non_zero_mask][40]

tensor([7849, 9796,  253], device='cuda:0', dtype=torch.int32)

In [None]:
# Ray marching pseudo-code
# For each ray
for ray_id in ray_ids:
    r = ray_id
    n = 0
    t1, t2 = hits_t[r, 0], hits_t[r, 1]
    t = t1
    while (0 <= t) & (t < t2) & (n < MAX_SAMPLES):
        xyz = ray_o + t * ray_d
        nxyz = ceil(nxyz)
        occ = occupancy_grid[nxyz] 

        if occ:
            t += calculate_dt(t) # next step 
            n += 1 # add sample
            val_samples[r, n] = xyz # save sample
        else:
            t += advance_to_next_cell()


In [15]:
# run radiance field
kwargs = {}
sigmas, rgbs = model(xyzs, dirs, **kwargs)

In [16]:
# Volume rendering
T_threshold = 1e-4
render_results = model.render_func(
    sigmas, 
    rgbs, 
    deltas, 
    ts, 
    rays_a, 
    T_threshold
)
_, _, _, rgb, _ = render_results

In [17]:
# Sample level -> Ray level
print("rgb shape: ", rgb.shape)

rgb shape:  torch.Size([8192, 3])


In [18]:
# Loss
loss = ((rgb - rgb_gt)**2).mean()
loss.backward()
optimizer.step()

In [19]:
print("loss: ", loss)

loss:  tensor(0.5511, device='cuda:0', grad_fn=<MeanBackward0>)


In [20]:
torch.cuda.empty_cache()

# Model

In [35]:
import numpy as np
from modules.networks import MLP
from modules.hash_encoder import HashEncoder
from modules.spherical_harmonics import DirEncoder

In [37]:
# Get box size
box_min = model.xyz_min
box_max = model.xyz_max

# Encoders
L = 16
min_res = 16
max_res = 1024
b = np.exp(np.log(max_res / min_res) / (L - 1))
pos_encoder = HashEncoder( # or Triplane
    b,
    hparams.batch_size # pre-allocation
).to(device)
dir_encoder = DirEncoder(hparams.batch_size).to(device)

# MLP
sigma_net = \
    MLP(
        input_dim=32,
        output_dim=16,
        net_depth=1,
        net_width=64,
        bias_enabled=False,
    ).to(device)

rgb_net = \
    MLP(
        input_dim=32,
        output_dim=3,
        net_depth=2,
        net_width=64,
        bias_enabled=False,
        output_activation=torch.nn.Sigmoid()
    ).to(device)

per_level_scale:  1.3195079107728942
offset_:  5710032
total_hash_size:  11420064


In [40]:
# Model forward
class TruncExp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.exp(x)

    @staticmethod
    def backward(ctx, dL_dout):
        x = ctx.saved_tensors[0]
        return dL_dout * torch.exp(x.clamp(-15, 15))

# Get sigma

x = (xyzs - box_min)/ (box_max - box_min)
pos_embedding = pos_encoder(x)
h = sigma_net(pos_embedding)
sigmas = TruncExp.apply(h[:, 0])

# Get rgb
d = dirs / torch.norm(dirs, dim=1, keepdim=True)
dir_embedding = dir_encoder((d + 1) / 2) # [-1, 1] -> [0, 1]
rgbs = rgb_net(torch.cat([dir_embedding, h], 1))

# return sigmas, rgbs

In [42]:
print("sigmas shape: ", sigmas.shape)
print("rgbs shape: ", rgbs.shape)

sigmas shape:  torch.Size([2074017])
rgbs shape:  torch.Size([2074017, 3])
