# NeRF Tutorial

This notebook was created for the AI Expert course on August 22nd.
This material explores the basics and applications of NeRF. For theoretical background on NeRF, please refer to the following paper: [NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](https://arxiv.org/abs/2003.08934).

## Table of Contents
* TinyNeRF
* NeRFStudio

# 1. TinyNeRF

A NeRF model with reduced performance for fast training and visualization
* Approximately 20 times fewer parameters compared to the original NeRF
* 5D input does not include view direction
* Does not perform Hierarchical Sampling

In [None]:
import os,sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

if not os.path.exists('tiny_nerf_data.npz'):
    !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz

In [None]:
#Search for GPU to run on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Load in data
rawData = np.load("tiny_nerf_data.npz")
images = rawData["images"]
poses = rawData["poses"]
focal = rawData["focal"]
H, W = images.shape[1:3]
H = int(H)
W = int(W)
print(images.shape, poses.shape, focal)

testimg, testpose = images[99], poses[99]
plt.imshow(testimg)
plt.show()
images = torch.Tensor(images).to(device)
poses = torch.Tensor(poses).to(device)
testimg = torch.Tensor(testimg).to(device)
testpose = torch.Tensor(testpose).to(device)

In [None]:
def get_rays(H, W, focal, pose):
  i, j = torch.meshgrid(
      torch.arange(W, dtype=torch.float32),
      torch.arange(H, dtype=torch.float32)
      )
  i = i.t()
  j = j.t()
  dirs = torch.stack(
      [(i-W*0.5)/focal,
       -(j-H*0.5)/focal,
       -torch.ones_like(i)], -1).to(device)
  rays_d = torch.sum(dirs[..., np.newaxis, :] * pose[:3, :3], -1)
  rays_o = pose[:3,-1].expand(rays_d.shape)
  return rays_o, rays_d

In [None]:
def positional_encoder(x, L_embed=6):
  rets = [x]
  for i in range(L_embed):
    for fn in [torch.sin, torch.cos]:
      rets.append(fn(2.**i *x))#(2^i)*x
  return torch.cat(rets, -1)

def cumprod_exclusive(tensor: torch.Tensor) -> torch.Tensor:
  cumprod = torch.cumprod(tensor, -1)
  cumprod = torch.roll(cumprod, 1, -1)
  cumprod[..., 0] = 1.
  return cumprod

def render(model, rays_o, rays_d, near, far, n_samples, rand=False):
  def batchify(fn, chunk=1024*32):
      return lambda inputs: torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)

  z = torch.linspace(near, far, n_samples).to(device)
  if rand:
    mids = 0.5 * (z[..., 1:] + z[...,:-1])
    upper = torch.cat([mids, z[...,-1:]], -1)
    lower = torch.cat([z[...,:1], mids], -1)
    t_rand = torch.rand(z.shape).to(device)
    z = lower + (upper-lower)*t_rand

  points = rays_o[..., None,:] + rays_d[..., None,:] * z[...,:,None]

  flat_points = torch.reshape(points, [-1, points.shape[-1]])
  flat_points = positional_encoder(flat_points)
  raw = batchify(model)(flat_points)
  raw = torch.reshape(raw, list(points.shape[:-1]) + [4])

  #Compute opacitices and color
  sigma = F.relu(raw[..., 3])
  rgb = torch.sigmoid(raw[..., :3])

  #Volume Rendering
  one_e_10 = torch.tensor([1e10], dtype=rays_o.dtype).to(device)
  dists = torch.cat((z[..., 1:] - z[..., :-1],
                  one_e_10.expand(z[..., :1].shape)), dim=-1)
  alpha = 1. - torch.exp(-sigma * dists)
  weights = alpha * cumprod_exclusive(1. - alpha + 1e-10)

  rgb_map = (weights[...,None]* rgb).sum(dim=-2)
  depth_map = (weights * z).sum(dim=-1)
  acc_map = weights.sum(dim=-1)
  return rgb_map, depth_map, acc_map


In [None]:
#helper functions
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])).to(device)

def train(model, optimizer, n_iters = 3001):
  #Track loss over time for graphing
  psnrs = []
  iternums = []
  plot_step = 500
  n_samples = 64
  for i in range(n_iters):
    #Choose random image and use it for training
    images_idx = np.random.randint(images.shape[0])
    target = images[images_idx]
    pose = poses[images_idx]

    #Core optimizer loop
    rays_o, rays_d = get_rays(H, W, focal, pose)
    rgb, disp, acc = render(model, rays_o, rays_d, near=2., far=6., n_samples=n_samples, rand=True)
    optimizer.zero_grad()
    image_loss = torch.nn.functional.mse_loss(rgb, target)
    image_loss.backward()
    optimizer.step()

    if i%plot_step==0:
      #Render shown image above as model begins to learn
      with torch.no_grad():
        rays_o, rays_d = get_rays(H, W, focal, testpose)
        rgb, depth, acc = render(model, rays_o, rays_d, near=2., far=6., n_samples=n_samples)
        loss = torch.nn.functional.mse_loss(rgb, testimg)
        psnr = mse2psnr(loss).cpu()

        psnrs.append(psnr)
        iternums.append(i)

        plt.figure(figsize=(10,5))
        plt.subplot(121)
        #copy from gpu memory to cpu
        picture = rgb.cpu()
        plt.imshow(picture)
        plt.title(f'Iterations: {i}')
        plt.subplot(122)
        plt.plot(iternums, psnrs)
        plt.title('PSNR')
        plt.show()

In [None]:
class VeryTinyNerfModel(torch.nn.Module):
  def __init__(self, filter_size=128, num_encoding_functions=6):
    super(VeryTinyNerfModel, self).__init__()
    # Input layer (default: 39 -> 128)
    self.layer1 = torch.nn.Linear(3 + 3 * 2 * num_encoding_functions, filter_size)
    # Layer 2 (default: 128 -> 128)
    self.layer2 = torch.nn.Linear(filter_size, filter_size)
    # Layer 3 (default: 128 -> 4)
    self.layer3 = torch.nn.Linear(filter_size, 4)
    # Short hand for torch.nn.functional.relu
    self.relu = torch.nn.functional.relu

  def forward(self, x):
    x = self.relu(self.layer1(x))
    x = self.relu(self.layer2(x))
    x = self.layer3(x)
    return x

In [None]:
#Run all the actual code
nerf = VeryTinyNerfModel()
nerf = nn.DataParallel(nerf).to(device)
optimizer = torch.optim.Adam(nerf.parameters(), lr=5e-3, eps = 1e-7)
train(nerf, optimizer)

In [None]:
%matplotlib inline
from ipywidgets import interactive, widgets


trans_t = lambda t : torch.tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1],
], dtype=torch.float32)

rot_phi = lambda phi : torch.tensor([
    [1,0,0,0],
    [0,np.cos(phi),-np.sin(phi),0],
    [0,np.sin(phi), np.cos(phi),0],
    [0,0,0,1],
], dtype=torch.float32)

rot_theta = lambda th : torch.tensor([
    [np.cos(th),0,-np.sin(th),0],
    [0,1,0,0],
    [np.sin(th),0, np.cos(th),0],
    [0,0,0,1],
], dtype=torch.float32)


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    c2w = torch.tensor([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]], dtype=torch.float32) @ c2w
    return c2w


def f(**kwargs):
    c2w = pose_spherical(**kwargs).cuda()
    rays_o, rays_d = get_rays(H, W, focal, c2w[:3,:4])
    rgb, depth, acc = render(nerf, rays_o, rays_d, near=2., far=6., n_samples=64)
    img = np.clip(rgb.cpu().detach().numpy(),0,1)

    plt.figure(2, figsize=(20,6))
    plt.imshow(img)
    plt.show()


sldr = lambda v, mi, ma: widgets.FloatSlider(
    value=v,
    min=mi,
    max=ma,
    step=.01,
)

names = [
    ['theta', [100., 0., 360]],
    ['phi', [-30., -90, 0]],
    ['radius', [4., 3., 5.]],
]

interactive_plot = interactive(f, **{s[0] : sldr(*s[1]) for s in names})
output = interactive_plot.children[-1]
output.layout.height = '350px'
interactive_plot

#### Exercise 6: Custom TinyNeRF

In [None]:
class CustomTinyNerfModel(torch.nn.Module):
    
    ######## Implement from here ########
    def __init__(self):
    
    def forward(self, x):
    ####### End of Implementation #######

nerf = CustomTinyNerfModel()
nerf = nn.DataParallel(nerf).to(device)
optimizer = torch.optim.Adam(nerf.parameters(), lr=5e-3, eps = 1e-7)
train(nerf, optimizer)

# 2. NeRFStudio

A platform designed to easily utilize various NeRF models.

* Includes multiple types of NeRF models (dynamic nerf, editing nerf, 3d diffusion model, fast nerf)
* Supports a powerful visualizer that enables easy rendering of desired view images
* Facilitates new model development by allowing easy modification of NeRF modules such as dataloader, ray sampler, encoder, etc.

In [None]:
NERFSTUDIO_PORT = os.environ["NERFSTUDIO_PORT"]
!ns-train nerfacto --output-dir "results/nerfstudio/nerfacto" --viewer.websocket-port "{NERFSTUDIO_PORT}" nerfstudio-data --data /data/nerfstudio/dozer --downscale-factor 4 

In [None]:
#@title # Resume stopped training
base_dir = "./outputs/unnamed/nerfacto/"
training_run_dir = base_dir + os.listdir(base_dir)[0] + '/nerfstudio_models'

!ns-train nerfacto --load-dir {training_run_dir} --viewer.websocket-port "{NERFSTUDIO_PORT}" nerfstudio-data --data data/nerfstudio/$scene --downscale-factor 4

### Utilizing Various Models with NeRFStudio
* nerfacto is a pipeline created by combining components from various papers
* nerfstudio supports various models besides nerfacto
  - Instant-NGP
  - [Instruct-NeRF2NeRF](https://docs.nerf.studio/en/latest/nerfology/methods/in2n.html)
  - K-Planes
  - [LERF](https://docs.nerf.studio/en/latest/nerfology/methods/lerf.html)
  - Mip-NeRF
  - NeRF
  - Nerfacto
  - Nerfbusters
  - NeRFPlayer
  - Tetra-NeRF
  - TensoRF
  - [Generfacto](https://docs.nerf.studio/en/latest/nerfology/methods/generfacto.html)

In [None]:
!ns-train --help

#### Exercise7: Training other models than NeRFacto (Recommendation: Instant-NGP)

In [None]:
######## Implement from here ########
####### End of Implementation #######