In [1]:
import math
import os

from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
os.makedirs('out/', exist_ok=True)
os.makedirs('iterations/', exist_ok=True)

In [3]:
def transform(c2w, x_c):

    B, H, W, _ = x_c.shape

    # Add a homogeneous coordinate (1) to the points in camera space
    x_c_homogeneous = torch.cat([
        x_c, torch.ones(B, H, W, 1, device=x_c.device)
    ], dim=-1)  # Shape: (B, H, W, 4)

    # Reshape for batched matrix multiplication
    x_c_flat = x_c_homogeneous.view(B, -1, 4).permute(0, 2, 1)  # Shape: (B, 4, H*W)

    # Apply transformation
    x_w_flat = c2w.bmm(x_c_flat)  # Shape: (B, 4, H*W)

    # Reshape back to image grid
    x_w_homogeneous = x_w_flat.permute(0, 2, 1).view(B, H, W, 4)  # Shape: (B, H, W, 4)

    x_w = x_w_homogeneous[..., :3]  # Shape: (B, H, W, 3)

    return x_w


def intrinsic_matrix(fx, fy, ox, oy):

    K = torch.tensor([
        [fx,  0.0, ox],
        [0.0,  fy,  oy],
        [0.0,  0.0,  1.0]
    ], dtype=torch.float32)
    return K


def pixel_to_camera(K, uv, s):

    B, H, W, _ = uv.shape

    # Reshape pixel coordinates for matrix multiplication
    uv_flat = uv.view(B, -1, 3).permute(0, 2, 1)  # Shape: (B, 3, H*W)

    # Create homogeneous pixel coordinates
    uv_homogeneous = torch.cat([
        uv_flat[:, 1:],  # Take v and u in correct order
        torch.ones((B, 1, H * W), device=uv.device)  # Add homogeneous coordinate (1)
    ], dim=1)  # Shape: (B, 3, H*W)

    # Invert the intrinsic matrix
    K_inv = torch.inverse(K)  # Shape: (B, 3, 3)

    # Transform pixel coordinates to camera coordinates
    uv_homogeneous = torch.stack((uv_homogeneous[:, 1], uv_homogeneous[:, 0], uv_homogeneous[:, 2]), dim=1)
    x_c_homogeneous_flat = K_inv.bmm(uv_homogeneous)  # Shape: (B, 3, H*W)

    # Reshape back to grid
    x_c_homogeneous = x_c_homogeneous_flat.permute(0, 2, 1).view(B, H, W, 3)  # Shape: (B, H, W, 3)

    # Scale the camera coordinates
    x_c = x_c_homogeneous * s  # Shape: (B, H, W, 3)

    return x_c


def pixel_to_ray(K, c2w, uv):

    B, H, W, _ = uv.shape

    # Transform pixel coordinates to camera coordinates
    x_c = pixel_to_camera(K, uv, torch.ones((B, H, W, 1), device=uv.device))

    # Inverse camera-to-world transformation
    w2c = torch.inverse(c2w)
    R = w2c[:, :3, :3]  # Rotation matrix
    R_inv = torch.inverse(R)
    T = w2c[:, :3, 3]  # Translation vector

    # Compute ray origins
    r_o = -torch.bmm(R_inv, T.unsqueeze(-1)).squeeze(-1)

    # Transform camera coordinates to world coordinates
    x_w = transform(c2w, x_c)

    # Broadcast ray origins across image dimensions
    r_o = r_o.unsqueeze(1).unsqueeze(1).repeat(1, H, W, 1)

    # Compute ray directions
    r_d = (x_w - r_o) / torch.norm((x_w - r_o), dim=-1, keepdim=True)

    return r_o, r_d


def sample_along_rays(r_o, r_d, perturb=True, near=2.0, far=6.0, n_samples=64):
    t = torch.linspace(near, far, n_samples, device=r_o.device)
    if perturb:
        t = t + torch.rand_like(t) * (far - near) / n_samples
    x = r_o + r_d * t.unsqueeze(-1).unsqueeze(-1)
    return x

In [4]:
class RaysData:
    def __init__(self, images, K, c2w, device='cuda'):

        self.images = images
        self.K = K
        self.c2w = c2w
        self.device = device

        # Image dimensions
        self.num_images, self.height, self.width = images.shape[:3]

        # Create UV grid for pixel coordinates
        uv_grid = torch.meshgrid(
            torch.arange(self.num_images, device=device),
            torch.arange(self.height, device=device),
            torch.arange(self.width, device=device),
            indexing='ij'
        )
        self.uv = torch.stack(uv_grid, dim=-1).float()

        # Add 0.5 offset to pixel centers
        self.uv[..., 1:] += 0.5

        # Flatten UV grid for later use
        self.uv_flattened = self.uv.reshape(-1, 3)

        # Compute ray origins and directions
        self.r_o, self.r_d = pixel_to_ray(K, c2w, self.uv)

        # Flatten rays and pixels for sampling
        self.pixels = images.reshape(-1, 3)
        self.r_o_flattened = self.r_o.reshape(-1, 3)
        self.r_d_flattened = self.r_d.reshape(-1, 3)

    def sample_rays(self, batch_size):
        idx = torch.randint(0, self.pixels.shape[0], (batch_size,), device=self.device)
        return (
            self.r_o_flattened[idx],
            self.r_d_flattened[idx],
            self.pixels[idx]
        )

    def sample_rays_single_img(self, image_index=None):
        if image_index is None:
            image_index = torch.randint(0, self.num_images, (1,), device=self.device).item()

        start_idx = image_index * self.height * self.width
        end_idx = start_idx + self.height * self.width

        return (
            self.r_o_flattened[start_idx:end_idx],
            self.r_d_flattened[start_idx:end_idx],
            self.pixels[start_idx:end_idx]
        )


In [5]:
def volrend(sigmas, rgbs, step_size):

    B, N, _ = sigmas.shape

    # Calculate transmittance of the rays
    T_i = torch.cat([
        torch.ones((B, 1, 1), device=rgbs.device),
        torch.exp(-step_size * torch.cumsum(sigmas, dim=1)[:, :-1])
    ], dim=1)

    alpha = 1 - torch.exp(-sigmas * step_size)

    weights = alpha * T_i

    rendered_colors = torch.sum(weights * rgbs, dim=1)

    return rendered_colors


In [6]:
def load_data():

    data = np.load(f"lego_200x200.npz")

    # Training images: [100, 200, 200, 3]
    images_train = data["images_train"] / 255.0

    # Cameras for the training images
    # (camera-to-world transformation matrix): [100, 4, 4]
    c2ws_train = data["c2ws_train"]

    # Validation images:
    images_val = data["images_val"] / 255.0

    # Cameras for the validation images: [10, 4, 4]
    # (camera-to-world transformation matrix): [10, 200, 200, 3]
    c2ws_val = data["c2ws_val"]

    # Test cameras for novel-view video rendering:
    # (camera-to-world transformation matrix): [60, 4, 4]
    c2ws_test = data["c2ws_test"]

    # Camera focal length
    focal = data["focal"]  # float

    return images_train, c2ws_train, images_val, c2ws_val, c2ws_test, focal

In [7]:
def positional_encoding(x, L):

    frequencies = 2.0 ** torch.arange(L).float().to(x.device)
    x_in = x.unsqueeze(-1) * frequencies * 2 * torch.pi

    encoding = torch.cat([torch.sin(x_in), torch.cos(x_in)], dim=-1)
    encoding = torch.cat([x, encoding.reshape(*x.shape[:-1], -1)], dim=-1)

    return encoding

def psnr(image1, image2):

    mse = np.mean((image1 - image2) ** 2)

    if mse == 0:
        return 100

    return 20 * math.log10(1.0 / math.sqrt(mse))

In [8]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        # Network dimensions
        input_xyz_dim = 2 * 3 * 10 + 3  # Positional encoding dim for xyz
        input_dir_dim = 2 * 3 * 4 + 3   # Positional encoding dim for direction
        hidden_dim = 256
        rgb_hidden_dim = 128
        rgb_output_dim = 3

        # Block 1: Initial xyz processing
        self.xyz_layers = nn.ModuleList([
            nn.Linear(input_xyz_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim)
        ])

        # Block 2: Direction processing
        self.dir_layer = nn.Linear(input_dir_dim, hidden_dim)

        # Block 3: Combined processing
        self.combined_layers = nn.ModuleList([
            nn.Linear(hidden_dim + input_xyz_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim)
        ])

        # Density output
        self.density_layer = nn.Linear(hidden_dim, 1)

        # RGB processing
        self.rgb_layers = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim),
            nn.Linear(hidden_dim + input_dir_dim, rgb_hidden_dim)
        ])
        self.rgb_output = nn.Linear(rgb_hidden_dim, rgb_output_dim)

        # Activation functions
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, r_d):
        """
        Forward pass through the network.
        Args:
            x: Position coordinates
            r_d: Ray directions
        Returns:
            rgb: RGB colors
            density: Density values
        """
        # Apply positional encoding
        x_encoded = positional_encoding(x, L=10)
        r_d_encoded = positional_encoding(r_d, L=4)

        # Process position through Block 1
        feat = x_encoded
        for layer in self.xyz_layers:
            feat = self.relu(layer(feat))

        # Combine features with original encoding
        feat = torch.cat([feat, x_encoded], dim=-1)

        # Process through Block 3
        for layer in self.combined_layers:
            feat = self.relu(layer(feat))

        # Calculate density
        density = self.relu(self.density_layer(feat))

        # Process RGB
        rgb_feat = self.rgb_layers[0](feat)
        rgb_feat = torch.cat([rgb_feat, r_d_encoded], dim=-1)
        rgb_feat = self.relu(self.rgb_layers[1](rgb_feat))
        rgb = self.sigmoid(self.rgb_output(rgb_feat))

        return rgb, density


def train_model(model, train_dataset, val_dataset, test_dataset, optimizer, criterion, iters=2500, batch_size=10000, device='cuda'):

    def sample_and_compute_loss(dataset, is_training=True):
        rays_o, rays_d, pixels = dataset.sample_rays(batch_size) if is_training else dataset.sample_rays_single_img()
        points = sample_along_rays(rays_o, rays_d, perturb=is_training)
        points = points.permute(1, 0, 2)
        rays_d = rays_d.unsqueeze(1).repeat(1, points.shape[1], 1)

        rgb, sigmas = model(points, rays_d)
        comp_rgb = volrend(sigmas, rgb, step_size=(6.0 - 2.0) / 64)

        loss = criterion(comp_rgb, pixels)
        psnr_value = psnr(comp_rgb.detach().cpu().numpy(), pixels.cpu().numpy())

        return loss, psnr_value, comp_rgb

    psnr_scores = []
    train_psnr_scores = []
    train_losses = []
    val_losses = []

    model.train()
    for i in tqdm(range(iters)):
        # Training step
        optimizer.zero_grad()
        loss, train_psnr, _ = sample_and_compute_loss(train_dataset)
        train_losses.append(loss.item())
        train_psnr_scores.append(train_psnr)

        loss.backward()
        optimizer.step()

        print(f"Training PSNR: {train_psnr}")

        # Validation step
        model.eval()
        with torch.no_grad():
            val_loss, curr_psnr, comp_rgb = sample_and_compute_loss(val_dataset, is_training=False)
            val_losses.append(val_loss.item())
            psnr_scores.append(curr_psnr)

            print(f"Validation PSNR: {curr_psnr:.2f} dB")

            # Save validation image
            image = comp_rgb.reshape(200, 200, 3).cpu().numpy()
            plt.imsave(f"iterations/iter{i+1}.jpg", image)

        model.train()

        # Save
        torch.save(model.state_dict(), f"nerf_model_{i}.pt")

    # Save
    torch.save(model.state_dict(), f"nerf_model.pt")

    # Create PSNR plots
    plt.figure()
    plt.plot(range(1, len(train_psnr_scores) + 1), train_psnr_scores, color="blue")
    plt.xlabel('Iteration')
    plt.ylabel('PSNR (dB)')
    plt.title('PSNR Vs. Iterations')
    plt.savefig('train_psnr_nerf.png')

    plt.figure()
    plt.plot(range(1, len(psnr_scores) + 1), psnr_scores, color="blue")
    plt.xlabel('Epoch')
    plt.ylabel('Validation PSNR (dB)')
    plt.title('PSNR vs. Iteration')
    plt.savefig('val_psnr_nerf.png')

    # Create Loss plots
    plt.figure()
    plt.plot(range(1, len(train_losses) + 1), train_losses, color="red")
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Loss vs. Iteration')
    plt.savefig('train_loss_nerf.png')

    plt.figure()
    plt.plot(range(1, len(val_losses) + 1), val_losses, color="red")
    plt.xlabel('Epoch')
    plt.ylabel('Validation Loss')
    plt.title('Loss vs. Iteration')
    plt.savefig('val_loss_nerf.png')


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 10000
num_iterations = 2000
learning_rate = 5e-4

# Load dataset
images_train, c2ws_train, images_val, c2ws_val, c2ws_test, focal = load_data()

def to_device(tensor_data):
    return torch.tensor(tensor_data).float().to(device)

images_train = to_device(images_train)
c2ws_train = to_device(c2ws_train)
images_val = to_device(images_val)
c2ws_val = to_device(c2ws_val)
c2ws_test = to_device(c2ws_test)
focal = to_device(focal)

def create_intrinsic_matrices(focal, image_shape, num_matrices):
    height, width = image_shape[1], image_shape[2]  # Assuming shape is (batch_size, channels, height, width)
    K = intrinsic_matrix(focal.item(), focal.item(), width / 2, height / 2)
    return torch.tensor(K).unsqueeze(0).repeat(num_matrices, 1, 1).to(device)

K_train = create_intrinsic_matrices(focal, images_train.shape, images_train.shape[0])
K_val = create_intrinsic_matrices(focal, images_val.shape, images_val.shape[0])
K_test = create_intrinsic_matrices(focal, images_val.shape, c2ws_test.shape[0])

# Training
model = MLP().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Define datasets
train_dataset = RaysData(images_train, K_train, c2ws_train)
val_dataset = RaysData(images_val, K_val, c2ws_val)
test_dataset = RaysData(images_train[:60], K_test, c2ws_test)

  return torch.tensor(K).unsqueeze(0).repeat(num_matrices, 1, 1).to(device)


In [10]:
train = True

if train:
  train_model(model, train_dataset, val_dataset, test_dataset, optimizer, criterion, iters=num_iterations, batch_size=batch_size, device=device)
else:
  checkpoint_path = "nerf_model.pt"
  model.load_state_dict(torch.load(checkpoint_path, map_location=device))

  model.load_state_dict(torch.load(checkpoint_path, map_location=device))


In [11]:
def render(model, test_dataset):

    model.eval()

    with torch.no_grad():
        for i in range(test_dataset.c2w.shape[0]):

            rays_o, rays_d, _ = test_dataset.sample_rays_single_img(i)
            points = sample_along_rays(rays_o, rays_d)

            points = points.permute(1, 0, 2)
            rays_d = rays_d.unsqueeze(1).repeat(1, points.shape[1], 1)

            rgb, sigmas = model(points, rays_d)
            composite_rgb = volrend(sigmas, rgb, step_size=(6.0 - 2.0) / 64)

            image = composite_rgb.reshape(200, 200, 3).cpu().numpy()
            plt.imsave(f"out/render_{i}.jpg", image)

render(model, test_dataset)

KeyboardInterrupt: 