In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys


P_PATH = os.getcwd()
print(os.listdir(P_PATH))

sys.path.append(P_PATH)

['results', 'tensorboard', 'src', 'temp_data', 'README.md', 'models', '.gitignore', 'wandb', 'exploration.ipynb', '.git', 'playground.ipynb', 'data', '.vscode', 'exploration_v2.ipynb']


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
from src.data_loader import *
from src.utils import *
from src.model import *

SCALEDOWN = 2
OBJ_NAME = 'chair'
BATCH_SIZE = 32
NUM_WORKERS = 4

img_size = int(800/SCALEDOWN)


min_max = None

#train dataset
train_dataset = SynDatasetRay(obj_name=OBJ_NAME, root_dir=P_PATH, split="train", img_size=img_size, num_points=8)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

min_max = train_dataset.min_max

print("train dataset size: ", len(train_dataset))


train dataset size:  16000000




In [4]:
#print the sample data shape
sample = train_dataset[0]

#sample is a dict with keys: rays_o, rays_d, points, z_vals, v_dir, img
for key in sample:
    print(key, sample[key].shape)


rays_o torch.Size([3])
rays_d torch.Size([3])
points torch.Size([8, 3])
z_vals torch.Size([8, 1])
v_dir torch.Size([2])
rgb torch.Size([3])


In [5]:
BATCH_SIZE = 1024
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
sample = next(iter(train_dataloader))

# for key in sample:
#     print(key, sample[key].shape)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


points = sample['points']
v_dir = sample['v_dir']

points_endcoded, v_dir_endcoded = position_encoding(points, v_dir, L_p=10, L_v=4)
print(points_endcoded.shape, v_dir_endcoded.shape)

torch.Size([1024, 8, 60]) torch.Size([1024, 16])


# test model and volume rendering

In [6]:
#test model
model = NeRF().to(device)
model.eval()

BATCH_SIZE = 1024
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

sample = next(iter(train_dataloader))

for key in sample:
    print(key, sample[key].shape)
print()

points = sample['points'].to(device)
v_dir = sample['v_dir'].to(device)
z_vals = sample['z_vals'].to(device).squeeze()

with torch.no_grad():
    rgb, sigma = model(points, v_dir)
    print(rgb.shape, sigma.shape)
print()


#test volume rendering
rendered_rgb = volume_rendering(z_vals, rgb, sigma)

print(rendered_rgb.shape)
print(sample['rgb'].shape)

#calculate loss
loss = torch.nn.functional.mse_loss(rendered_rgb, sample['rgb'].to(device))
print(loss)


rays_o torch.Size([1024, 3])
rays_d torch.Size([1024, 3])
points torch.Size([1024, 8, 3])
z_vals torch.Size([1024, 8, 1])
v_dir torch.Size([1024, 2])
rgb torch.Size([1024, 3])

torch.Size([1024, 8, 3]) torch.Size([1024, 8])

torch.Size([1024, 3])
torch.Size([1024, 3])
tensor(0.0064, device='cuda:0')


In [11]:
#run a small training loop
from torch.utils.data.sampler import SubsetRandomSampler
model = NeRF(D=6, W=128, skips=[3]).to(device)
model.train()
BATCH_SIZE = 128
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

indice = torch.randint(0, 160000, (1000,))
v_indice = torch.randint(0, 160000, (1000,))

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(indice), num_workers=NUM_WORKERS)
val_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(v_indice), num_workers=NUM_WORKERS)

# for epoch in range(10):
#     print("epoch: ", epoch)
#     for i, sample in enumerate(train_dataloader):
#         points = sample['points'].to(device)
#         v_dir = sample['v_dir'].to(device)
#         z_vals = sample['z_vals'].to(device).squeeze()
#         rgb_gt = sample['rgb'].to(device)

#         optimizer.zero_grad()
#         rgb, sigma = model(points, v_dir)
#         rgb_pred = volume_rendering(z_vals, rgb, sigma)
#         loss = torch.nn.functional.mse_loss(rgb_pred, rgb_gt)
#         loss.backward()
#         optimizer.step()

#         #validation loss
#         if i % 10 == 0:
#             with torch.no_grad():
#                 val_loss_total = 0
#                 for j, sample in enumerate(val_dataloader):
#                     points = sample['points'].to(device)
#                     v_dir = sample['v_dir'].to(device)
#                     z_vals = sample['z_vals'].to(device).squeeze()
#                     rgb_gt = sample['rgb'].to(device)

#                     rgb, sigma = model(points, v_dir)
#                     print("rgb shape: ", rgb.shape)
#                     print("sigma shape: ", sigma.shape)
#                     print("z_vals shape: ", z_vals.shape)
#                     rgb_pred = volume_rendering(z_vals, rgb, sigma)
#                     loss = torch.nn.functional.mse_loss(rgb_pred, rgb_gt)
#                     val_loss_total += loss.item()

#                 print("val loss: ", val_loss_total / len(val_dataloader))
#                 print("train loss: ", loss.item())
                    

    


epoch:  0


rgb shape:  torch.Size([128, 8, 3])
sigma shape:  torch.Size([128, 8])
z_vals shape:  torch.Size([128, 8])
rgb shape:  torch.Size([128, 8, 3])
sigma shape:  torch.Size([128, 8])
z_vals shape:  torch.Size([128, 8])
rgb shape:  torch.Size([128, 8, 3])
sigma shape:  torch.Size([128, 8])
z_vals shape:  torch.Size([128, 8])
rgb shape:  torch.Size([128, 8, 3])
sigma shape:  torch.Size([128, 8])
z_vals shape:  torch.Size([128, 8])
rgb shape:  torch.Size([128, 8, 3])
sigma shape:  torch.Size([128, 8])
z_vals shape:  torch.Size([128, 8])
rgb shape:  torch.Size([128, 8, 3])
sigma shape:  torch.Size([128, 8])
z_vals shape:  torch.Size([128, 8])
rgb shape:  torch.Size([128, 8, 3])
sigma shape:  torch.Size([128, 8])
z_vals shape:  torch.Size([128, 8])
rgb shape:  torch.Size([104, 8, 3])
sigma shape:  torch.Size([104, 8])
z_vals shape:  torch.Size([104, 8])
val loss:  0.2096224781125784
train loss:  0.20619143545627594
epoch:  1
rgb shape:  torch.Size([128, 8, 3])
sigma shape:  torch.Size([128, 8])


## test trainer and model

In [12]:
import wandb
from src.trainer import *
from torch.utils.data.sampler import SubsetRandomSampler



#init model
D = 6
W = 128
input_ch_pos = 3
input_ch_dir = 2
L_p = 10
L_v = 4
skips = [3]

lr = 1e-3
BATCH_SIZE = 128

wandb.init(project="nerf", 
           name="test",
           config={
                "D": D,
                "W": W,
                "input_ch_pos": input_ch_pos,
                "input_ch_dir": input_ch_dir,
                "L_p": L_p,
                "L_v": L_v,
                "skips": skips,
                "lr": lr,
                "BATCH_SIZE": BATCH_SIZE
              }  
           )



model = NeRF(D=D, W=W, input_ch_pos=input_ch_pos, input_ch_dir=input_ch_dir, L_p=L_p, L_v=L_v, skips=skips).to(device)
model = model.to(device)

wandb.watch(model, log="all")

#init optimizer
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)
loss_fn = torch.nn.MSELoss()



#create a subset of the train dataset
train_i = torch.randint(0, 160000, (1000,))
val_i = torch.randint(0, 160000, (1000,))


train_dataloader = DataLoader(train_dataset,sampler=SubsetRandomSampler(train_i), batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
val_dataloader = DataLoader(train_dataset,sampler=SubsetRandomSampler(val_i) , batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)


#init trainer
trainer = NeRFTrainer(model=model, optimizer=optimizer, 
                      lr_scheduler=lr_scheduler, loss_fn=loss_fn, 
                      train_loader=train_dataloader, val_loader=val_dataloader, 
                      device=device, wandb_run=True)

trainer.train(epochs=100, log_interval=1, early_stopping_patience=10)
wandb.finish()






Epoch 1/100 | Loss: 0.2133: : 8it [00:00, 34.72it/s]


Epoch 1/100 | Train Loss: 0.1881 | Val Loss: 0.1995


Epoch 2/100 | Loss: 0.0819: : 8it [00:00, 34.26it/s]


Epoch 2/100 | Train Loss: 0.1552 | Val Loss: 0.0737


Epoch 3/100 | Loss: 0.0490: : 8it [00:00, 35.14it/s]


Epoch 3/100 | Train Loss: 0.0677 | Val Loss: 0.0864


Epoch 4/100 | Loss: 0.0577: : 8it [00:00, 24.53it/s]


Epoch 4/100 | Train Loss: 0.0709 | Val Loss: 0.0863


Epoch 5/100 | Loss: 0.0273: : 8it [00:00, 36.05it/s]


Epoch 5/100 | Train Loss: 0.0702 | Val Loss: 0.0864


Epoch 6/100 | Loss: 0.0619: : 8it [00:00, 29.70it/s]


Epoch 6/100 | Train Loss: 0.0710 | Val Loss: 0.0858


Epoch 7/100 | Loss: 0.0902: : 8it [00:00, 31.70it/s]


Epoch 7/100 | Train Loss: 0.0716 | Val Loss: 0.0865


Epoch 8/100 | Loss: 0.1010: : 8it [00:00, 31.47it/s]


Epoch 8/100 | Train Loss: 0.0719 | Val Loss: 0.0861


Epoch 9/100 | Loss: 0.0839: : 8it [00:00, 30.17it/s]


Epoch 9/100 | Train Loss: 0.0715 | Val Loss: 0.0857


Epoch 10/100 | Loss: 0.0451: : 8it [00:00, 32.85it/s]


Epoch 10/100 | Train Loss: 0.0706 | Val Loss: 0.0859


Epoch 11/100 | Loss: 0.0840: : 8it [00:00, 32.54it/s]


Epoch 11/100 | Train Loss: 0.0715 | Val Loss: 0.0869


Epoch 12/100 | Loss: 0.0525: : 8it [00:00, 30.95it/s]


Early stopping at epoch 12




0,1
epoch,▁▂▂▃▄▄▅▅▆▇▇█
train_loss,█▆▁▁▁▁▁▁▁▁▁▁
val_loss,█▁▂▂▂▂▂▂▂▂▂▂

0,1
epoch,11.0
train_loss,0.07076
val_loss,0.08714
