In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


In [None]:
import os
os.chdir("/content/gdrive/MyDrive/Colab Notebooks/NeRF/")

data.py

In [None]:
import imageio
import numpy as np
import torch
# import json
# from torchvision import transforms
import os
import cv2

def load_poses(pose_dir, idxs=[]):
    txtfiles = np.sort([os.path.join(pose_dir, f.name) for f in os.scandir(pose_dir)])
    posefiles = np.array(txtfiles)[idxs]
    srn_coords_trans = np.diag(np.array([1, -1, -1, 1])) # SRN dataset
    poses = []
    for posefile in posefiles:
        pose = np.loadtxt(posefile).reshape(4,4)
        poses.append(pose@srn_coords_trans)
    return torch.from_numpy(np.array(poses)).float()

def load_imgs(img_dir, idxs = []):
    allimgfiles = np.sort([os.path.join(img_dir, f.name) for f in os.scandir(img_dir)])
    imgfiles = np.array(allimgfiles)[idxs]
    imgs = []
    for imgfile in imgfiles:
        img = imageio.imread(imgfile, pilmode='RGB')
        img = img.astype(np.float32)
        img /= 255.
        imgs.append(img)
    return torch.from_numpy(np.array(imgs))

def load_intrinsic(intrinsic_path):
    with open(intrinsic_path, 'r') as f:
        lines = f.readlines()
        focal = float(lines[0].split()[0])
        H, W = lines[-1].split()
        H, W = int(H), int(W)
    return focal, H, W

class DTU():
    def __init__(self,splits='train',data_dir='/content/gdrive/MyDrive/Colab Notebooks/NeRF/Data/dtu_dataset/rs_dtu_4/DTU',num_instances_per_obj=1):
        dir_list=[]
        with open(os.path.join(data_dir,splits+'.lst')) as f:
            for dir in f:
                dir_list.append(os.path.join(data_dir,dir[:-1]))

        self.ids=np.sort(dir_list)
        #print("DEBUG: INIT DTU: ",self.ids[0])
        self.lenids=len(self.ids)
        self.num_instances_per_obj=num_instances_per_obj
        self.train=True if splits=='train' else False

    def __len__(self):
        return self.lenids

    def load_dtu(self,path,idxs=[]):
        img_dir=os.path.join(path,"image")
        param_dir=os.path.join(path,"cameras.npz")
        params=np.load(param_dir)
        coord_trans_world=np.array([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,1]])
        coord_trans_cam=np.array([[1,0,0,0],[0,-1,0,0],[0,0,-1,0],[0,0,0,1]])
        #print("DEBUG: img_dir: ",img_dir)
        allimgfiles = np.sort([os.path.join(img_dir, f.name) for f in os.scandir(img_dir)])
        #print("DEBUG: allimgfiles: ",np.array(allimgfiles))
        #print("DEBUG: allimgfiles[0]: ",np.array(allimgfiles)[0])
        imgs = []
        poses=[]
        #print("Whose Shape? ",np.array(allimgfiles)[0])
        tmp=imageio.imread(np.array(allimgfiles)[0],pilmode='RGB')
        #print("Being Here")
        #H,W,_=tmp.shape
        H=128
        W=128
        #print(tmp.shape)
        #print("H W Completed")
        #H,W=np.array(allimgfiles)[0].shape
        focal=0.0
        for i in idxs:
            #print("into iterations")
            imgfile = np.array(allimgfiles)[i]
            #print("load img succ")
            img = imageio.imread(imgfile, pilmode='RGB')
            img = img.astype(np.float32)
            img /= 255.
            img.resize((128,128,3))
            img.resize((16384,3))
            imgs.append(img)
            mat=params["world_mat_"+str(i)]
            scale=params['scale_mat_'+str(i)]
            mat=mat[:3]
            #print("param load succ")
            K,R,t=cv2.decomposeProjectionMatrix(mat)[:3]
            #print("Proj Mat calc:")
            #print("K,R,t",K,R,t)
            K=K/K[2,2]
            pose=np.eye(4,dtype=np.float32)
            pose[:3,:3]=R.transpose()
            pose[:3,3]=(t[:3]/t[3])[:,0]
            norm_trans=scale[:3,3:]
            norm_scale=np.diagonal(scale[:3,:3])[...,None]
            pose[:3,3:]-=norm_trans
            pose[:3,3:]/=norm_scale
            #print("pose conv")
            pose=coord_trans_world@pose@coord_trans_cam
            #print("pose trans succ")
            poses.append(pose)
            focal+=(K[0,0]+K[1,1])/2
        focal=focal/len(idxs)
        #print("LOAD DTU SUCC")
        #print("focal: ",focal)
        #print("H W: ",H,W)
        #print(torch.from_numpy(np.array(poses)).float())
        #print(torch.from_numpy(np.array(imgs)))
        #print("imgs.shape: ",np.array(imgs).shape)
        im=torch.from_numpy(np.array(imgs))
        #im = im[:,32:-32,32:-32,:]
        #H, W = H // 2, W//2
        #print("im shape: ",im.shape)
        im.reshape(self.num_instances_per_obj, -1,3)
        #print("im reshaped: ",im.shape)
        #print("ret: ",focal,H,W,im.shape,np.array(poses).shape)
        return focal,H,W,im,torch.from_numpy(np.array(poses)).float()

    def __getitem__(self,idx):
        obj_id=self.ids[idx]
        if self.train:
            instances = np.random.choice(49, self.num_instances_per_obj)
        else:
            instances = np.arange(49)
        if self.train:
            focal,H,W,imgs,poses=self.load_dtu(obj_id,instances)
            return focal,H,W,imgs,poses,instances,idx
        else:
            focal,H,W,imgs,poses=self.load_dtu(obj_id,instances)
            return focal,H,W,imgs,poses,idx


class SRN():
    def __init__(self, cat='srn_cars', splits='cars_train',
                 data_dir = '/content/gdrive/MyDrive/Colab Notebooks/NeRF/Data/',
                num_instances_per_obj = 1, crop_img = True):
        """
        cat: srn_cars / srn_chairs
        split: cars_train(/test/val) or chairs_train(/test/val)
        First, we choose the id
        Then, we sample images (the number of instances matter)
        """
        #print("SRN: num_obj:",num_instances_per_obj)
        self.data_dir = os.path.join(data_dir, cat, splits)
        self.ids = np.sort([f.name for f in os.scandir(self.data_dir)])
        self.lenids = len(self.ids)
        self.num_instances_per_obj = num_instances_per_obj
        self.train = True if splits.split('_')[1] == 'train' else False
        self.crop_img = crop_img

    def __len__(self):
        return self.lenids

    def __getitem__(self, idx):
        obj_id = self.ids[idx]
        if self.train:
            focal, H, W, imgs, poses, instances = self.return_train_data(obj_id)
            #print("SRN DEBUG: ",imgs.shape)
            return focal, H, W, imgs, poses, instances, idx
        else:
            focal, H, W, imgs, poses = self.return_test_val_data(obj_id)
            return focal, H, W, imgs, poses, idx

    def return_train_data(self, obj_id):
        #print("num_instance_per_obj",self.num_instances_per_obj)
        pose_dir = os.path.join(self.data_dir, obj_id, 'pose')
        img_dir = os.path.join(self.data_dir, obj_id, 'rgb')
        intrinsic_path = os.path.join(self.data_dir, obj_id, 'intrinsics.txt')
        instances = np.random.choice(50, self.num_instances_per_obj)
        poses = load_poses(pose_dir, instances)
        imgs = load_imgs(img_dir, instances)
        focal, H, W = load_intrinsic(intrinsic_path)
        if self.crop_img:
            imgs = imgs[:,32:-32,32:-32,:]
            H, W = H // 2, W//2
        return focal, H, W, imgs.reshape(self.num_instances_per_obj, -1,3), poses, instances

    def return_test_val_data(self, obj_id):
        pose_dir = os.path.join(self.data_dir, obj_id, 'pose')
        img_dir = os.path.join(self.data_dir, obj_id, 'rgb')
        intrinsic_path = os.path.join(self.data_dir, obj_id, 'intrinsics.txt')
        instances = np.arange(250)
        poses = load_poses(pose_dir, instances)
        imgs = load_imgs(img_dir, instances)
        focal, H, W = load_intrinsic(intrinsic_path)
        return focal, H, W, imgs, poses

model.py

In [None]:
import torch
import torch.nn as nn

def PE(x, degree):
    y = torch.cat([2.**i * x for i in range(degree)], -1)
    w = 1
    return torch.cat([x] + [torch.sin(y) * w, torch.cos(y) * w], -1)


class CodeNeRF(nn.Module):
    def __init__(self, shape_blocks = 2, texture_blocks = 1, W = 256,
                 num_xyz_freq = 10, num_dir_freq = 4, latent_dim=256):
        super().__init__()
        self.shape_blocks = shape_blocks
        self.texture_blocks = texture_blocks
        self.num_xyz_freq = num_xyz_freq
        self.num_dir_freq = num_dir_freq

        d_xyz, d_viewdir = 3 + 6 * num_xyz_freq, 3 + 6 * num_dir_freq
        self.encoding_xyz = nn.Sequential(nn.Linear(d_xyz, W), nn.ReLU())
        for j in range(shape_blocks):
            layer = nn.Sequential(nn.Linear(latent_dim,W),nn.ReLU())
            setattr(self, f"shape_latent_layer_{j+1}", layer)
            layer = nn.Sequential(nn.Linear(W,W), nn.ReLU())
            setattr(self, f"shape_layer_{j+1}", layer)
        self.encoding_shape = nn.Linear(W,W)
        self.sigma = nn.Sequential(nn.Linear(W,1), nn.Softplus())
        self.encoding_viewdir = nn.Sequential(nn.Linear(W+d_viewdir, W), nn.ReLU())
        for j in range(texture_blocks):
            layer = nn.Sequential(nn.Linear(latent_dim, W), nn.ReLU())
            setattr(self, f"texture_latent_layer_{j+1}", layer)
            layer = nn.Sequential(nn.Linear(W,W), nn.ReLU())
            setattr(self, f"texture_layer_{j+1}", layer)
        self.rgb = nn.Sequential(nn.Linear(W, W//2), nn.ReLU(), nn.Linear(W//2, 3))

    def forward(self, xyz, viewdir, shape_latent, texture_latent):
        xyz = PE(xyz, self.num_xyz_freq)
        viewdir = PE(viewdir, self.num_dir_freq)
        y = self.encoding_xyz(xyz)
        for j in range(self.shape_blocks):
            z = getattr(self, f"shape_latent_layer_{j+1}")(shape_latent)
            y = y + z
            y = getattr(self, f"shape_layer_{j+1}")(y)
        y = self.encoding_shape(y)
        sigmas = self.sigma(y)
        y = torch.cat([y, viewdir], -1)
        y = self.encoding_viewdir(y)
        for j in range(self.texture_blocks):
            z = getattr(self, f"texture_latent_layer_{j+1}")(texture_latent)
            y = y + z
            y = getattr(self, f"texture_layer_{j+1}")(y)
        rgbs = self.rgb(y)
        return sigmas, rgbs

util.py

In [None]:
import imageio
import numpy as np
import torch
# import json
# from torchvision import transforms
import os


def get_rays(H, W, focal, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W - 1, W), torch.linspace(0, H - 1, H))
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i - W * .5) / focal, -(j - H * .5) / focal, -torch.ones_like(i)], -1)
    rays_d = torch.sum(dirs[..., np.newaxis, :].type_as(c2w) * c2w[..., :3, :3], -1)
    viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    rays_o = c2w[..., :3, -1].expand(rays_d.shape)
    rays_o, viewdirs = rays_o.reshape(-1, 3), viewdirs.reshape(-1, 3)
    return rays_o, viewdirs

def sample_from_rays(ro, vd, near, far, N_samples, z_fixed = False):
    # Given ray centre (camera location), we sample z_vals
    # we do not use ray_o here - just number of rays
    if z_fixed:
        z_vals = torch.linspace(near, far, N_samples).type_as(ro)
    else:
        dist = (far - near) / (2*N_samples)
        z_vals = torch.linspace(near+dist, far-dist, N_samples).type_as(ro)
        z_vals += torch.rand(N_samples) * (far - near) / (2*N_samples)
    xyz = ro.unsqueeze(-2) + vd.unsqueeze(-2) * z_vals.unsqueeze(-1)
    vd = vd.unsqueeze(-2).repeat(1,N_samples,1)
    return xyz, vd, z_vals

def volume_rendering(sigmas, rgbs, z_vals, white_bg = True):
    deltas = z_vals[1:] - z_vals[:-1]
    deltas = torch.cat([deltas, torch.ones_like(deltas[:1]) * 1e10])
    alphas = 1 - torch.exp(-sigmas.squeeze(-1) * deltas)
    trans = 1 - alphas + 1e-10
    transmittance = torch.cat([torch.ones_like(trans[..., :1]), trans], -1)
    accum_trans = torch.cumprod(transmittance, -1)[..., :-1]
    weights = alphas * accum_trans
    rgb_final = torch.sum(weights.unsqueeze(-1) * rgbs, -2)
    depth_final = torch.sum(weights * z_vals, -1)
    if white_bg:
        weights_sum = weights.sum(1)
        rgb_final = rgb_final + 1 - weights_sum.unsqueeze(-1)
    return rgb_final, depth_final

def image_float_to_uint8(img):
    """
    Convert a float image (0.0-1.0) to uint8 (0-255)
    """
    #print(img.shape)
    vmin = np.min(img)
    vmax = np.max(img)
    if vmax - vmin < 1e-10:
        vmax += 1e-10
    img = (img - vmin) / (vmax - vmin)
    img *= 255.0
    return img.astype(np.uint8)


#def str2bool(v):
#    if isinstance(v, bool):
#        return v
#    if v.lower() in ('yes', 'true'):
#        return True
#    elif v.lower() in ('no', 'false'):
#        return False
#    else:
#        raise argparse.ArgumentTypeError('Boolean value expected.')

optimizer.py

In [None]:
import numpy as np
import torch
import torch.nn as nn
import json
from skimage.metrics import structural_similarity as compute_ssim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
import imageio
import time


class Optimizer():
    def __init__(self, saved_dir, gpu, instance_ids=[], splits='test',
                 jsonfile = 'srncar.json', batch_size=2048, num_opts = 200):
        """
        :param saved_dir: the directory of pre-trained model
        :param gpu: which GPU we would use
        :param instance_id: the number of images for test-time optimization(ex : 000082.png)
        :param splits: test or val
        :param jsonfile: where the hyper-parameters are saved
        :param num_opts : number of test-time optimization steps
        """
        super().__init__()
        # Read Hyperparameters
        hpampath = os.path.join('jsonfiles', jsonfile)
        with open(hpampath, 'r') as f:
            print("DEBUG: ",f)
            self.hpams = json.load(f)
        self.device = torch.device('cuda:' + str(gpu))
        self.make_model()
        self.load_model_codes(saved_dir)
        print("DEBUG: ",saved_dir)
        print("DEBUG: ",instance_ids)
        self.make_dataloader(splits, len(instance_ids))
        print('we are going to save at ', self.save_dir)
        #self.saved_dir = saved_dir
        self.B = batch_size
        self.num_opts = num_opts
        self.splits = splits
        self.nviews = str(len(instance_ids))
        self.psnr_eval = {}
        self.psnr_opt = {}
        self.ssim_eval = {}


    def optimize_objs(self, instance_ids, lr=1e-2, lr_half_interval=50, save_img = True):
        logpath = os.path.join(self.save_dir, 'opt_hpams.json')
        hpam = {'instance_ids' : instance_ids, 'lr': lr, 'lr_half_interval': lr_half_interval,
                '': self.splits}
        with open(logpath, 'w') as f:
            json.dump(hpam, f, indent=2)

        self.lr, self.lr_half_interval, iters = lr, lr_half_interval, 0
        instance_ids = torch.tensor(instance_ids)
        self.optimized_shapecodes = torch.zeros(len(self.dataloader), self.mean_shape.shape[1])
        self.optimized_texturecodes = torch.zeros(len(self.dataloader), self.mean_texture.shape[1])
        # Per object
        print("DEBUG: instance_ids",instance_ids)
        for num_obj, d in enumerate(self.dataloader):
            focal, H, W, imgs, poses, obj_idx = d
            tgt_imgs, tgt_poses = imgs[0, instance_ids], poses[0, instance_ids]
            self.nopts, self.lr_half_interval = 0, lr_half_interval
            shapecode = self.mean_shape.to(self.device).clone().detach().requires_grad_()
            texturecode = self.mean_texture.to(self.device).clone().detach().requires_grad_()

            # First Optimize
            self.set_optimizers(shapecode, texturecode)
            while self.nopts < self.num_opts:
                self.opts.zero_grad()
                t1 = time.time()
                generated_imgs, gt_imgs = [], []
                for num, instance_id in enumerate(instance_ids):
                    tgt_img, tgt_pose = tgt_imgs[num].reshape(-1,3), tgt_poses[num]
                    #print("Debug: ",tgt_img)
                    tgt_img.reshape(128,128,3)
                    rays_o, viewdir = get_rays(H.item(), W.item(), focal, tgt_pose)
                    xyz, viewdir, z_vals = sample_from_rays(rays_o, viewdir, self.hpams['near'], self.hpams['far'],
                                                            self.hpams['N_samples'])
                    loss_per_img, generated_img = [], []
                    for i in range(0, xyz.shape[0], self.B):
                        sigmas, rgbs = self.model(xyz[i:i+self.B].to(self.device),
                                                  viewdir[i:i+self.B].to(self.device),
                                                  shapecode, texturecode)
                        rgb_rays, _ = volume_rendering(sigmas, rgbs, z_vals.to(self.device))
                        #print(rgb_rays.shape, tgt_img.shape)
                        loss_l2 = torch.mean((rgb_rays - tgt_img[i:i+self.B].type_as(rgb_rays))**2)
                        if i == 0:
                            reg_loss = torch.norm(shapecode, dim=-1) + torch.norm(texturecode, dim=-1)
                            loss_reg = self.hpams['loss_reg_coef'] * torch.mean(reg_loss)
                            loss = loss_l2 + loss_reg
                        else:
                            loss = loss_l2
                        loss.backward()
                        loss_per_img.append(loss_l2.item())
                        generated_img.append(rgb_rays)
                    generated_imgs.append(torch.cat(generated_img).reshape(H,W,3))
                    gt_imgs.append(tgt_img.reshape(H,W,3))
                self.opts.step()
                self.log_opt_psnr_time(np.mean(loss_per_img), time.time() - t1, self.nopts + self.num_opts * num_obj,
                                       num_obj)
                self.log_regloss(reg_loss.item(), self.nopts, num_obj)
                #print("prepare to save image...")
                #print("save_img=",save_img)
                if self.save_img:
                    self.save_img(generated_imgs, gt_imgs, self.ids[num_obj], self.nopts)
                self.nopts += 1
                if self.nopts % lr_half_interval == 0:
                    self.set_optimizers(shapecode, texturecode)

            # Then, Evaluate
            with torch.no_grad():
                #print(tgt_poses.shape)
                for num in range(49):
                    if num not in instance_ids:
                        tgt_img, tgt_pose = imgs[0,num].reshape(-1,3), poses[0, num]
                        rays_o, viewdir = get_rays(H.item(), W.item(), focal, poses[0, num])
                        xyz, viewdir, z_vals = sample_from_rays(rays_o, viewdir, self.hpams['near'], self.hpams['far'],
                                                               self.hpams['N_samples'])
                        loss_per_img, generated_img = [], []
                        for i in range(0, xyz.shape[0], self.B):
                            sigmas, rgbs = self.model(xyz[i:i+self.B].to(self.device),
                                                      viewdir[i:i + self.B].to(self.device),
                                                      shapecode, texturecode)
                            rgb_rays, _ = volume_rendering(sigmas, rgbs, z_vals.to(self.device))
                            loss_l2 = torch.mean((rgb_rays - tgt_img[i:i+self.B].type_as(rgb_rays)) ** 2)
                            loss_per_img.append(loss_l2.item())
                            generated_img.append(rgb_rays)
                        self.log_eval_psnr(np.mean(loss_per_img), num, num_obj)
                        self.log_compute_ssim(torch.cat(generated_img).reshape(H, W, 3), tgt_img.reshape(H, W, 3),
                                              num, num_obj)
                        #print("prepare to save image...")
                        #print("save_img=",save_img)
                        if save_img:
                            self.save_img([torch.cat(generated_img).reshape(H,W,3)], [tgt_img.reshape(H,W,3)], self.ids[num_obj], num,
                                          opt=False)

            # Save the optimized codes
            self.optimized_shapecodes[num_obj] = shapecode.detach().cpu()
            self.optimized_texturecodes[num_obj] = texturecode.detach().cpu()
            self.save_opts(num_obj)

    def save_opts(self, num_obj):
        saved_dict = {
            'ids': self.ids,
            'num_obj' : num_obj,
            'optimized_shapecodes' : self.optimized_shapecodes,
            'optimized_texturecodes': self.optimized_texturecodes,
            'psnr_eval': self.psnr_eval,
            'ssim_eval': self.ssim_eval
        }
        torch.save(saved_dict, os.path.join(self.save_dir, 'codes.pth'))
        print('We finished the optimization of ' + str(num_obj))

    def save_img(self, generated_imgs, gt_imgs, obj_id, instance_num, opt=True):
        H, W = gt_imgs[0].shape[:2]
        nviews = int(self.nviews)
        if not opt:
            nviews = 1
        generated_imgs = torch.cat(generated_imgs).reshape(nviews, H, W, 3)
        gt_imgs = torch.cat(gt_imgs).reshape(nviews, H, W, 3)
        ret = torch.zeros(nviews *H, 2 * W, 3)
        ret[:,:W,:] = generated_imgs.reshape(-1, W, 3)
        ret[:,W:,:] = gt_imgs.reshape(-1, W, 3)
        ret = image_float_to_uint8(ret.detach().cpu().numpy())
        save_img_dir = os.path.join(self.save_dir, obj_id)
        #print("save image to: ",save_img_dir)
        if not os.path.isdir(save_img_dir):
            os.makedirs(save_img_dir)
        if opt:
            imageio.imwrite(os.path.join(save_img_dir, 'opt' + self.nviews + '_' + str(instance_num) + '.png'), ret)
        else:
            imageio.imwrite(os.path.join(save_img_dir, str(instance_num) + '_' + self.nviews + '.png'), ret)

    def log_compute_ssim(self, generated_img, gt_img, niters, obj_idx):
        generated_img_np = generated_img.detach().cpu().numpy()
        gt_img_np = gt_img.detach().cpu().numpy()
        ssim = compute_ssim(generated_img_np, gt_img_np, multichannel=True)
        if niters == 0:
            self.ssim_eval[obj_idx] = [ssim]
        else:
            self.ssim_eval[obj_idx].append(ssim)

    def log_eval_psnr(self, loss_per_img, niters, obj_idx):
        psnr = -10 * np.log(loss_per_img) / np.log(10)
        print("psnr_eval: ",self.psnr_eval)
        print("obj_idx: ",obj_idx)
        print("PSNR: ",psnr)
        if niters == 0:
            self.psnr_eval[obj_idx] = [psnr]
        else:
            self.psnr_eval[obj_idx].append(psnr)

    def log_opt_psnr_time(self, loss_per_img, time_spent, niters, obj_idx):
        psnr = -10*np.log(loss_per_img) / np.log(10)
        self.writer.add_scalar('psnr_opt/' + self.nviews + '/' + self.splits, psnr, niters, obj_idx)
        self.writer.add_scalar('time_opt/' + self.nviews + '/' + self.splits, time_spent, niters, obj_idx)

    def log_regloss(self, loss_reg, niters, obj_idx):
        self.writer.add_scalar('reg/'  + self.nviews + '/' + self.splits, loss_reg, niters, obj_idx)

    def set_optimizers(self, shapecode, texturecode):
        lr = self.get_learning_rate()
        #print(lr)
        self.opts = torch.optim.AdamW([
            {'params': shapecode, 'lr': lr},
            {'params': texturecode, 'lr':lr}
        ])

    def get_learning_rate(self):
        opt_values = self.nopts // self.lr_half_interval
        lr = self.lr * 2**(-opt_values)
        return lr

    def make_model(self):
        self.model = CodeNeRF(**self.hpams['net_hyperparams']).to(self.device)

    def load_model_codes(self, saved_dir):
        saved_path = os.path.join('exps', saved_dir, 'models.pth')
        print("Load Model: ",saved_path)
        saved_data = torch.load(saved_path, map_location = torch.device('cpu'))
        self.make_save_img_dir(os.path.join('exps', saved_dir, 'test'))
        self.make_writer(saved_dir)
        self.model.load_state_dict(saved_data['model_params'])
        self.model = self.model.to(self.device)
        self.mean_shape = torch.mean(saved_data['shape_code_params']['weight'], dim=0).reshape(1,-1)
        self.mean_texture = torch.mean(saved_data['texture_code_params']['weight'], dim=0).reshape(1,-1)

    def make_writer(self, saved_dir):
        self.writer = SummaryWriter(os.path.join('exps', saved_dir, 'test', 'runs'))

    def make_save_img_dir(self, save_dir):
        save_dir_tmp = save_dir
        num = 2
        while os.path.isdir(save_dir_tmp):
            save_dir_tmp = save_dir + '_' + str(num)
            num += 1

        os.makedirs(save_dir_tmp)
        self.save_dir = save_dir_tmp
        #print(self.save_dir)

    def make_dataloader(self, splits, num_instances_per_obj, crop_img=False):
        #cat = self.hpams['data']['cat']
        #data_dir = self.hpams['data']['data_dir']
        #obj = cat.split('_')[1]
        #splits = obj + '_' + splits
        dtu=DTU(splits='test')
        #srn = SRN(cat=cat, splits=splits, data_dir = data_dir,
        #          num_instances_per_obj = num_instances_per_obj, crop_img = crop_img)
        self.ids=dtu.ids
        #self.ids = srn.ids
        #self.dataloader = DataLoader(srn, batch_size=1, num_workers =4, shuffle = False)
        self.dataloader = DataLoader(dtu, batch_size=1, num_workers =4, shuffle = False)

Trainer.py

In [None]:
import numpy as np
import torch
import torch.nn as nn
import json
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import os
import math
import time


class Trainer():
    def __init__(self, save_dir, gpu, jsonfile = 'srncar.json', batch_size=2048,
                 check_iter = 10000):
        super().__init__()
        # Read Hyperparameters
        hpampath = os.path.join('jsonfiles', jsonfile)
        with open(hpampath, 'r') as f:
            self.hpams = json.load(f)
        self.device = torch.device('cuda:' + str(gpu))
        self.make_model()
        self.make_dataloader(num_instances_per_obj = 1, crop_img = False)
        self.make_codes()
        self.B = batch_size
        self.make_savedir(save_dir)
        self.niter, self.nepoch = 0, 0
        self.check_iter = check_iter


    def training(self, iters_crop, iters_all, num_instances_per_obj=1):
        if iters_crop > iters_all:
            raise Exception("Iters_crop can't be larger than Iters_all.")
        while self.niter < iters_all:
            if self.niter < iters_crop:
                self.training_single_epoch(num_instances_per_obj = num_instances_per_obj,
                                           num_iters = iters_crop, crop_img = True)
            else:
                self.training_single_epoch(num_instances_per_obj=num_instances_per_obj,
                                           num_iters=iters_all, crop_img = False)
            self.save_models()
            self.nepoch += 1

    def training_single_epoch(self, num_instances_per_obj, num_iters, crop_img = True):
        # single epoch here means that it iterates over whole objects
        # only 1 or a few images are chosen for each epoch
        self.make_dataloader(num_instances_per_obj, crop_img = crop_img)
        self.set_optimizers()
        # per object
        for d in self.dataloader:
            if self.niter < num_iters:
                focal, H, W, imgs, poses, instances, obj_idx = d
                #print("input: ",imgs.shape)
                #print("obj_idx: ",obj_idx)
                obj_idx = obj_idx.to(self.device)
                # per image
                self.opts.zero_grad()
                for k in range(num_instances_per_obj):
                    # print(k, num_instances_per_obj, poses[0, k].shape, imgs.shape, 'k')
                    t1 = time.time()
                    self.opts.zero_grad()
                    rays_o, viewdir = get_rays(H.item(), W.item(), focal, poses[0,k])
                    xyz, viewdir, z_vals = sample_from_rays(rays_o, viewdir, self.hpams['near'], self.hpams['far'],
                                            self.hpams['N_samples'])
                    loss_per_img, generated_img = [], []
                    for i in range(0, xyz.shape[0], self.B):
                        shape_code, texture_code = self.shape_codes(obj_idx), self.texture_codes(obj_idx)
                        sigmas, rgbs = self.model(xyz[i:i+self.B].to(self.device),
                                                  viewdir[i:i+self.B].to(self.device),
                                                  shape_code, texture_code)
                        rgb_rays, _ = volume_rendering(sigmas, rgbs, z_vals.to(self.device))
                        #print("DEBUG: ",rgb_rays.shape)
                        #print("DEBUG: ",imgs.shape)
                        loss_l2 = torch.mean((rgb_rays - imgs[0, k, i:i+self.B].type_as(rgb_rays))**2)
                        if i == 0:
                            reg_loss = torch.norm(shape_code, dim=-1) + torch.norm(texture_code, dim=-1)
                            loss_reg = self.hpams['loss_reg_coef'] * torch.mean(reg_loss)
                            loss = loss_l2 + loss_reg
                        else:
                            loss = loss_l2
                        loss.backward()
                        loss_per_img.append(loss_l2.item())
                        generated_img.append(rgb_rays)
                self.opts.step()
                self.log_psnr_time(np.mean(loss_per_img), time.time() - t1, obj_idx)
                self.log_regloss(reg_loss, obj_idx)
                if self.niter % self.check_iter == 0:
                    generated_img = torch.cat(generated_img)
                    generated_img = generated_img.reshape(H,W,3)
                    gtimg = imgs[0,-1].reshape(H,W,3)
                    self.log_img(generated_img, gtimg, obj_idx)
                    print(-10*np.log(np.mean(loss_per_img))/np.log(10), self.niter)
                if self.niter % self.hpams['check_points'] == 0:
                    self.save_models(self.niter)
                print(self.niter)
                self.niter += 1

    def log_psnr_time(self, loss_per_img, time_spent, obj_idx):
        psnr = -10*np.log(loss_per_img) / np.log(10)
        self.writer.add_scalar('psnr/train', psnr, self.niter, obj_idx)
        self.writer.add_scalar('time/train', time_spent, self.niter, obj_idx)

    def log_regloss(self, loss_reg, obj_idx):
        self.writer.add_scalar('reg/train', loss_reg, self.niter, obj_idx)

    def log_img(self, generated_img, gtimg, obj_idx):
        H, W = generated_img.shape[:-1]
        ret = torch.zeros(H,2*W, 3)
        ret[:,:W,:] = generated_img
        ret[:,W:,:] = gtimg
        ret = image_float_to_uint8(ret.detach().cpu().numpy())
        self.writer.add_image('train_'+str(self.niter) + '_' + str(obj_idx.item()), torch.from_numpy(ret).permute(2,0,1))

    def set_optimizers(self):
        lr1, lr2 = self.get_learning_rate()
        self.opts = torch.optim.AdamW([
            {'params':self.model.parameters(), 'lr': lr1},
            {'params':self.shape_codes.parameters(), 'lr': lr2},
            {'params':self.texture_codes.parameters(), 'lr':lr2}
        ])

    def get_learning_rate(self):
        model_lr, latent_lr = self.hpams['lr_schedule'][0], self.hpams['lr_schedule'][1]
        num_model = self.niter // model_lr['interval']
        num_latent = self.niter // latent_lr['interval']
        lr1 = model_lr['lr'] * 2**(-num_model)
        lr2 = latent_lr['lr'] * 2**(-num_latent)
        return lr1, lr2

    def make_model(self):
        self.model = CodeNeRF(**self.hpams['net_hyperparams']).to(self.device)

    def make_codes(self):
        embdim = self.hpams['net_hyperparams']['latent_dim']
        d = len(self.dataloader)
        self.shape_codes = nn.Embedding(d, embdim)
        self.texture_codes = nn.Embedding(d, embdim)
        self.shape_codes.weight = nn.Parameter(torch.randn(d, embdim) / math.sqrt(embdim/2))
        self.texture_codes.weight = nn.Parameter(torch.randn(d, embdim) / math.sqrt(embdim/2))
        self.shape_codes = self.shape_codes.to(self.device)
        self.texture_codes = self.texture_codes.to(self.device)

    def make_dataloader(self, num_instances_per_obj, crop_img):
        # cat : whether it is 'srn_cars' or 'srn_chairs'
        # split: whether it is 'car_train', 'car_test' or 'car_val'
        # data_dir : the root directory of ShapeNet_SRN
        #num_instances_per_obj : how many images we chosose from objects
        cat = self.hpams['data']['cat']
        data_dir = self.hpams['data']['data_dir']
        splits = self.hpams['data']['splits']
        #dtu=DTU(splits='test')
        srn = SRN(cat=cat, splits=splits, data_dir = data_dir,
                  num_instances_per_obj = num_instances_per_obj, crop_img = crop_img)
        self.dataloader = DataLoader(srn, batch_size=1, num_workers =1)
        #self.dataloader = DataLoader(dtu, batch_size=1, num_workers =4)

    def make_savedir(self, save_dir):
        self.save_dir = os.path.join('exps', save_dir)
        if not os.path.isdir(self.save_dir):
            os.makedirs(os.path.join(self.save_dir, 'runs'))
        self.writer = SummaryWriter(os.path.join(self.save_dir, 'runs'))
        hpampath = os.path.join(self.save_dir, 'hpam.json')
        with open(hpampath, 'w') as f:
            json.dump(self.hpams, f, indent=2)


    def save_models(self, iter = None):
        save_dict = {'model_params': self.model.state_dict(),
                     'shape_code_params': self.shape_codes.state_dict(),
                     'texture_code_params': self.texture_codes.state_dict(),
                     'niter': self.niter,
                     'nepoch' : self.nepoch
                     }
        if iter != None:
            torch.save(save_dict, os.path.join(self.save_dir, str(iter) + '.pth'))
        torch.save(save_dict, os.path.join(self.save_dir, 'models.pth'))

Options

In [None]:
class options():
    def __init__(self):
        self.gpu='0'
        self.saved_dir="/content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/saved/"
        self.tgt_instances="1"
        self.splits="train"
        self.num_opts=200
        self.lr=1e-2
        self.lr_half_interval=50
        self.save_img=True
        self.jsonfile="/content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/jsonfiles/srncar.json"
        self.batchsize=2048

        self.iters_crop="300000"
        self.iters_all="360000"
        self.num_instances_per_obj=2

args=options()


Optimize.py

In [None]:
import sys, os

args.jsonfile="/content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/jsonfiles/srn.json"

#ROOT_DIR = os.path.abspath(os.path.join('', 'src'))
#sys.path.insert(0, os.path.join(ROOT_DIR))
args.splits="test"

saved_dir = args.saved_dir
gpu = int(args.gpu)
lr = float(args.lr)
lr_half_interval = int(args.lr_half_interval)
save_img = args.save_img
batchsize = int(args.batchsize)
tgt_instances = list(args.tgt_instances)
print("save_img:",tgt_instances)
num_opts = int(args.num_opts)
print("save_img:",save_img)
for num, i in enumerate(tgt_instances):
    tgt_instances[num] = int(i)
optimizer = Optimizer(saved_dir, gpu, tgt_instances, args.splits, args.jsonfile, batchsize, num_opts)
optimizer.optimize_objs(tgt_instances, lr, lr_half_interval, save_img)

save_img: ['1']
save_img: True
DEBUG:  <_io.TextIOWrapper name='/content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/jsonfiles/dtu.json' mode='r' encoding='UTF-8'>
Load Model:  /content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/saved/models.pth
DEBUG:  /content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/saved/
DEBUG:  [1]
we are going to save at  /content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/saved/test_32
DEBUG: instance_ids tensor([1])
psnr_eval:  {}
obj_idx:  0
PSNR:  16.27385238011467
psnr_eval:  {0: [16.27385238011467]}
obj_idx:  0
PSNR:  14.455274183257385
psnr_eval:  {0: [16.27385238011467, 14.455274183257385]}
obj_idx:  0
PSNR:  13.101883751800033
psnr_eval:  {0: [16.27385238011467, 14.455274183257385, 13.101883751800033]}
obj_idx:  0
PSNR:  13.064881365651027
psnr_eval:  {0: [16.27385238011467, 14.455274183257385, 13.101883751800033, 13.064881365651027]}
obj_idx:  0
PSNR:  12.636859384637708
psnr_eval:  {0: [16.27385238011467, 14.455274183257385, 13.10188375180003

KeyboardInterrupt: ignored

Train.py

In [None]:
import sys, os
args.jsonfile="/content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/jsonfiles/dtu.json"
#args.jsonfile="/content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/jsonfiles/srncar.json"
save_dir = args.saved_dir
gpu = int(args.gpu)
iters_crop = int(args.iters_crop)
iters_all = int(args.iters_all)
B = int(args.batchsize)
num_instances_per_obj = int(args.num_instances_per_obj)
trainer = Trainer(save_dir, gpu, jsonfile = args.jsonfile, batch_size = B)
trainer.training(iters_crop, iters_all, num_instances_per_obj=1)

25.55937952976159 0
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
27

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


2924


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


2925
2926


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():


2927


  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


2966
2967
2968


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


2969


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3032
3033
3034
3035
3036


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3037
3038


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3039
3040


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3085
3086
3087
3088
3089
3090


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3091


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3092


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

3164
3165


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3166
3167
3168
3169
3170


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219


Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550><function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>
Traceback (most recent call last):

Traceback (most recent call last):

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
        self._shutdown_workers()self._shutdown_workers()        

self._shutdown_workers()se

3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289


Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550><function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550><function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>

Traceback (most recent call last):

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1510, in __del__
        self._shutdown_workers()    Exception ignored in: self._shutdown_workers()
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f28b6d08550>self._shutdown_workers()  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1493, in _shutdown_workers



  File "/usr/local/lib/python3

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
136718
136719
136720
136721
136722
136723
136724
136725
136726
136727
136728
136729
136730
136731
136732
136733
136734
136735
136736
136737
136738
136739
136740
136741
136742
136743
136744
136745
136746
136747
136748
136749
136750
136751
136752
136753
136754
136755
136756
136757
136758
136759
136760
136761
136762
136763
136764
136765
136766
136767
136768
136769
136770
136771
136772
136773
136774
136775
136776
136777
136778
136779
136780
136781
136782
136783
136784
136785
136786
136787
136788
136789
136790
136791
136792
136793
136794
136795
136796
136797
136798
136799
136800
136801
136802
136803
136804
136805
136806
136807
136808
136809
136810
136811
136812
136813
136814
136815
136816
136817
136818
136819
136820
136821
136822
136823
136824
136825
136826
136827
136828
136829
136830
136831
136832
136833
136834
136835
136836
136837
136838
136839
136840
136841
136842
136843
136844
136845
136846
136847
136848
136849
136850
1368

In [None]:
import sys, os

args.jsonfile="/content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/jsonfiles/dtu.json"

#ROOT_DIR = os.path.abspath(os.path.join('', 'src'))
#sys.path.insert(0, os.path.join(ROOT_DIR))
args.splits="test"

saved_dir = args.saved_dir
gpu = int(args.gpu)
lr = float(args.lr)
lr_half_interval = int(args.lr_half_interval)
save_img = args.save_img
batchsize = int(args.batchsize)
tgt_instances = list(args.tgt_instances)
print("save_img:",tgt_instances)
num_opts = int(args.num_opts)
print("save_img:",save_img)
for num, i in enumerate(tgt_instances):
    tgt_instances[num] = int(i)
optimizer = Optimizer(saved_dir, gpu, tgt_instances, args.splits, args.jsonfile, batchsize, num_opts)
optimizer.optimize_objs(tgt_instances, lr, lr_half_interval, save_img)

save_img: ['1']
save_img: True
DEBUG:  <_io.TextIOWrapper name='/content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/jsonfiles/dtu.json' mode='r' encoding='UTF-8'>
Load Model:  /content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/saved/models.pth
DEBUG:  /content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/saved/
DEBUG:  [1]
we are going to save at  /content/gdrive/MyDrive/Colab Notebooks/NeRF/CodeNeRF/saved/test_37
DEBUG: instance_ids tensor([1])
psnr_eval:  {}
obj_idx:  0
PSNR:  15.28524250344347
psnr_eval:  {0: [15.28524250344347]}
obj_idx:  0
PSNR:  14.354758307486856
psnr_eval:  {0: [15.28524250344347, 14.354758307486856]}
obj_idx:  0
PSNR:  13.217992404576275
psnr_eval:  {0: [15.28524250344347, 14.354758307486856, 13.217992404576275]}
obj_idx:  0
PSNR:  11.799426650620626
psnr_eval:  {0: [15.28524250344347, 14.354758307486856, 13.217992404576275, 11.799426650620626]}
obj_idx:  0
PSNR:  11.393426771052154
psnr_eval:  {0: [15.28524250344347, 14.354758307486856, 13.21799240457627