In [122]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import json
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F

In [104]:
class NeRFSynthetic(Dataset):
    def __init__(self, data_path: str, split="train", img_size=(800, 800), device="cpu"):
        """
        Args:
            data_path (str): Path to the dataset folder (e.g., "nerf_synthetic/chair").
            split (str): Dataset split to load ("train", "val", or "test").
            img_size (tuple): Target size of the images.
            device (str): Device to move tensors to ("mps", "cuda", or "cpu").
        """
        super().__init__()
        self.data_path = data_path
        self.split = split
        self.img_size = img_size
        self.device = device

        # Load JSON metadata for the given split
        json_file = os.path.join(data_path, f"transforms_{split}.json")
        with open(json_file, "r") as f:
            self.meta = json.load(f)

        # Extract file paths and transformation matrices
        self.image_paths = [os.path.join(data_path, frame["file_path"][2:] + ".png") for frame in self.meta["frames"]]
        self.transform_matrices = [np.array(frame["transform_matrix"], dtype=np.float32) for frame in self.meta["frames"]]

        # Camera intrinsics
        self.camera_angle_x = self.meta["camera_angle_x"]
        self.focal_length = (0.5 * img_size[0]) / np.tan(0.5 * self.camera_angle_x)  # Focal length in pixels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        """
        Returns:
            img: Tensor of the image [3, H, W].
            transform_matrix: Camera-to-world transformation matrix [4, 4].
            focal_length: Focal length (scalar).
        """
        # Load and process the image
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        img = img.resize(self.img_size, Image.LANCZOS)
        img = np.array(img).astype(np.float32) / 255.0  # Normalize to [0, 1]
        img = torch.tensor(img).permute(2, 0, 1).to(self.device)  # Convert to [C, H, W] and move to device

        # Get the transformation matrix for the current frame
        transform_matrix = torch.tensor(self.transform_matrices[idx], dtype=torch.float32).to(self.device)

        # Return the focal length, image, and camera-to-world matrix
        focal_length = torch.tensor(self.focal_length, dtype=torch.float32).to(self.device)
        return img, transform_matrix, focal_length

In [105]:
if torch.backends.mps.is_available():
    device = torch.device('mps')  # Use MPS for Apple Silicon
elif torch.cuda.is_available():
    device = torch.device('cuda')  # Use CUDA for NVIDIA GPUs
else:
    device = torch.device('cpu')  # Fallback to CPU

print(f"Using device: {device}")

Using device: mps


In [106]:
nerf_chair_train_set = NeRFSynthetic("/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair", device=device)
len(nerf_chair_train_set)

100

In [107]:
nerf_chair_train_set[0][0].shape

torch.Size([3, 800, 800])

In [108]:
nerf_chair_train_set[0][1].shape

torch.Size([4, 4])

In [109]:
# Each 4x4 matrix is a homogeneous transformation matrix that transforms a point in camera space to world space.
# Top Left 3x3 matrix is the rotation matrix and the rightmost column is the translation vector.
# The bottom row is always [0, 0, 0, 1] for 3D homogeneous transformations.
nerf_chair_train_set[0][1]

tensor([[-0.9250,  0.2749, -0.2623, -1.0572],
        [-0.3799, -0.6693,  0.6385,  2.5740],
        [ 0.0000,  0.6903,  0.7235,  2.9166],
        [ 0.0000,  0.0000,  0.0000,  1.0000]], device='mps:0')

In [110]:
nerf_chair_train_set.image_paths

['/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_0.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_1.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_2.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_3.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_4.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_5.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_6.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_7.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_8.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_9.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_10.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_11.png',
 '/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair/train/r_12.png',
 '/Users/rickypramanick/Desktop/ner

In [111]:
nerf_chair_train_set.transform_matrices

[array([[-0.925014  ,  0.27488998, -0.26226836, -1.0572376 ],
        [-0.37993318, -0.6692679 ,  0.63853836,  2.5740304 ],
        [ 0.        ,  0.6903013 ,  0.72352195,  2.9166102 ],
        [ 0.        ,  0.        ,  0.        ,  1.        ]],
       dtype=float32),
 array([[ 6.8581498e-01, -7.2490275e-01,  6.4605363e-02,  2.6043257e-01],
        [ 7.2777599e-01,  6.8310744e-01, -6.0880452e-02, -2.4541697e-01],
        [ 3.7252903e-09,  8.8770814e-02,  9.9605203e-01,  4.0152144e+00],
        [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  1.0000000e+00]],
       dtype=float32),
 array([[ 0.3260958 ,  0.14048941, -0.9348391 , -3.7684572 ],
        [-0.9453368 ,  0.0484621 , -0.32247463, -1.2999368 ],
        [ 0.        ,  0.98889536,  0.1486131 ,  0.59907854],
        [ 0.        ,  0.        ,  0.        ,  1.        ]],
       dtype=float32),
 array([[-0.90634435, -0.15694898,  0.39230973,  1.581451  ],
        [ 0.4225398 , -0.33665425,  0.84150106,  3.3921993 ],
        [ 0.

In [112]:
nerf_chair_train_set.camera_angle_x

0.6911112070083618

In [113]:
nerf_chair_train_set.focal_length

1111.1110311937682

In [114]:
dataloader = DataLoader(nerf_chair_train_set, batch_size=4, shuffle=True)

In [115]:
class NeRF(nn.Module):
    def __init__(self, in_features=60, hidden_dim=256, num_layers=8):
        """
        NeRF model with sinusoidal positional encoding.
        Args:
            in_features: Number of input features (e.g., 60 for positional encoding).
            hidden_dim: Number of hidden units in each layer.
            num_layers: Number of layers in the MLP.
        """
        super().__init__()
        self.layers = nn.ModuleList()

        # Input layer
        self.layers.append(nn.Linear(in_features, hidden_dim))

        # Hidden layers
        for _ in range(num_layers - 1):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim))

        # Output layer: RGB (3) + density (1)
        self.rgb_sigma = nn.Linear(hidden_dim, 4)

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = F.relu(layer(x))
        return self.rgb_sigma(x)

In [116]:
def positional_encoding(x, num_freqs=10, device="cpu"):
    """
    Encode 3D coordinates with sinusoidal positional encoding.
    Args:
        x: Tensor of shape [N, 3] (3D spatial points).
        num_freqs: Number of frequency bands to encode.
        device: Device to perform the computation ("cpu", "cuda", or "mps").
    Returns:
        Encoded tensor of shape [N, 3 * 2 * num_freqs].
    """
    # Ensure input is on the correct device
    x = x.to(device)

    # Create frequency bands
    freq_bands = 2.0 ** torch.arange(num_freqs, dtype=torch.float32, device=device)

    # Apply sinusoidal encoding
    x = x.unsqueeze(-1) * freq_bands  # Shape: [N, 3, num_freqs]
    encoded = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)  # Shape: [N, 3, 2 * num_freqs]

    return encoded.view(x.shape[0], -1)  # Flatten to [N, 3 * 2 * num_freqs]


In [117]:
def get_rays(H, W, focal, c2w, device="cpu"):
    """
    Generate rays for all pixels in the image.
    Args:
        H, W: Image height and width.
        focal: Focal length.
        c2w: Camera-to-world transformation matrix.
        device: Device to perform the computation ("cpu", "cuda", or "mps").
    Returns:
        rays_o: Ray origins (camera position) on the specified device.
        rays_d: Ray directions on the specified device.
    """
    # Ensure inputs are on the correct device
    focal = focal.to(device)
    c2w = c2w.to(device)

    # Create pixel grid
    i, j = torch.meshgrid(
        torch.arange(W, device=device),
        torch.arange(H, device=device),
        indexing="xy"
    )

    # Compute ray directions in camera space
    dirs = torch.stack(
        [(i - W / 2) / focal, -(j - H / 2) / focal, -torch.ones_like(i, device=device)],
        dim=-1
    )  # Shape: [H, W, 3]

    # Rotate ray directions to world space
    rays_d = (dirs @ c2w[:3, :3].T).reshape(-1, 3)  # Shape: [H * W, 3]

    # Camera origin (same for all rays)
    rays_o = c2w[:3, 3].expand(rays_d.shape)  # Shape: [H * W, 3]

    return rays_o, rays_d


In [118]:
def volume_rendering(rgb_sigma, z_vals, dirs, device="cpu"):
    """
    Perform volume rendering along a ray.
    Args:
        rgb_sigma: Output of NeRF [N, num_samples, 4].
        z_vals: Depth values for samples [N, num_samples].
        dirs: Ray directions [N, 3].
        device: Device to perform the computation ("cpu", "cuda", or "mps").
    Returns:
        pixel_color: Rendered pixel colors [N, 3].
    """
    # Ensure inputs are on the correct device
    rgb_sigma = rgb_sigma.to(device)
    z_vals = z_vals.to(device)
    dirs = dirs.to(device)

    # Compute distances between depth samples
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, torch.full_like(dists[..., :1], 1e10, device=device)], dim=-1)
    dists *= torch.norm(dirs[..., None, :], dim=-1)

    # Extract RGB and density (sigma)
    rgb = rgb_sigma[..., :3]  # Extract RGB
    sigma = rgb_sigma[..., 3]  # Extract density

    # Compute alpha values
    alpha = 1.0 - torch.exp(-sigma * dists)  # Alpha compositing

    # Compute weights
    weights = alpha * torch.cumprod(
        torch.cat([torch.ones_like(alpha[..., :1], device=device), 1.0 - alpha + 1e-10], dim=-1),
        dim=-1,
    )[..., :-1]

    # Compute pixel colors as a weighted sum
    pixel_color = torch.sum(weights[..., None] * rgb, dim=-2)  # Weighted sum

    return pixel_color



In [119]:
def compute_loss(rendered, ground_truth):
    return torch.mean((rendered - ground_truth) ** 2)


In [120]:
def train_nerf(model, dataset, num_epochs=10, lr=5e-4, device="cpu"):
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0
        for images, transform_matrices, focal_lengths in dataloader:
            # Move inputs to the MPS device
            images = images.to(device)
            transform_matrices = transform_matrices.to(device)
            focal_lengths = focal_lengths.to(device)

            H, W = images.shape[-2:]
            rays_o, rays_d = get_rays(H, W, focal_lengths[0], transform_matrices[0], device=device)
            rays_o = rays_o.to(device)
            rays_d = rays_d.to(device)

            # Sample points along rays
            t_vals = torch.linspace(0, 1, steps=64, device=device)
            z_vals = t_vals * (rays_d.norm(dim=-1).max())  # Near-far plane approximation
            points = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]

            # Encode points and predict RGB + density
            points_encoded = positional_encoding(points.view(-1, 3), device=device)
            rgb_sigma = model(points_encoded).view(-1, z_vals.shape[1], 4)

            # Volume render
            pixel_color = volume_rendering(rgb_sigma, z_vals, rays_d, device=device)

            # Compute loss
            ground_truth = images.view(-1, 3)
            loss = compute_loss(pixel_color, ground_truth)
            epoch_loss += loss.item()

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(dataloader):.4f}")


In [124]:
# Dataset path and training
data_path = "/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair"
dataset = NeRFSynthetic(data_path, split="train", img_size=(100, 100), device=device)
model = NeRF().to(device)
train_nerf(model, dataset, num_epochs=10, device=device)

RuntimeError: MPS backend out of memory (MPS allocated: 8.51 GB, other allocations: 4.75 MB, max allowed: 9.07 GB). Tried to allocate 625.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).