In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import json
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [6]:
class NeRFSynthetic(Dataset):
    def __init__(self, data_path: str, split="train", img_size=(800, 800), transform: transforms = None):
        """
        Args:
            data_path (str): Path to the dataset folder (e.g., "nerf_synthetic/chair").
            split (str): Dataset split to load ("train", "val", or "test").
            img_size (tuple): Target size of the images.
            transform (transforms): Optional transforms to apply to the images.
        """
        super().__init__()
        self.data_path = data_path
        self.transform = transform
        self.split = split
        self.img_size = img_size

        json_path = os.path.join(data_path, f"transforms_{split}.json")
        with open(json_path, "r") as f:
            self.meta = json.load(f)

        # Extract image file paths and corresponding poses
        self.image_paths = [os.path.join(data_path, frame["file_path"] + ".png") for frame in self.meta["frames"]]
        self.poses = [np.array(frame["transform_matrix"], dtype=np.float32) for frame in self.meta["frames"]]

        # Assume focal length is the same for all images
        # Add intrinsics extraction if available in your dataset
        self.focal_length = self.meta["frames"][0]["intrinsics"][0] if "intrinsics" in self.meta["frames"][0] else None

         # Default image transform pipeline if not provided
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize(self.img_size),
                transforms.ToTensor()
            ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load the image and apply transformations
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)

        # Get the corresponding pose
        pose = torch.tensor(self.poses[idx], dtype=torch.float32)

        # Return focal length if available
        if self.focal_length is not None:
            focal_length = torch.tensor(self.focal_length, dtype=torch.float32)
            return img, pose, focal_length
        else:
            return img, pose


In [23]:
nerf_chair_train_set = NeRFSynthetic("/Users/rickypramanick/Desktop/nerf/nerf_synthetic/chair")
len(nerf_chair_train_set)

100

In [24]:
nerf_chair_train_set[0][0].shape

torch.Size([3, 800, 800])

In [25]:
nerf_chair_train_set[0][1].shape

torch.Size([4, 4])

In [26]:
nerf_chair_train_set[0][1]

tensor([[-0.9250,  0.2749, -0.2623, -1.0572],
        [-0.3799, -0.6693,  0.6385,  2.5740],
        [ 0.0000,  0.6903,  0.7235,  2.9166],
        [ 0.0000,  0.0000,  0.0000,  1.0000]])