In [5]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import pickle

In [6]:
class NeRFmodel(nn.Module):
    # def __init__(self, embed_pos_L, embed_direction_L):
    def __init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim=128):
        super(NeRFmodel, self).__init__()
        
        self.block1 = nn.Sequential(
            nn.Linear(embedding_dim_pos*6+3, hidden_dim), nn.ReLU(), 
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        
        # Density Estimation 
        self.block2 = nn.Sequential(
            nn.Linear(embedding_dim_pos*6 + hidden_dim +3, hidden_dim), nn.ReLU(), 
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        
        # Color Estimation
        self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + hidden_dim + 3, hidden_dim // 2), nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.relu = nn.ReLU()
    
    @staticmethod
    def position_encoding(self, x, L):
        
        out = [x]
        for jj in range(L):
            out.append(torch.sin(2**jj*x))
            out.append(torch.cos(2**jj*x))
        return torch.cat(out, dim=1)
        
        # return y

    def forward(self, pos, direction):
        
        emb_x = self.positional_encoding(pos, self.embedding_dim_pos) # emb_x: [batch_size, embedding_dim_pos * 6]
        emb_d = self.positional_encoding(direction, self.embedding_dim_direction) # emb_d: [batch_size, embedding_dim_direction * 6]
        h = self.block1(emb_x)
        tmp = self.block2(torch.cat((h, emb_x), dim=1)) # tmp: [batch_size, hidden_dim + 1]
        h, sigma = tmp[:, :-1], self.relu(tmp[:, -1]) # h: [batch_size, hidden_dim], sigma: [batch_size]
        h = self.block3(torch.cat((h, emb_d), dim=1)) # h: [batch_size, hidden_dim // 2]
        c = self.block4(h) # c: [batch_size, 3]
        return c, sigma


In [10]:
import os
import numpy as np
import json
import torch
import math
from torch.utils.data import Dataset
from skimage import io
import cv2
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)

#########################################################
# Loading Dataset {Data from original nerf paper}
#########################################################
def loadDataset(data_path, mode):
    """
    Input:
        data_path: dataset path
        mode: train or test
    Outputs:
        camera_info: image width, height, camera matrix 
        images: images
        pose: corresponding camera pose in world frame
    """
    
    if mode in ["train", "val", "test"]:
        json_file = f"transforms_{mode}.json"
    else:
        raise ValueError("Mode must be 'train', 'val', or 'test'.")
    
    transforms_path = os.path.join(data_path, json_file)
    with open(transforms_path) as file:
        data = json.load(file)
        
    print(data)
