# CS 180 Project 5 - Neural Radiance Fields

# Part 1 - Fit a Neural Field to a 2D image

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import skimage.io as skio
import skimage as sk
import numpy as np
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# Build Model
class NF_MLP(nn.Module):
  def __init__(self, channel_size, L):
    super(NF_MLP, self).__init__()
    self.mlp = nn.Sequential(
        nn.Linear(4*L+2, channel_size),
        nn.ReLU(),
        nn.Linear(channel_size,channel_size),
        nn.ReLU(),
        nn.Linear(channel_size,channel_size),
        nn.ReLU(),
        nn.Linear(channel_size,3),
        nn.Sigmoid()
    )
    self.L = L

  def forward(self, x):
    x = self.get_PE(x)
    out = self.mlp(x)
    return out

  def get_PE(self, x):
    x_dims = torch.hstack([torch.sin(((2**i)*np.pi)*x) for i in range(self.L)])
    y_dims = torch.hstack([torch.cos(((2**i)*np.pi)*x) for i in range(self.L)])
    x_PE = torch.hstack((x,x_dims, y_dims))
    return x_PE


In [None]:
# Create Custom Dataset
class NeRF_Dataset(Dataset):
  def __init__(self,image,N, numIt):
    self.N = N
    self.numIt = numIt
    self.image = image

  def __len__(self):
    return self.numIt

  def __getitem__(self, idx):
    N,M,_ = self.image.shape
    random_indices = torch.randint(0, N*M, (self.N,))
    row_indices = random_indices // M
    col_indices = random_indices % M
    pixel_coords = torch.stack((row_indices/N, col_indices/M), dim=1)
    pixel_vals = self.image[row_indices, col_indices,:]
    return pixel_coords, pixel_vals


In [None]:
# Function to run model on image
def getImage(model, s):
  model.eval()
  N,M,_ = s
  row_coords, col_coords = torch.meshgrid(torch.arange(N), torch.arange(M))
  pixel_coords = torch.stack((row_coords/N, col_coords/M), dim=2)
  img = model(pixel_coords.reshape(N*M,2).to(device)).reshape(N,M,3).cpu().detach().numpy()
  return img

In [None]:
from tqdm import tqdm

# Funtion to train model to reconstruct image
def train(model, train_loader,s):
  losses = []
  ims = []

  model.train()
  PSNR_loss = lambda input, target: 10*torch.log10(nn.functional.mse_loss(input.float(), target.float()))
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
  steps = [1,20,100,500,1000,2000]
  for i, (inputs, labels) in enumerate(tqdm(train_loader)):
      if i+1 in steps:
        im = getImage(model,s)
        ims.append(im)
      inputs = inputs[0].to(device)
      labels = labels[0].to(device)
      outputs = model(inputs)
      loss = PSNR_loss(outputs, labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      losses.append(loss.item())

  return losses, ims

## Run NF_MLP model on images

### Fox Image Results

In [None]:
fox_image = torch.tensor(sk.img_as_float(skio.imread("fox.jpg")))
foxDataset = NeRF_Dataset(fox_image,10000,2000)
fox_dataloader = DataLoader(foxDataset)

In [None]:
model = NF_MLP().to(device)
fox_losses, fox_ims = train(model, fox_dataloader,fox_image.shape)

In [None]:
plt.plot(range(1,2001),-np.array(fox_losses),'-bo',markersize=0.5)
plt.xlabel('Number of Iterations')
plt.ylabel('PSNR')
plt.title('Fox Image Training PSNR')
plt.show()

In [None]:
fig,axes = plt.subplots(1,len(fox_ims),figsize=(30,30))
plt.subplots_adjust(wspace=0.05)
for i in range(len(fox_ims)):
  axes[i].imshow(fox_ims[i])
  axes[i].axis('off')
plt.show()

### Fox Image Results - Increase Channel Size

In [None]:
model2 = NF_MLP(420,10).to(device)
fox_losses2, fox_ims2 = train(model2, fox_dataloader,fox_image.shape)

In [None]:
plt.plot(range(1,2001),-np.array(fox_losses2),'-bo',markersize=0.5)
plt.xlabel('Number of Iterations')
plt.ylabel('PSNR')
plt.title('Fox Image Training PSNR with Channel Size = 420')
plt.show()

In [None]:
fig,axes = plt.subplots(1,len(fox_ims2),figsize=(30,30))
plt.subplots_adjust(wspace=0.05)
for i in range(len(fox_ims2)):
  axes[i].imshow(fox_ims2[i])
  axes[i].axis('off')
plt.show()

### Fox Image Results - Decrease Encoding Size

In [None]:
model3 = NF_MLP(256,3).to(device)
fox_losses3, fox_ims3 = train(model3, fox_dataloader,fox_image.shape)

In [None]:
plt.plot(range(1,2001),-np.array(fox_losses3),'-bo',markersize=0.5)
plt.xlabel('Number of Iterations')
plt.ylabel('PSNR')
plt.title('Fox Image Training PSNR with L = 3')
plt.show()

In [None]:
fig,axes = plt.subplots(1,len(fox_ims3),figsize=(30,30))
plt.subplots_adjust(wspace=0.05)
for i in range(len(fox_ims3)):
  axes[i].imshow(fox_ims3[i])
  axes[i].axis('off')
plt.show()

### Results for El Capitan image

In [None]:
capitan_image = torch.tensor(sk.img_as_float(skio.imread("capitan.jpeg")))
capitanDataset = NeRF_Dataset(capitan_image,30000,2000)
capitan_dataloader = DataLoader(capitanDataset)

In [None]:
capitanModel = NF_MLP(300,15).to(device)
capitan_losses, capitan_ims = train(capitanModel, capitan_dataloader,capitan_image.shape)

In [None]:
plt.plot(range(1,2001),-np.array(capitan_losses),'-bo',markersize=0.5)
plt.xlabel('Number of Iterations')
plt.ylabel('PSNR')
plt.title('El Capitan Image Training PSNR')
plt.show()

In [None]:
fig,axes = plt.subplots(1,len(capitan_ims),figsize=(30,30))
plt.subplots_adjust(wspace=0.05)
for i in range(len(capitan_ims)):
  axes[i].imshow(capitan_ims[i])
  axes[i].axis('off')
plt.show()

## Part 2: Fit a Neural Radiance Field from Multi-view images

In [None]:
data = np.load(f"lego_200x200.npz")

# Training images: [100, 200, 200, 3]
images_train = data["images_train"] / 255.0

# Cameras for the training images
# (camera-to-world transformation matrix): [100, 4, 4]
c2ws_train = data["c2ws_train"]

# Validation images:
images_val = data["images_val"] / 255.0

# Cameras for the validation images: [10, 4, 4]
# (camera-to-world transformation matrix): [10, 200, 200, 3]
c2ws_val = data["c2ws_val"]

# Test cameras for novel-view video rendering:
# (camera-to-world transformation matrix): [60, 4, 4]
c2ws_test = data["c2ws_test"]

# Camera focal length
focal = data["focal"]  # float

# Intrinsic matrix
sigma = images_train.shape[1]/2
K = np.array([[focal,0,sigma],[0,focal,sigma],[0,0,1]])

### Part 2.1: Create rays from cameras

In [None]:
def transform(c2w,x_c):
  coords_set = np.hstack([x_c,np.ones((x_c.shape[0],1))]).T
  x_w = c2w @ coords_set
  x_w = x_w[:3] / x_w[3]
  return x_w.T

def pixel_to_camera(K, uv, s):
  uv = np.hstack([uv,np.ones((uv.shape[0],1))]).T
  K_inv = np.linalg.inv(K)
  return (K_inv * s @ uv).T

def pixel_to_ray(K, c2w, uv):
  uv = uv + 0.5
  r_o = transform(c2w, np.zeros((uv.shape[0],3)))
  X_c = pixel_to_camera(K,uv, 1)
  X_w = transform(c2w, X_c)
  norm = np.linalg.norm(X_w-r_o, axis=1)
  r_d = (X_w - r_o) / norm.reshape((norm.shape[0],1))
  return r_o, r_d

### Parts 2.2 & 2.3: Sampling and putting Dataloading together

In [None]:
def sample_along_rays(rays_o, rays_d, perturb):
    near = 2.0
    far = 6.0
    n_samples = 64
    t = np.linspace(near, far, n_samples)
    if perturb:
      t = t + np.random.rand(n_samples) * 0.1

    X = np.array([rays_o + rays_d*z for z in t])
    return X.swapaxes(0,1)

In [None]:
class RaysData(Dataset):
  def __init__(self,images, K, c2ws, M, numIt,perturb):
    self.images = images
    self.K = K
    self.c2ws = c2ws
    self.perturb = perturb
    self.M = M
    self.numIt = numIt

  def __len__(self):
      return self.numIt

  def __getitem__(self,idx):
    rays_o, rays_d, pixels = self.sample_rays(10000)
    points = sample_along_rays(rays_o, rays_d,self.perturb)
    return rays_d, points, pixels

  def sample_rays(self, N):
    dim = self.images[0].shape[1]
    M = self.M
    sample_idx = np.random.choice(np.arange(self.images.shape[0]),M)
    image_sample = self.images[sample_idx]
    c2w_sample = self.c2ws[sample_idx]
    sample_pixels = np.random.choice(np.arange(dim),(M,N//M,2))

    rays_o = np.zeros(((N//M)*M,3))
    rays_d = np.zeros(((N//M)*M,3))
    pixels = np.zeros(((N//M)*M,3))
    for i in range(M):
      rays_o[(N//M)*i:(N//M)*(i+1)], rays_d[(N//M)*i:(N//M)*(i+1)] = pixel_to_ray(self.K, c2w_sample[i],sample_pixels[i])
      pixels[(N//M)*i:(N//M)*(i+1)] = image_sample[i,sample_pixels[i][:,1],sample_pixels[i][:,0]]
    return rays_o, rays_d, pixels


In [None]:
%pip install viser

In [None]:
import viser, time  # pip install viser
import numpy as np

# --- You Need to Implement These ------
dataset = RaysData(images_train, K, c2ws_train)
rays_o, rays_d, pixels = dataset.sample_rays(100)
points = dataset.sample_along_rays(rays_o, rays_d, perturb=True)
H, W = images_train.shape[1:3]
# ---------------------------------------

server = viser.ViserServer(share=True)
for i, (image, c2w) in enumerate(zip(images_train, c2ws_train)):
    server.add_camera_frustum(
        f"/cameras/{i}",
        fov=2 * np.arctan2(H / 2, K[0, 0]),
        aspect=W / H,
        scale=0.15,
        wxyz=viser.transforms.SO3.from_matrix(c2w[:3, :3]).wxyz,
        position=c2w[:3, 3],
        image=image
    )
for i, (o, d) in enumerate(zip(rays_o, rays_d)):
    server.add_spline_catmull_rom(
        f"/rays/{i}", positions=np.stack((o, o + d * 6.0)),
    )
server.add_point_cloud(
    f"/samples",
    colors=np.zeros_like(points).reshape(-1, 3),
    points=points.reshape(-1, 3),
    point_size=0.02,
)
time.sleep(1000)

### Parts 2.4 & 2.5: Neural Radiance Field and Volume Rendering

In [None]:
class NeRF(nn.Module):
  def __init__(self):
    super(NeRF, self).__init__()
    self.centerMLP1 = nn.Sequential(
        nn.Linear(63, 256),
        nn.ReLU(),
        nn.Linear(256,256),
        nn.ReLU(),
        nn.Linear(256,256),
        nn.ReLU(),
        nn.Linear(256,256),
        nn.ReLU(),
    )

    self.centerMLP2 = nn.Sequential(
        nn.Linear(319,256),
        nn.ReLU(),
        nn.Linear(256,256),
        nn.ReLU(),
        nn.Linear(256,256),
        nn.ReLU(),
        nn.Linear(256,256)
    )

    self.densityMLP = nn.Sequential(
        nn.Linear(256,1),
        nn.ReLU(),
    )

    self.rgbMLP1 = nn.Linear(256,256)
    self.rgbMLP2 = nn.Sequential(
        nn.Linear(283,128),
        nn.ReLU(),
        nn.Linear(128,3),
        nn.Sigmoid()
    )

  def forward(self,x,r_d):
    B,N,_ = x.size()
    x = self.get_PE(x,10)
    r_d = self.get_PE(r_d,4)
    # Feed forward
    x = torch.cat([self.centerMLP1(x),x],dim =-1)
    x = self.centerMLP2(x)

    # Go through density branch
    density = self.densityMLP(x)

    # Go through RGB branch
    rgb = self.rgbMLP1(x)
    r_d = r_d.unsqueeze(1).repeat(1,N,1)
    rgb = torch.cat([rgb,r_d], dim=-1)
    rgb = self.rgbMLP2(rgb)

    # Volume Rendering
    step_size = (6.0 - 2.0) / 64
    return self.volrend(density,rgb,step_size)

  def get_PE(self, x, L):
    sin_dims = torch.cat([torch.sin(((2**i)*np.pi)*x) for i in range(L)],dim=-1)
    cos_dims = torch.cat([torch.cos(((2**i)*np.pi)*x) for i in range(L)],dim=-1)
    x_PE = torch.cat([x,sin_dims, cos_dims],dim=-1)
    return x_PE

  def volrend(self, sigmas, rgbs, step_size):
    B = sigmas.size(0)
    sigmas = torch.cat([torch.zeros(B,1,1).to(device),sigmas],dim=1)
    T_i = torch.exp(-torch.cumsum(sigmas[:,:-1,:], dim=1)*step_size)
    exp_term = 1-torch.exp(-sigmas[:,1:,:]*step_size)
    #prod = T_i * exp_term * torch.linspace(1,0,64).unsqueeze(0).unsqueeze(2).repeat(B,1,1).to(device) #depth rendering
    prod = T_i * exp_term * rgbs
    rend = torch.sum(prod, dim=1)
    color = torch.Tensor([0,0.667,1]).to(device)
    rend = rend + (1-torch.sum(T_i*exp_term,dim=1))*color # bg color
    return rend

### Test Volume Rendering

In [None]:
import torch
torch.manual_seed(42)
sigmas = torch.rand((10, 64, 1))
rgbs = torch.rand((10, 64, 3))
step_size = (6.0 - 2.0) / 64
rendered_colors = volrend(sigmas, rgbs, step_size)

correct = torch.tensor([
    [0.5006, 0.3728, 0.4728],
    [0.4322, 0.3559, 0.4134],
    [0.4027, 0.4394, 0.4610],
    [0.4514, 0.3829, 0.4196],
    [0.4002, 0.4599, 0.4103],
    [0.4471, 0.4044, 0.4069],
    [0.4285, 0.4072, 0.3777],
    [0.4152, 0.4190, 0.4361],
    [0.4051, 0.3651, 0.3969],
    [0.3253, 0.3587, 0.4215]
  ])
print(rendered_colors)
assert torch.allclose(rendered_colors, correct, rtol=1e-4, atol=1e-4)

### Train NeRF

In [None]:
from tqdm import tqdm

def trainNeRF(model, train_loader):
  model.train()
  PSNR_loss = lambda input, target: 10*torch.log10(nn.functional.mse_loss(input, target))
  optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
  val_losses = []
  ims = []
  steps = [1,50,100,200,500,1000]
  target = torch.Tensor(images_val[0].reshape(200*200,3))

  for i, (rays_d, points, pixels) in enumerate(tqdm(train_loader)):
      rays_d = rays_d[0].to(device).to(torch.float32)
      points = points[0].to(device).to(torch.float32)
      pixels = pixels[0].to(device).to(torch.float32)
      outputs = model(points, rays_d)
      loss = PSNR_loss(outputs, pixels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      #im = torch.Tensor(renderImage(model,200,200,K,c2ws_val[0]))
      # if i+1 in steps:
      #   ims.append(im)
      # val_losses.append(PSNR_loss(im.reshape(200*200,3),target))
      print(loss.item())

  return val_losses,ims

In [None]:
lego_dataset = RaysData(images_train,K,c2ws_train,50,5000,True)
lego_train_loader = DataLoader(lego_dataset)

In [None]:
model = NeRF().to(device)

In [None]:
lego_losses, lego_ims = trainNeRF(model, lego_train_loader)

In [None]:
plt.plot(range(1,1001),-np.array(lego_losses),'-bo',markersize=0.5)
plt.xlabel('Number of Training Iterations')
plt.ylabel('PSNR')
plt.title('Lego Image Validation PSNR')
plt.show()

In [None]:
fig,axes = plt.subplots(1,len(lego_ims),figsize=(30,30))
plt.subplots_adjust(wspace=0.05)
for i in range(len(lego_ims)):
  axes[i].imshow(lego_ims[i])
  axes[i].axis('off')
plt.show()

In [None]:
torch.save(model.state_dict(), 'model2.pt')

### Run NeRF Model on Validation Set

In [None]:
val_dataset = RaysData(images_val,K,c2ws_val, 10, False)
val_loader = DataLoader(val_dataset)

In [None]:
from tqdm import tqdm
def testNeRF(model, val_loader):
  model.eval()
  PSNR_loss = lambda input, target: 10*torch.log10(nn.functional.mse_loss(input.float(), target.float()))

  with torch.no_grad():
    for i, (rays_d, points, pixels) in enumerate(tqdm(val_loader)):
        rays_d = rays_d[0].to(torch.float32).to(device)
        points = points[0].to(torch.float32).to(device)
        pixels = pixels[0].to(torch.float32).to(device)
        outputs = model(points, rays_d)
        loss = PSNR_loss(outputs, pixels)
        print(loss.item())


In [None]:
model = NeRF().to(device)
model.load_state_dict(torch.load("model2.pt"))

In [None]:
testNeRF(model, val_loader)

### Render Lego Video

In [None]:
model = NeRF().to(device)
model.load_state_dict(torch.load("model.pt"))

In [None]:
import cv2

# Function to render a single image
def renderImage(model, width, height, K, c2w):
  w_step = width // 4
  h_step = height // 4
  image = np.ones((height,width,3))
  for i in range(4):
    for j in range(4):
      coords = np.array(np.meshgrid(np.arange(i*w_step,(i+1)*w_step),np.arange(j*h_step,(j+1)*h_step)))
      coords = coords.reshape(-1,coords.shape[1]*coords.shape[2]).T
      rays_o, rays_d = pixel_to_ray(K,c2w,coords)
      X = sample_along_rays(rays_o, rays_d, False)
      X = torch.Tensor(X).to(device)
      rays_d = torch.Tensor(rays_d).to(device)
      #imagePatch = model(X,rays_d).cpu().detach().numpy().reshape(height//4,width//4,1) #depth rendering
      imagePatch = model(X,rays_d).cpu().detach().numpy().reshape(height//4,width//4,3)
      image[j*h_step:(j+1)*h_step,i*w_step:(i+1)*w_step] = imagePatch
  return image

def renderFrames(model, width, height, K, c2ws):
  ims = []
  for i in range(len(c2ws)):
    image = renderImage(model, width, height, K, c2ws[i])
    ims.append(sk.img_as_ubyte(image))
  return ims

def createVideo(name, images):
  video_writer = cv2.VideoWriter(name, cv2.VideoWriter_fourcc(*'mp4v'),10,(images[0].shape[1], images[0].shape[0]))
  for im in images:
      video_writer.write(im)
  video_writer.release()

In [None]:
## Normal video
# Create frames
lego_render_ims = renderFrames(model,200,200,K,c2ws_test)
lego_render_ims = [im[:,:,::-1] for im in lego_render_ims]

# Write video
createVideo("lego_depth.mp4", lego_render_ims)

In [None]:
## Depth video
# Create frames
lego_depth_render_ims = renderFrames(model,200,200,K,c2ws_test)
lego_depth_render_ims = [im[:,:,::-1] for im in lego_depth_render_ims]

# Write video
createVideo("lego_depth.mp4", lego_depth_render_ims)

In [None]:
## BG Color video
# Create frames
lego_color_render_ims = renderFrames(model,200,200,K,c2ws_test)
lego_color_render_ims = [im[:,:,::-1] for im in lego_color_render_ims]

# Write video
createVideo("lego_color.mp4", lego_color_render_ims)

In [None]:
im = renderImage(model,200,200,K,c2ws_test[1])

In [None]:
plt.imshow(im)