In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
import pickle

In [2]:
# Enable inline plotting
%matplotlib inline

# Device check
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cpu


In [3]:
# Load training and testing data
with open('nerf_datasets/training_data.pkl', 'rb') as f:
    training_dataset = torch.from_numpy(np.load(f, allow_pickle=True)).to(device)

with open('nerf_datasets/testing_data.pkl', 'rb') as f:
    testing_dataset = torch.from_numpy(np.load(f, allow_pickle=True)).to(device)

In [4]:
training_dataset.shape

torch.Size([16000000, 9])

In [5]:
testing_dataset.shape

torch.Size([32000000, 9])

In [6]:
training_dataset[0] #ray origins(0-2), directions(3-5), and ground truth pixel values(6-8)

tensor([-0.0538,  3.8455,  1.2081,  0.3340, -0.9418,  0.0390,  1.0000,  1.0000,
         1.0000])

In [12]:
# Extract ray origins, directions, and ground truth pixel values
rays_train = training_dataset[:, :6]
rgbs_train = training_dataset[:, 6:]

rays_test = testing_dataset[:, :6]
rgbs_test = testing_dataset[:, 6:]

# Create TensorDataset and DataLoader
train_dataset = TensorDataset(rays_train, rgbs_train)
test_dataset = TensorDataset(rays_test, rgbs_test)
loader = DataLoader(train_dataset, batch_size=8192, shuffle=True)

# NERF Model Architecture 

In [13]:
#Architecture with less layers (NEED NOT RUN)
class VolumetricRendererModel(nn.Module):
    def __init__(self, pos_encoding_dim=10, dir_encoding_dim=4, hidden_layer_size=128):
        super(VolumetricRendererModel, self).__init__()

        self.layer_block1 = nn.Sequential(
            nn.Linear(pos_encoding_dim * 6 + 3, hidden_layer_size), nn.ReLU(),
            nn.Linear(hidden_layer_size, hidden_layer_size), nn.ReLU(),
            nn.Linear(hidden_layer_size, hidden_layer_size), nn.ReLU(),
            nn.Linear(hidden_layer_size, hidden_layer_size), nn.ReLU()
        )
        
        self.layer_block2 = nn.Sequential(
            nn.Linear(pos_encoding_dim * 6 + hidden_layer_size + 3, hidden_layer_size), nn.ReLU(),
            nn.Linear(hidden_layer_size, hidden_layer_size), nn.ReLU(),
            nn.Linear(hidden_layer_size, hidden_layer_size), nn.ReLU(),
            nn.Linear(hidden_layer_size, hidden_layer_size + 1)
        )
        
        self.layer_block3 = nn.Sequential(
            nn.Linear(dir_encoding_dim * 6 + hidden_layer_size + 3, hidden_layer_size // 2), nn.ReLU()
        )
        
        self.layer_block4 = nn.Sequential(
            nn.Linear(hidden_layer_size // 2, 3),nn.Sigmoid()
        )

    @staticmethod
    def encode_position(x, levels):
        encoded = [x]
        for i in range(levels):
            encoded.append(torch.sin(2 ** i * x))
            encoded.append(torch.cos(2 ** i * x))
        return torch.cat(encoded, dim=1)
    
    
    def forward(self, ray_origin, ray_direction):
        encoded_origin = self.encode_position(ray_origin, 10)
        encoded_direction = self.encode_position(ray_direction, 4)
        intermediate = self.layer_block1(encoded_origin)
        density_output = self.layer_block2(torch.cat((intermediate, encoded_origin), dim=1))
        intermediate, density = density_output[:, :-1], nn.ReLU()(density_output[:, -1])
        intermediate = self.layer_block3(torch.cat((intermediate, encoded_direction), dim=1))
        color_output = self.layer_block4(intermediate)
        return color_output, density

# Using NeRF model to compute integrated colors

In [14]:
def compute_transmittance(alphas):
    """
    Computes the accumulated transmittance for each sample along the ray.
    """
    transmittance = torch.cumprod(alphas, dim=1)
    return torch.cat((torch.ones((transmittance.shape[0], 1), device=alphas.device), transmittance[:, :-1]), dim=-1)

def process_rays(model, origins, directions, near_plane=0, far_plane=0.5, num_bins=192):
    """
    Samples points along rays and uses a NeRF model to compute integrated colors.
    
    Args:
        model (nn.Module): The NeRF model to predict RGB colors and densities.
        origins (torch.Tensor): Ray origins [batch_size, 3].
        directions (torch.Tensor): Ray directions [batch_size, 3].
        near_plane (float): The starting distance for sampling along the ray.
        far_plane (float): The ending distance for sampling along the ray.
        num_bins (int): The number of sample points along each ray.
    
    Returns:
        torch.Tensor: Integrated color along each ray.
    """
    device = origins.device
    
    # Generate depth samples along each ray
    t_samples = torch.linspace(near_plane, far_plane, num_bins, device=device).expand(origins.shape[0], num_bins)
    delta = torch.cat((t_samples[:, 1:] - t_samples[:, :-1], torch.tensor([1e10], device=device).expand(origins.shape[0], 1)), dim=-1)

    # Compute 3D points along each ray
    points = origins.unsqueeze(1) + t_samples.unsqueeze(2) * directions.unsqueeze(1)
    
    # Predict colors and densities from the model
    rgb_colors, densities = model(points.reshape(-1, 3), directions.expand(num_bins, directions.shape[0], 3).transpose(0, 1).reshape(-1, 3))
    rgb_colors = rgb_colors.reshape(origins.shape[0], num_bins, 3)
    densities = densities.reshape(origins.shape[0], num_bins)

    # Calculate alphas using densities and delta
    alphas = 1 - torch.exp(-densities * delta)
    
    # Calculate weights for each sample based on transmittance and alpha
    weights = compute_transmittance(1 - alphas).unsqueeze(2) * alphas.unsqueeze(2)
    
    # Integrate colors along the ray by summing weighted colors
    integrated_color = (weights * rgb_colors).sum(dim=1)

    return integrated_color

# Train Model

In [15]:
def train_model(model, optimizer, scheduler, data_loader, epochs=5, near_plane=2, far_plane=6):
    model.train()
    losses = []
    for epoch in tqdm(range(epochs)):
        print(f"start epoch: ", {epoch})
        for rays, rgbs in data_loader:
            origins = rays[:, :3].to(device)
            directions = rays[:, 3:6].to(device)
            true_colors = rgbs.to(device)

            generated_colors = process_rays(model, origins, directions, near_plane, far_plane)
            loss = ((true_colors - generated_colors) ** 2).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        scheduler.step()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")
    return losses

In [17]:
#initialise mode, Optimizer, and Scheduler
model = VolumetricRendererModel().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2, 4, 8], gamma=0.5)

training_loss = train_model(model, optimizer, scheduler, loader, epochs=16)

In [None]:
# Plot training losses
plt.plot(training_loss)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.show()

# Save the Model

In [None]:
torch.save(model.state_dict(), 'VRM_model.pth')

# Load the Model

In [None]:
# Define the model architecture
# model = NerfModel(hidden_dim=256)  # Ensure this matches the saved model's architecture
model = VolumetricRendererModel(pos_encoding_dim=10, dir_encoding_dim=4, hidden_layer_size=128)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

# Load the saved model weights
model.load_state_dict(torch.load('VRM_model_105.pth', map_location=device))

# Set the model to evaluation mode
model.eval()
print("Model loaded successfully and set to evaluation mode.")

# Test the Model

In [None]:
# Define the test DataLoader if not defined
test_loader = DataLoader(test_dataset, batch_size=4096, shuffle=False)

In [None]:
# Evaluation Function
def evaluate_model(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    count = 0
    with torch.no_grad():  # Disable gradient computation
        for rays, rgbs in data_loader:
            origins = rays[:, :3].to(device)
            directions = rays[:, 3:6].to(device)
            true_colors = rgbs.to(device)
            predicted_colors = process_rays(model, origins, directions)

            loss = ((true_colors - predicted_colors) ** 2).mean()
            total_loss += loss.item()
            count += 1
    
    average_loss = total_loss / count
    return average_loss

# Test the model
test_loss = evaluate_model(model, test_loader)
print(f"Average Test Loss: {test_loss}")

# Display the Model Images

In [None]:
def display_results(model, data_loader, image_size=(400, 400), batch_size=4096, num_samples=1):
    """
    Processes rays in smaller batches and accumulates them to display a full image.

    Args:
        model (nn.Module): The NeRF model.
        data_loader (DataLoader): DataLoader for test dataset.
        image_size (tuple): The target image resolution (height, width).
        batch_size (int): The batch size for processing rays.
        num_samples (int): Number of images to display.
    """
    model.eval()
    H, W = image_size  # Unpack height and width
    num_rays = H * W   # Total number of rays in one image

    with torch.no_grad():
        accumulated_true_colors = []
        accumulated_predicted_colors = []
        rays_accumulated = 0  # Counter for accumulated rays
        images_displayed = 0  # Counter for the number of images displayed

        for i, (rays, rgbs) in enumerate(data_loader):
            # Check if weâ€™ve displayed the required number of images
            if images_displayed >= num_samples:
                break
            
            origins = rays[:, :3].to(device)
            directions = rays[:, 3:6].to(device)
            true_colors = rgbs.to(device)
            predicted_colors = process_rays(model, origins, directions)
            
            # Accumulate the rays and colors
            accumulated_true_colors.append(true_colors)
            accumulated_predicted_colors.append(predicted_colors)
            rays_accumulated += true_colors.shape[0]

            # Debugging output to track accumulation
#             print(f"Batch {i+1}: Accumulated rays = {rays_accumulated}/{num_rays}")

            # Once we have enough rays to form a full image
            if rays_accumulated >= num_rays:
                # Concatenate accumulated results and reshape
                full_true_colors = torch.cat(accumulated_true_colors, dim=0)[:num_rays].view(H, W, 3).cpu()
                full_predicted_colors = torch.cat(accumulated_predicted_colors, dim=0)[:num_rays].view(H, W, 3).cpu()

                # Display the full image
                plt.figure(figsize=(12, 6))
                plt.subplot(1, 2, 1)
                plt.imshow(full_true_colors)
                plt.title("Ground Truth")
                plt.subplot(1, 2, 2)
                plt.imshow(full_predicted_colors)
                plt.title("Predicted")
                plt.show()
                
                # Reset accumulators for the next image
                accumulated_true_colors = []
                accumulated_predicted_colors = []
                rays_accumulated = 0
                images_displayed += 1  # Increment image display counter

In [None]:
# Use DataLoader with batch_size=4096
test_loader = DataLoader(test_dataset, batch_size=4096, shuffle=False)

# Display some results
display_results(model, test_loader, image_size=(400, 400), batch_size=4096, num_samples=1)