In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import sys


P_PATH = os.getcwd()
print(os.listdir(P_PATH))

sys.path.append(P_PATH)

['train.py', 'results', 'src', 'README.md', 'models', '.gitignore', 'wandb', 'exploration.ipynb', '.git', 'playground.ipynb', 'data', '.vscode', 'exploration_V3.ipynb']


In [4]:
# Script for training the NeRF model.
import os
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm
from datetime import datetime
from src.utils import *
from src.data_loader import *
from src.model import *
from src.trainer import *


In [5]:

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set random seed
seed = 42
torch.manual_seed(seed)

# Set hyperparameters
SCALEDOWN = 2
OBJ_NAME = 'chair'
BATCH_SIZE = 2048*2
NUM_WORKERS = 8
SAMPLE = 32 
D = 6
W = 128
input_ch_pos = 3
input_ch_dir = 2
L_p = 10
L_v = 4
skips = [3]
lr = 1e-3

img_size = int(800/SCALEDOWN)

# Set paths
P_PATH = os.path.join(os.getcwd())
sys.path.append(P_PATH)


data_preprocess(OBJ_NAME, P_PATH)



In [None]:
# Load data
train_dataset = SynDatasetRay(obj_name=OBJ_NAME, root_dir=P_PATH, split='train', img_size=img_size, num_points=SAMPLE)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

min_max = train_dataset.min_max

val_dataset = SynDatasetRay(obj_name=OBJ_NAME, root_dir=P_PATH, split='val', img_size=img_size, num_points=SAMPLE, min_max=min_max)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

model = NeRF(D=D, W=W, input_ch_pos=input_ch_pos, input_ch_dir=input_ch_dir, skips=skips, L_p=L_p, L_v=L_v).to(device)
model = model.to(device)

loss_fn = nn.MSELoss(reduction='mean')


total_loss = 0
with torch.no_grad():
    for data in val_dataloader:
        # Unpack the data from the dataset
        points = data['points'].to(device)
        v_dir = data['v_dir'].to(device)
        target_rgb = data['rgb'].to(device)
        z_vals = data['z_vals'].to(device).squeeze(-1)  # Ensure z_vals are provided by the dataset

        # Forward pass through the model
        rgb, sigma = model(points, v_dir)

        # Perform volume rendering using the outputs from the model
        rendered_rgb = volume_rendering(z_vals, rgb, sigma, white_bkgd=False)

        # Calculate the loss using the rendered RGB and the target RGB
        loss = loss_fn(rendered_rgb, target_rgb)
        print(loss.item())
        total_loss += loss.item()
average_loss = total_loss / len(val_dataloader)