In [1]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import cv2 as cv
from ProgressNerf.Registries.DataloaderRegistry import get_dataloader
from ProgressNerf.Registries.RaypickerRegistry import get_raypicker
from ProgressNerf.Registries.RaysamplerRegistry import get_raysampler
from ProgressNerf.Registries.EncoderRegistry import get_encoder
from ProgressNerf.Registries.ModelRegistry import get_model
import ProgressNerf.Dataloading.ToolsPartsDataloader
import ProgressNerf.Raycasting.RandomRaypicker
import ProgressNerf.Raycasting.NearFarRaysampler
import ProgressNerf.Encoders.PositionalEncoder
import ProgressNerf.Raycasting.WeightedRaypicker
import ProgressNerf.Raycasting.PerturbedRaysampler
import ProgressNerf.Models.OGNerf
from ProgressNerf.Utils.CameraUtils import BuildCameraMatrix
import yaml

In [2]:
configFile = "./configs/tests/sandboxTestConfig.yml"
with open(configFile,'r') as f:
    config = yaml.load(f, yaml.FullLoader)

render_width = config['render_resoluion'][0]
render_height = config['render_resoluion'][1]

In [3]:
raypicker = get_raypicker(config['raypicker'])(config[config['raypicker']])
cam_matrix = BuildCameraMatrix(config['camera_fx'],config['camera_fy'],config['camera_tx'],config['camera_ty'])
raypicker.setCameraParameters(cam_matrix, config['render_resoluion'][0], config['render_resoluion'][1])
raysampler = get_raysampler(config['raysampler'])(config[config['raysampler']])

In [4]:
train_loader_config = config['train_dataloader']
train_loader = get_dataloader(train_loader_config['dataloader'])(train_loader_config[train_loader_config['dataloader']])
batch_size = config['batch_size']
device = config['device']
lr = config['optim_lr']
dataloader = DataLoader(train_loader, batch_size=batch_size, shuffle=True, num_workers=config['num_workers'])

In [5]:
coarse_config = config['coarse_model']
fine_config = config['fine_model']
nn = get_model(coarse_config['nnModel'])(coarse_config[coarse_config['nnModel']])
nn_fine = get_model(fine_config['nnModel'])(fine_config[fine_config['nnModel']])

In [6]:
learning_params = list(nn.parameters())
if(nn_fine is not None):
    nn_fine.train()
    nn_fine.to(device)
    learning_params = learning_params + list(nn_fine.parameters())
optimizer = torch.optim.Adam(learning_params, lr=lr)
loss_rgb = torch.nn.MSELoss()
