# Code Practice : NeRF

## Importing Libraries

In [5]:
import os
from PIL import Image
import numpy as np 
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt 

## Preprocessing Data

- 데이터셋은 tiny_nerf_data를 씁니다.


- 데이터셋은 아래의 링크에서 다운로드 받을 수 있습니다.


- [tiny_nerf_data](http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz)




In [3]:
# Specify the environment path
PATH = 'C:/Users/user/anaconda3/envs/NeRF' # check your development environment path

# Load tiny_nerf_data
tiny_nerf_data = np.load(os.path.join(PATH, 'tiny_nerf_data.npz'))

# Check the array name of tiny_nerf_data
print('The array name of tiny_nerf_data: ', tiny_nerf_data.files)

# Check the array shape of tiny_nerf_data
for name in tiny_nerf_data.files:
    print(name,':', tiny_nerf_data[name].shape)

# Set the device
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')

# Define variables of dataset
images = tiny_nerf_data['images']
poses = tiny_nerf_data['poses']
focal = tiny_nerf_data['focal']

# Get the number of images, and the length of height and width
num_images, height, width = images.shape[:-1]

# Split the dataset into training set and test set 
test_idx = 101
test_image = images[test_idx]
test_pose = poses[test_idx]

# Move training variables to the device
images = torch.from_numpy(images[:100, ... , :3]).to(device)
poses = torch.from_numpy(poses).to(device)
focal = torch.from_numpy(focal).to(device)

The array name of tiny_nerf_data:  ['images', 'poses', 'focal']
images : (106, 100, 100, 3)
poses : (106, 4, 4)
focal : ()


## Utility Functions

In [4]:
# Compute the origin and direction vectors 
def get_rays(height, width, focal, pose) : 
    '''
    Inputs:
        height - Int. the height of an image
        width - Int. the width of an image
        focal - Float. focal length of the camera
        pose - torch.Tensor. the pose of an image 
        
    Outputs:
        rays_o - torch.Tensor. origin vector of the ray
        rays_d - torch.Tensor. direction vector of the ray
    '''
    
    # Use torch.meshgrid to build a meshgrid of size (height X width)
    i, j = torch.meshgrid(torch.arange(width, dtype = torch.float32).to(pose), 
                          torch.arange(height, dtype = torch.float32).to(pose), 
                          indexing = 'ij')
    
    # Use .transpose method to reshape the meshgrid
    i, j = i.transpose(-1, -2), j.transpose(-1, -2)
    
    # Calculate the x,y,z coordinates of a direction vector, ray_d
    # Measure the distance between current pixel coordinates and the center of the image on x,y axis.
    # Normalize x,y coordinates and set a z coordinate to 1. 
    rays_d = torch.stack([i - weight * .5 / focal, -(j - height * .5) / focal, torch.ones_like(i)], dim = -1)
    
    # Multiply rays_d with the camera pose to rotate w.r.t. world coordinates
    rays_d = torch.sum(rays_d[..., None, :] * pose[:3, :3], dim = -1) # pose[:3, :3] is the rotation part.
    
    rays_o = pose[:3, -1].expand(rays_d.shape)
    
    return rays_o, rays_d

In [10]:
# Positional Encoding
def positional_encoding(p, L):
    '''
    Inputs:
        p - torch.Tensor. p can be 3 coordinate values in vector x or Catersian viewing direction unit vector d.
            p lies in [-1, 1].
        L - Int. Dimensionality of positional encoding.
    
    Output:
        gamma_p - torch.Tensor. The positional encoding of p.
    '''
    # Define the list that saves positional encoding parameters
    gamma_p = []
    
    # Define the frequency that maps to higher dimensional space
    frequency = 2.0 ** torch.linspace(0, L-1, L, dtype = p.dtype, device = p.device)
    
    # Iterate for the number of frequency to append related positional encoding
    for freq in frequency:
        gamma_p.append(freq * torch.pi * p)
        gamma_p.append(freq * torch.pi * p)
        
    gamma_p = torch.concat(gamma_p, dim = -1)
    
    return gamma_p

In [11]:
p = torch.Tensor([-0.51, 0.125, 0.68])
print(positional_encoding(p, L = 4))

tensor([ -1.6022,   0.3927,   2.1363,  -1.6022,   0.3927,   2.1363,  -3.2044,
          0.7854,   4.2726,  -3.2044,   0.7854,   4.2726,  -6.4088,   1.5708,
          8.5451,  -6.4088,   1.5708,   8.5451, -12.8177,   3.1416,  17.0903,
        -12.8177,   3.1416,  17.0903])


In [1]:
# Stratified Sampling
def stratified_sampling(rays_o, rays_d, t_n, t_f, N):
    '''
    Inputs :
        rays_o - torch.Tensor. Shape of (width, height, 3). origin vector of the camera ray
        rays_d - torch.Tensor. Shape of (width, height, 3). direction vector of the camera ray
        t_n - Float. the nearest boundary point of the camera ray
        t_f - Float. the farthest boundary point of the camera ray
        N - Int. the number of bins to partition [t_n, t_f]
        
    Outputs : 
        t - torch.Tensor. Shape of (N,). linspace divided into N intervals. 
        x - torch.Tensor. Shape of (width, height, N, 3). the vector that satisfies 'x = o + td'
    '''
    
    # Partition the interval [t_n, t_f] by N and move the variable to the device.
    t = torch.linspace(t_n, t_f, N).to(rays_o)
    
    # Generate random noise that matches a given shape (width, height, N) and a given boundary [t_n, t_f]
    noise = torch.rand(rays_o.shape[:-1] + (N,)) * (t_f - t_n) / N 
    
    # Add the noise to t.
    t += noise
    
    # Calculate the equation of r and match the shape of (width, height, N, 3).
    x = rays_o[..., None, :] + t[..., None] * rays_d[..., None, :]
    
    return x, t

In [2]:
# Classical Volume Rendering
def classical_volume_rendering(rays_o, t, sigma, c):
    '''
    Inputs:
        rays_o - torch.Tensor. Shape of (width, height, 3). origin vector of the camera ray
        t - torch.Tensor. Shape of (N,). linspace divided into N intervals
        sigma - torch.Tensor. Tensor of volume density
        c - torch.Tensor. Tensor of 
        
    Outputs : 
        rgb - torch.Tensor. rendered color by classical volume rendering
    
    '''
    # Use ReLU function to ensure that the volume density can be non-negative.
    sigma = F.relu(sigma)[..., 0]
    
    # Use sigmoid function to ensure that the color vector, c can be in [0, 1].
    c = torch.sigmoid(c)
    
    # Calculate the distance between adjacent samples, delta.
    delta = t[1:, ...] - t[:-1, ...]
    
    # Add the endpoint of the ray to delta and assume that the endpoint of the ray has a very large value.
    delta = torch.cat([delta, torch.Tensor([1e10], dtype = rays_o.dtype, device = rays_o.device)], dim = -1)
    
    # Calculate the alpha compositing value, alpha. 
    alpha = 1. - torch.exp(-sigma * delta)
    
    # Product (1 - alpha) cumulatively and prevent multiplication by 0. 
    T = torch.cumprod(1. - alpha + 1e-10, -1)
    
    # Use torch.roll to remove the 'i'th multiplied case.
    T = torch.roll(T, 1, -1)
    
    # Fill the first case of T with 1. 
    T[..., 0] = 1.
    
    # Multiply T with alpha before aligining the shape with c. 
    w = T * alpha 
    
    rgb = (w[..., None] * c).sum(dim = -2)
    
    return rgb

## Build Model

In [18]:
class NeRF(nn.Module):
    def __init__(self, dim_x = 60, dim_d = 24, num_layers = 8, num_channels = 256, skip_connection = 4):
        super().__init__()
        
        layers = []
        
        for i in range(num_layers):
            if i == 0:
                layers.append(nn.Linear(dim_x, num_channels))
            elif i == skip_connection : 
                layers.append(nn.Linear(num_channels + dim_x, num_channels))
            else :
                layers.append(nn.Linear(num_channels, num_channels))
        
        self.layers = nn.ModuleList(layers)
        self.fc1 = nn.Linear(num_channels, num_channels)
        self.fc2 = nn.Linear(num_channels + dim_d, num_channels // 2)
        
        self.sigma = nn.Linear(num_channels, 1)
        self.rgb = nn.Linear(num_channels // 2, 3)
        
        self.skip_connection = skip_connection
                
    
    def forward(self, x, d):
        out = x 
        for i, layer in enumerate(self.layers):
            if i == skip_connection:
                out = torch.cat([out, x], dim = -1)
                out = self.layers[i](out)
                out = F.relu(out)
            else :
                out = self.layers[i](out)
                out = F.relu(out)
        
        sigma = self.sigma(out)
        
        out = self.fc1(out)
        out = torch.concat([out, d], dim = -1)
        out = self.fc2(out)
        out = F.relu(out)
        rgb = self.rgb(out)
        
        return rgb, sigma
        

## Defining Hyperparameters

In [14]:
# hyperparameters of positional encoding
L_X = 10
L_D = 4 

# hyperparameters of stratified sampling
T_N = 2.
T_F = 6.
N = 32

# hyperparameters of NeRF model 
NUM_LAYERS = 8 
NUM_CHANNELS = 256
SKIP_CONNECTION = 4 

# hyperparameters of training
EPOCHS = 10000
BATCH_SIZE = 4096

# Learning rate of the Adam optimizer
LR = 5e-4

## Training

In [12]:
def train(model, height, width, focal, pose, T_N, T_F, N, L_X, L_D):
    '''
    Inputs :
        height - Int. the height of an image
        width - Int. the width of an image
        focal - Float. focal length of the camera
        pose - torch.Tensor. the pose of an image 
        T_N - Float. the nearest boundary point of the camera ray
        T_F - Float. the farthest boundary point of the camera ray
        N - Int. the number of bins to partition [T_N, T_F] 
        L_X - Int. the dimensionality of positional encoding for x
        L_D - Int. the dimensionality of positional encoding for d  
    
    Output :
        rgb - torch.Tensor. rendered color by classical volume rendering
    '''
    rays_o, rays_d = get_rays(height, width, focal, pose)
    
    x, t = stratified_sampling(rays_o, rays_d, t_n, t_f, N)
    
    # Flatten the ray vector and directional vector to have a shape of (width * height * N , 3)
    x_flatten = x.reshape(-1, 3) # x has a shape of (width, height, N, 3)
    d_flatten = rays_d[..., None, :].expand_as(x).reshape(-1, 3) # d has a shape of (width, height, 3)
    
    gamma_x = positional_encoding(x_flatten, L_X)
    gamma_d = positional_encoding(d_flatten, L_D)
    
    # Set the list to save predictions
    preds = []
    
    # Slice the training data by batch size, predict the model's outputs, and append them to the list
    for i in range(0, gamma_x.shape[0], batch_size):
        preds.append(model(gamma_x[i : i + batch_size], gamma_d[i : i + batch_size]))
    
    # Define color and sigma variables from predictions
    color = torch.concat([pred[0] for pred in preds], dim = 0).reshape(height, width, -1, 3)
    sigma = torch.concat([pred[1] for pred in preds], dim = 0).reshape(height, width, -1, 1)
    
    rgb = classical_volume_rendering(color, sigma, rays_o, t)
    
    return rgb

In [20]:
gamma_x_dim = 3 * 2 * L_X
gamma_d_dim = 3 * 2 * L_D

model = NeRF(dim_x = gamma_x_dim, dim_d = gamma_d_dim, \
             num_layers = NUM_LAYERS, num_channels = NUM_CHANNELS, skip_connection = SKIP_CONNECTION)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr = LR)

seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)

pbar = tqdm(range(num_iters))
for i in pbar:
    idx = np.random.randint(images.shape[0])
    image_i = images[idx]
    pose_i = poses[idx]
    
    rgb_predicted = train(model, height, width, focal, pose, T_N, T_F, N, L_X, L_D)
    loss = F.mse_loss(rgb_predicted, image_i)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()


NameError: name 'device' is not defined