In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import math
import time
from model import NeRF
from datasets import StanfordDragonDataset

In [2]:
# Instantiate dataset & model 
cuda = False
dataset = StanfordDragonDataset("./datasets/dragon")
device = torch.device("cuda") if (cuda) else torch.device("cpu")
min_bounds = torch.Tensor([-10,-10,-10]).to(device)
max_bounds = torch.Tensor([10, 10, 10]).to(device)
model = NeRF(device, min_bounds, max_bounds)
if (cuda): model.cuda()
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total model parameters: %d" % total_params)
print("Total images: %d" % len(dataset))

Total model parameters: 593924
Total images: 400


In [None]:
# Training variables
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
iterations = 1
rays_per_batch = 2**11
num_samples = 256
all_positions = []
all_directions = []
all_gt_colors = []

# Gather all rays
for i in range(len(dataset)):
    image, pose, focal = dataset[i]
    positions, directions, gt_colors = model.get_rays(image, pose, focal)
    all_positions.append(positions)
    all_directions.append(directions)
    all_gt_colors.append(gt_colors)

# Concatenate all rays
all_positions = torch.cat(all_positions, dim=0)
all_directions = torch.cat(all_directions, dim=0)
all_gt_colors = torch.cat(all_gt_colors, dim=0)

# Shuffle rays
shuffle = torch.randperm(all_positions.shape[0])
all_positions = all_positions[shuffle]
all_directions = all_directions[shuffle]
all_gt_colors = all_gt_colors[shuffle]

rays_per_iteration = all_positions.shape[0]

In [None]:
# Training loop
for i in range(iterations):
    current_idx = 0
    losses = []
    while(current_idx < rays_per_iteration):
        optimizer.zero_grad()
        indices = torch.arange(current_idx, min(all_positions.shape[0], current_idx + rays_per_batch))
        positions = all_positions[indices].to(device)
        directions = all_directions[indices].to(device)
        colors, depths, weights = model.render_rays(positions, directions, num_samples)
        gt = all_gt_colors[indices].to(device)
        loss = torch.mean(torch.square(colors - gt))
        loss.backward()
        current_idx += rays_per_batch
        optimizer.step()
        print('iteration: %d, loss: %.4f, ray count: %.2f%%' % (i, loss.item(), 100 * current_idx / rays_per_iteration))
    torch.save(model.state_dict(), "model-%d.pth" % i)    

In [None]:
# Render testing data
cuda = True
device = torch.device("cuda") if (cuda) else torch.device("cpu")
min_bounds = torch.Tensor([-10,-10,-10]).to(device)
max_bounds = torch.Tensor([10, 10, 10]).to(device)
pth_model = NeRF(device, min_bounds, max_bounds)
pth_model.load_state_dict(torch.load("model-0.pth"))
if (cuda): pth_model.cuda()
pth_model.eval()
testing_poses = torch.from_numpy(np.load("./testing_poses.npy"))
dummy_image, dummy_pose, focal = dataset[0]

def render_image(idx):
    with torch.no_grad():
        pose = testing_poses[idx]
        positions, directions, gt_colors = pth_model.get_rays(dummy_image, pose, focal)
        current_idx = 0
        color = torch.zeros(dummy_image.shape).reshape((-1, 3))
        depth = torch.zeros(dummy_image.shape).reshape((-1, 3))
        while(current_idx < positions.shape[0]):
            indices = torch.arange(current_idx, min(positions.shape[0], current_idx + rays_per_batch))
            pos = positions[indices].to(device)
            dirs = directions[indices].to(device)
            gt = gt_colors[indices].to(device)
            colors, depths, _ = pth_model.render_rays(pos, dirs, num_samples)
            color[indices] = colors.float().cpu()
            depth[indices] = depths[..., None].float().cpu()
            current_idx += rays_per_batch
        color = color.reshape(dummy_image.shape) * 255
        depth = depth.reshape(dummy_image.shape) * 255
        pil_color = Image.fromarray(color.numpy().astype(np.uint8))
        pil_color.save("test-color-%d.jpeg" % idx)
        pil_depth = Image.fromarray(depth.numpy().astype(np.uint8))
        pil_depth.save("test-depth-%d.jpeg" % idx)
        
for i in range(testing_poses.shape[0]): render_image(i)